オープンソース化されたVision Transformer(ViT)を紹介!

オープンソース化されたVision Transformer(ViT)を紹介!

はじめに
 Transformerのみの画像処理モデルとして、話題を集めたVisionTransformerですが、今回研究元であるGoogleからGoogleBlogによって発表が行われ、コードとモデルがオープンソース化されました。
 それに伴い、Vision Transformerの内容を再度確認しながら、コードとモデルについて紹介します。Vision Transformerについての詳細は、以下の記事をご確認ください。『画像認識の革新モデル!脱CNNを果たしたVision Transformerを徹底解説!』

GoogleBlog
Transformers for Image Recognition at Scale
https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html

論文
An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale
https://arxiv.org/abs/2010.11929

GitHub
https://github.com/google-research/vision_transformer/

 

VisionTransformerの概略

 VisionTransformerはCNNを使わずに、Transfromerのみを利用したモデルとして注目を集めました。論文内では、計算リソースを最先端モデルから約4分の1まで削減するだけでなく、今後のスケーラビリティの可能性が提示されました。
 NLPで利用される場合のTransformerは、単語のシーケンスを入力として受け取り、分類、翻訳などを行います。Vision Transformerのように画像処理で利用される場合は、画像をシーケンスとして受け取り、分類などのタスクを行っています。
 Vision Transformerではまず画像を正方形のパッチのグリッドに分割します。各パッチは、パッチ内のすべてのピクセルのチャネルを連結し、それを目的の入力次元に線形に投影することによって、単一のベクトルに平坦化されます。Transformerは入力要素の構造に依存しないため、各パッチに学習可能な位置の埋め込みを追加します。これにより、モデルは画像の構造について学習できます。
 Vision Transformerは画像内のパッチの相対位置、または画像が2D構造を持っていることさえ学習する前は知りませんが、トレーニングデータからそのような関連情報を学習し、位置埋め込みに構造情報をエンコードすることができるようになります。
 (Vision Transformerの詳細については、こちらの記事『画像認識の革新モデル!脱CNNを果たしたVision Transformerを徹底解説!』をご覧ください。)

実際の利用について

利用環境

 Python>=3.6が要求されています。また、Jaxをインストールする場合には、以下のコマンドを叩いてください。

pip install -r vit_jax/requirements.txt

Google Colabでの利用について

 Google Colab上で利用することができます。Colabでは、GitHub上のリポジトリからコードをロードし、デフォルトで8コアのTPUで実行されるようになっています。なお、すべてのデータをエフェメラルVMに保存するのと同じように実行できます。また、個人のGoogleドライブにログインして、コードとデータを保持するように利用することもできます。

https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax.ipynb

利用できるVision Transformerモデルについて

 ImageNet-21kで事前学習されたものが公開されています。(なお、ImageNet2012でFine-tuningされたものも公開されています。)

・Vision Transformer 
 ViT-B/16、ViT-B/32、 ViT-L/16、 ViT-L/32、 ViT-H/14
・R50+ViT-B/16ハイブリッドモデル(Resnet50バックボーンの上にViT-B / 16)
 ImageNet21kで事前学習すると、Finetuningコストが半分未満で、L / 16モデルと同様のパフォーマンスを実現します。

GCP上からダウンロードすることで利用します。(https://console.cloud.google.com/storage/vit_models/
ImageNet21Kで事前学習されたViT-B/16を利用する場合、以下のコマンドで取得することができます。

wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz

Fine-tuningの方法について

 以下のコマンドを入力することで、利用することが可能です。

python3 -m vit_jax.train --name ViT-B_16-cifar10_`date +%F_%H%M%S` --model ViT-B_16 --logdir /tmp/vit_logs --dataset cifar10

現在、コードはCIFAR-10およびCIFAR-100データセットを自動的にダウンロードするよになっているようです。なお、他のパブリックデータセットまたはカスタムデータセットは、tensorflowデータセットライブラリを使用して簡単に統合できるとしています。なお、追加されたデータセットに関するいくつかのパラメーターを指定するには、vit_jax /input_pipeline.pyも更新する必要があることに注意が必要です。