はじめに
Transformerのみの画像処理モデルとして、話題を集めたVisionTransformerですが、今回研究元であるGoogleからGoogleBlogによって発表が行われ、コードとモデルがオープンソース化されました。
それに伴い、Vision Transformerの内容を再度確認しながら、コードとモデルについて紹介します。Vision Transformerについての詳細は、以下の記事をご確認ください。『画像認識の革新モデル!脱CNNを果たしたVision Transformerを徹底解説!』
Transformers for Image Recognition at Scale
https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html
論文
An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale
https://arxiv.org/abs/2010.11929
GitHub
https://github.com/google-research/vision_transformer/
VisionTransformerの概略
VisionTransformerはCNNを使わずに、Transfromerのみを利用したモデルとして注目を集めました。論文内では、計算リソースを最先端モデルから約4分の1まで削減するだけでなく、今後のスケーラビリティの可能性が提示されました。
NLPで利用される場合のTransformerは、単語のシーケンスを入力として受け取り、分類、翻訳などを行います。Vision Transformerのように画像処理で利用される場合は、画像をシーケンスとして受け取り、分類などのタスクを行っています。
Vision Transformerではまず画像を正方形のパッチのグリッドに分割します。各パッチは、パッチ内のすべてのピクセルのチャネルを連結し、それを目的の入力次元に線形に投影することによって、単一のベクトルに平坦化されます。Transformerは入力要素の構造に依存しないため、各パッチに学習可能な位置の埋め込みを追加します。これにより、モデルは画像の構造について学習できます。
Vision Transformerは画像内のパッチの相対位置、または画像が2D構造を持っていることさえ学習する前は知りませんが、トレーニングデータからそのような関連情報を学習し、位置埋め込みに構造情報をエンコードすることができるようになります。
(Vision Transformerの詳細については、こちらの記事『画像認識の革新モデル!脱CNNを果たしたVision Transformerを徹底解説!』をご覧ください。)
実際の利用について
利用環境
Python>=3.6が要求されています。また、Jaxをインストールする場合には、以下のコマンドを叩いてください。
pip install -r vit_jax/requirements.txt
Google Colabでの利用について
Google Colab上で利用することができます。Colabでは、GitHub上のリポジトリからコードをロードし、デフォルトで8コアのTPUで実行されるようになっています。なお、すべてのデータをエフェメラルVMに保存するのと同じように実行できます。また、個人のGoogleドライブにログインして、コードとデータを保持するように利用することもできます。
利用できるVision Transformerモデルについて
ImageNet-21kで事前学習されたものが公開されています。(なお、ImageNet2012でFine-tuningされたものも公開されています。)
・Vision Transformer
ViT-B/16、ViT-B/32、 ViT-L/16、 ViT-L/32、 ViT-H/14
・R50+ViT-B/16ハイブリッドモデル(Resnet50バックボーンの上にViT-B / 16)
ImageNet21kで事前学習すると、Finetuningコストが半分未満で、L / 16モデルと同様のパフォーマンスを実現します。
GCP上からダウンロードすることで利用します。(https://console.cloud.google.com/storage/vit_models/)
ImageNet21Kで事前学習されたViT-B/16を利用する場合、以下のコマンドで取得することができます。
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
Fine-tuningの方法について
以下のコマンドを入力することで、利用することが可能です。
python3 -m vit_jax.train --name ViT-B_16-cifar10_`date +%F_%H%M%S` --model ViT-B_16 --logdir /tmp/vit_logs --dataset cifar10
現在、コードはCIFAR-10およびCIFAR-100データセットを自動的にダウンロードするよになっているようです。なお、他のパブリックデータセットまたはカスタムデータセットは、tensorflowデータセットライブラリを使用して簡単に統合できるとしています。なお、追加されたデータセットに関するいくつかのパラメーターを指定するには、vit_jax /input_pipeline.pyも更新する必要があることに注意が必要です。