はじめに
Normalization-Free Net (NFNet) は、Deep Mind社から発表された最新の画像認識モデルで、今非常に注目されているモデルです。EfficientNetと同水準を維持しながらも、訓練が8.7倍速く、SOTA(State-of-the-art、歴代最高水準のこと)を更新しました。画像認識において常識となっているバッチ正規化を取り除いたモデルです。
今回はこのモデルを、非専門家の方でも概要がつかめるよう、わかりやすく解説していきたいと思います。
背景
バッチ正規化ありのResNet(レズネット、非常に有名なモデルの一つ)により、極めて深いDeep learningが使えるようになり、精度が大きく向上しました。
バッチ正規化は、大きな学習率を使用可能にし、またregularization(過学習を回避する)効果も有しています。ただし下記の欠点もあります。
- 計算が大変で、メモリ不足に注意しないといけない
- 学習時とテスト時でモデルの挙動が変わってしまう
- 1つ1つのデータ間の独立性が壊れる
- 小さいバッチサイズでは機能しない
バッチ正規化を取り除く
これらの欠点を克服するため、バッチ正規化を取り除くことを試みます。しかし、バッチ正規化の以下の機能は保持するようにしないといけません。
- 残差結合の影響を小さく保つ
- Mean-Shiftを取り除く
- Regularization効果
- 大きなバッチによる訓練を可能にする
それぞれ順番に対処法を見ていきます。
まず1のために、下記の式が用いられました。
次に2のために、下記のScaled Weight Standardizationという方法で、Convolutionの重みを調整しました。※Convolution(畳み込み)は画像認識の特徴を抽出する手法。
さらに3のために、DropoutとStochastic Depthという手法を使用。
最後に4のために、下記Gradient Clippingを導入。
さらにλを微調整しなくて済むように、これをAdaptive Gradient Clipping(AGC)に改良しました。
要するに、重みWに比べて勾配Gが大きすぎるときは、Gの大きさをクリップする(小さくする)ということのようです。正確にはこれを行単位で行う、Unit-wiseというアプローチが採用されています。
NFNet
完成したNFNetは、Stem ⇒ Stage 1 ⇒ Stage 2 ⇒ Stage 3 ⇒ Stage 4 ⇒ Classifierという構成になっています。上述の手法たちが取り入れられていますが、ClassifierにAGCを入れると成績が悪化するようなので、入れません。
Stemは4つのConvolutionから成り、それぞれストライドを2,1,1,2、出力チャネル数を16,32,64,128に設定します。
Stage1-4はそれぞれ、N, 2N, 6N, 3N個のブロックから構成されます。たとえばNFNet-F0では、それぞれ1,2,6,3個のブロックから構成されます。
ブロックには基本的に下図右のNon-transition blockを使い、Stage2-4の最初のブロックだけは下図左のTransition blockを使用します。
※WS-Conv:Scaled Weight Standardization(上述)
※S+E:Squeeze & Exciteという層
※C/G=128:Grouped convolutionにおいて、グループごとのチャネル数が128に固定されている
※Scaled Activation:出力の分散を1に保つ
モデルの訓練
ImageNetデータに対して、Nesterov’s Momentum =0.9、エポック=360、バッチサイズ=4096、AGC閾値=0.01で訓練。
初めの5エポックで学習率を0から徐々に増やし、その後Cosine annealingという手法で減衰させました。
MixUp、RandAugment、CutMixによる強力な画像のAugumentationが成績向上に貢献しました。※Augumentationは画像を変形させることで、データを水増しする手法です
またさらに成績を向上させるため、最新のOptimizerであるSharpness-Aware Minimization (SAM)も追加されました。※Optimizerにより、訓練を通じてモデルは、問題を解くために最適化される。
結果
上図は訓練時間と正答率の関係を見たものです。NFNet-F1(赤線)は8.7倍短い訓練時間で、EfficientNet-B7(青線)と同程度の成績を叩き出しました。さらにSAMを追加することで86.5%の正答率を達成し、ImageNetにおけるSOTA(歴代最高水準)を更新しました。
追加の実験として、転移学習にもNFNetが使用されました。巨大な訓練データにおいては、バッチ正規化がむしろ有害かもしれないと考えられたからです。結果として、転移学習においても良好な成績が得られました。※転移学習とはあらかじめ訓練済のモデルを使うことで精度を向上させる技術です。
まとめ
常識に反してバッチ正規化を取り除くことで、画像分類における学習スピードを大幅に短縮することが可能になり、今後も注目され発展していくモデルだと思います。
今回は非専門家の方でも概要がつかめるように、なるべくわかりやすく解説することを心がけました。この記事が、少しでも読んでいただいた方の参考になれば嬉しいです。
コメント