Laboro

エンジニアコラム

広い技術領域をカバーする当社の機械学習エンジニアが、アカデミア発のAI&機械学習技術を紹介&解説いたします。

ディープラーニングを軽量化「宝くじ仮説」について

2020.2.25
株式会社Laboro.AI リード機械学習エンジニア 内木 賢吾

概 要

ディープラーニングモデルの開発・導入時、とくに結果の出力にリアルタイム性をもとめられるようなビジネス現場への実装においては、モデルの精度を保ちつついかに高速処理を実現するかが重要なポイントになってきます。こうした観点から、前回のコラムではディープニューラルネットワークの「モデル圧縮」をテーマにpruning(枝刈り)、quantize(量子化)、distillation(蒸留)の3つのモデル軽量化の手法をご紹介しました。

今回はさらに知見を深めていただくことを目的に、2019年に話題になった「宝くじ仮説 (The Lottery Ticket Hypothesis)」というディープニューラルネットワークのモデル圧縮に関する仮説を紹介します。

「宝くじ仮説」を理解する上では、モデル圧縮技術に関するベース知識があった方が内容が掴みやすいと思いますので、前回のコラム「ディープラーニングを軽量化する『モデル圧縮』3手法」を、まずはお読みいただけると良いと思います。

目 次

「宝くじ仮説」とは
当たりくじの発見方法
当たりくじの検証
関連研究について
まとめ

「宝くじ仮説」とは

「宝くじ仮説」とは、MITのJonathan Frankle氏とMichael Carbin氏による論文『 The Lottery Ticket Hypothesis: Training Pruned Neural Networks 』にて提唱された、ディープニューラルネットワークのモデル圧縮に関する仮説のことです。その論文はICLR 2019のBest Paper Awardに選ばれています。

宝くじ仮説について簡単に説明しますと「学習済みのニューラルネットワーク(original network)には、単独で同じ程度学習させても、元のネットワークの性能に匹敵する部分ネットワーク(subnetwork)が存在する」という仮説のことです。つまり、宝くじ仮説は下図のようなoriginal networkと同等の精度を持つsubnetworkが存在することを意味しています。

図1 宝くじ仮説のイメージ

この部分ネットワーク(subnetwork)を「当たりくじ(winning tickets)」と例えており、膨大な組合せ数がある部分ネットワークから「精度の高い部分ネットワーク(subnetwork)」を見つけることを、宝くじに当選することになぞらえて「宝くじ仮説」と提唱しているようです。
(このコラムでも「当たりくじ」という表現を使用していきます。)

論文内の検証において、元のネットワークの10%〜20%のサイズ、かつ元のネットワークよりも精度の高い当たりくじが発見できた、と報告されています。では次に、当たりくじの探し方について説明します。

当たりくじの発見方法

論文では、1回だけ当たりくじを探す方法(one-shot pruning)と、繰り返し処理で当たりくじを探す方法(iterative pruning)が記載されています。iterative pruningの繰り返し回数を1回にすればone-shot pruningと同じであること、iterative pruningの方が得られるネットワークの性能がより良いことが示されているといった理由から、本コラムではiterative pruningについて説明します。論文には以下のように記されています。

Strategy 1: Iterative pruning with resetting.
1.Randomly initialize a neural network \(f(x;m\odot \theta)\) where \(\theta = \theta_0\) and \(m = 1^{|θ|}\)is a mask.
2.Train the network for \(j\) iterations, reaching parameters \(m \odot \theta_j\).
3.Prune \(s\)% of the parameters, creating an updated mask \(m’\) where \(P_{m’} = (P_m − s)\)%.
4.Reset the weights of the remaining portion of the network to their values in \(\theta_0\). That is, let \(\theta = \theta_0\).
5.Let \(m = m’\) and repeat steps 2 through 4 until a sufficiently pruned network has been obtained.

当たりくじを探す手順を補足を加えながら説明します。図2に当たりくじを探す手順を示します。論文では、手順4で得られたネットワークのテスト精度と元のネットワークのテスト精度を比較することで、得られたネットワークが当たりくじかを確認しています。

図2 宝くじ仮説:当たりくじを探す手順

1.初期ネットワークを作成する

パラメータをランダムに初期化したネットワーク\(f(x;m\odot \theta_0)\)を作成し、マスク値を全て1に設定します。このとき、\(x\)は入力変数を表し、\(\theta\)は、ネットワークの重みの集合を表します。 \(\theta_0\) は重みの初期値を示します。\(m\)はマスクと呼ばれるネットワークの枝刈りに用いるパラメータを表し、0か1の2値をとります。\(m \odot \theta\)は、マスク\(m\)と重み\(\theta\)の要素ごとの積(アダマール積)を表します。つまり、マスク値0のときは重みとの積が0になるため、その接続がなくなることを意味します。マスク値が1のときは、重みの値がそのまま残ります。

2.モデルの初期値を保存し、モデルを訓練する

重み\(\theta_0\)は後の作業で使用するため保存し、モデル\(f(x;m\odot \theta)\)を\(j\)回訓練します。このとき、重みが更新されて\(\theta = \theta_j\)となります。

3.枝刈り用のマスクを作成し、枝刈りを実施する。

学習終了後、枝刈り率\(s%\)に応じたマスク\(m’\)を作成し、\(m’ \odot \theta_j\)を計算します。マスク\(m’\)の値は、重みの絶対値が低い順番から下位\(s\)%に対してマスク値0を割り当て、残りに1を割り当てます。このとき、マスク\(m\)のときに枝刈り後にネットワークに残っている重みの割合を\(P_m\)とした場合、本処理で残る割合\(P_{m’}\)は\((P_m – s)%\)になります。つまり、枝刈り率20%のときは、最初の枝刈りで80%の重みが残り(マスク値=1の割合が20%)、2回目の枝刈りで80*(100 – 20) = 64%の重みが残ることになります。

4.枝刈り実施後のパラメータを更新する

枝刈りで残ったパラメータは\(\theta_j\)なので、重みの値を初期値\(\theta_0\)に更新します。

5.マスクを保存し、手順2-4を繰り返す

次のイテレーションで使用するためにマスク\(m’\)を保存し、手順2から4を繰り返すことで削除するパラメータを増やします。論文には明記されていませんが、枝刈り率が閾値以上となるか、モデルの精度が閾値を下回るまで繰り返すことがよいと思われます。なお、繰り返し処理を行わない場合はone-shot pruningとなります。


重みの絶対値が小さい箇所から削除する考え方は、単純なpruningの手法と同じです。宝くじ仮説における重要なポイントは、「ネットワークの初期の重みを使用する」ことです。つまり、宝くじ仮説ではネットワークの構造だけではなく、ネットワークの初期の重みとの両方を重要視しており、この点が一般的なpruningとの違いになります。また、重みの絶対値が小さいものから削除した後に初期の重みと置き換えるだけという比較的シンプルなアプローチで性能低下を防ぐことができるモデル圧縮技術であると考えることもできます。

論文では、上記の手順で当たりくじが見つかるか、またその当たりくじの結果を評価しています。次にその一部を抜粋して紹介したいと思います。

当たりくじの検証

宝くじ仮説の検証には、次のデータセットとモデルを使用します。

1 手書き数字の画像データセット「MNIST」とモデルLeNet
2 一般物体認識のベンチマークの画像データセット「CIFAR-10」とモデルCNN(2層、4層、6層)
3 「CIFAR-10」とモデルResNet-18、VGG-19

1回の枝刈り率\(s\)は、データセットやモデルによって異なりますが、全結合層に対して20%、畳み込み層に対して10%程度の値を取っています。詳細は元論文をご参照ください。

全体を通して次のような結果が観測されています。

●より少ない学習回数で元のネットワークと同等以上のテスト精度を持ち、元のネットワークの10%~20%程度までサイズを削減できた当たりくじを発見した。

●重みの初期値に置き換えず、重みをランダムに更新してネットワークを再学習すると、学習の進みが遅くなり、かつ、テスト精度も元のモデルと同等までしか到達しなかった。

これらの結果より、元のネットワークと同等以上の性能を持つ当たりくじが存在し、当たりくじを発見するために元のネットワークの重みを初期値として用いることが重要であることがわかります。

個々の実験から、次のような結果が観測されています。

●one-shot pruningとiterative pruningは、どちらも重みをランダムに初期化する場合と比べて高い性能を示した。また、iterative pruningの方が性能のよい当たりくじを発見できた。ただし、パラメータを変更させながら訓練を行うため、計算コストはiterative pruningの方が高い。(論文2章 MNISR × LeNetの結果より)

●訓練時の精度とテスト精度の差が小さいことから、当たりくじは元のネットワークよりも汎化性能が高い可能性がある。(3章 CIFAR10 × CNN(2層、4層、6層)の結果より)

●dropoutと組み合わせることで、テスト精度が全体的に向上する。(3章 CIFAR10 × CNN(2層、4層、6層)の結果より)

●学習率が大きい場合では当たりくじの発見が難しいが、学習率を小さくする、または、学習率を徐々に大きくしながら学習するwarm upを導入することで当たりくじを発見できた。(4章 CIFAR10 × VGG-19、ResNet-18の結果より)

以上の複数のデータセットに対して様々なモデルの結果を確認しても当たりくじを発見できたことから、宝くじ仮説の信用性は高いと言えるでしょう。また、ハイパーパラメータの調整は必要であるものの論文で提案されたiterative pruningの手法によって当たりくじを発見できること、dropoutやwarm upといった深層学習で用いられる手法と併用できることもおわかりいただけたかと思います。

関連研究について

詳しい説明は割愛しますが、宝くじ仮説はモデル圧縮の研究として注目を集めており、NeurIPS 2019でも関連研究が発表されました。

1つ目に『 One ticket to win them all: generalizing lottery ticket
initializations across datasets and optimizers
』です。この論文は、宝くじ仮説における当たりくじの計算コストを削減するために、より大きなデータセットに対して作成した当たりくじを再利用できないかを検証した論文です。その結果は、大きなデータセットを使用して作成した当たりくじを他のデータセットに対して再利用できる可能性が高いことを示唆しています。

2つ目に『 Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask 』です。この論文では、当たりくじの性能が高い理由について考察し、重みを0にすることが重要な理由などを実験を通して検証しています。

まとめ

今回のコラムでは、「宝くじ仮説 (The Lottery Ticket Hypothesis) 」について説明してきました。宝くじ仮説は、計算コストは高いものの比較的シンプルな方法でモデルを圧縮することができる点から、有効な技術だと言えます。発表のあった年に別の国際会議で関連研究が報告されていることからも、さらに研究が加速し、ビジネスシーンでの活用機会も増えてくるのではないかと予想されます。

参考文献

The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks
One ticket to win them all: generalizing lottery ticket initializations across datasets and optimizers
Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask

コラム執筆者

機械学習エンジニア 内木 賢吾

名古屋大学大学院 工学研究科 博士前期課程修了。在学中に質問応答システムを研究。卒業後、ハードウェアエンジニアとして実験機器の開発やデジタル/アナログ回路設計ならびにハーネス設計を経験。その後、車載向け音声認識システムのフロントエンド処理に関する研究開発を担当。2019年7月よりLaboro.AIに参画。

その他の執筆コラム

声や音を聞き分ける、『音源分離』とは
ディープラーニングを軽量化する「モデル圧縮」3手法

カスタムAIの導入に関する
ご相談はこちらから

お名前(必須)
御社名(必須)
部署名(必須)
役職名(任意)
メールアドレス(必須)
電話番号(任意)
件名(必須)
本文(必須)

(プライバシーポリシーはこちら