【LINEbot】Pythonで女優識別botを作ろう Part4【モデル構築編】

こんにちは!

 

「女優識別botを作ろう」の第4回目となる今回は、いよいよモデル構築に関して解説していきます。

 

確認ですが、今回は、画像収集・モデル入力用データに加工済みという前提で進めていきます。

 

 

 今回の内容

  • モデル構築
  • 転移学習

 

 

第3回目の記事はこちら。

>>【LINEbot】Pythonで女優識別botを作ろう Part3【入力用データ作成編】

 

では、さっそく見ていきましょう。

 

 

今回使うモデル紹介

今回、自作したデータセットを学習させるために、転移学習を行います。

 

転移学習を知らない人のために簡単に説明すると、転移学習は、ある分野ですでに学習済みのモデルを別の分野で活用することによって、学習時間を大幅に減らしつつ高い精度のモデルを作れる、というものです。

 

詳しくは、転移学習とは【数式無し】をご覧ください。

 

自分で層を何層も重ねてモデルを作ってもいいのですが、これだと精度がうまく上がらない可能性が高いので、今回は転移学習を行います。

 

調べるといろんなモデルが出てくるのですが、特に有名なものを挙げると、VGGやResNetなどがあります。

 

今回は、VGG16というモデルで転移学習を行います。

 

では、次の章でコードを紹介します。

 

スポンサーリンク

 

モデル学習用コード

モデル学習用のコードはこちらになります。

 

 

かなり長々とした印象を受ける方もいるかもしれませんが、大部分が PyTorch Lightning の基本的なコードですので、少しずつ理解してみてください。

 

記事の最後の方に、PyTorchの使い方を一通りマスターできるサイトを紹介しているので、そちらで先に勉強してみることもおすすめです。

 

主な流れは PyTorch Lightning で共通しているため、ここではモデル学習部分のみに焦点を当てて解説します。

 

まず、最初の13~25行目までの部分で、前回の記事で作成したデータセットを入力用データとして読み込んでいます。

 

 

PATH  には前回作成した、 model.pt のパスを指定します。

そのすぐ下の部分でデータを読み込み、21~25行目のところでデータを学習用と検証用に分割しています。

 

では、次に実際にモデルを定義している部分を見てみます。

モデル定義は Net  クラスの中の72~76行目で行っています。

 

 

73行目の self.model  に、転移学習に使用するモデルを入力します。

ちなみに、転移学習に使用できるモデルは、 torchvision.models  に格納されているので、興味があればいろいろ試してみてください。

 

76行目では、転移学習に使用するモデルに追加する層を定義しています。

これは、元のモデルが1000クラス分類用のモデルであるため、出力結果を女優の人数分にするために追加しています。

 

78~83行目では、パラメータの一部を再学習するように設定しています。

 

 

というのも、転移学習では、別の分野で学習したモデルを使うわけですが、その際に学習済みのパラメータを使わなければほとんど恩恵を受けません。

 

画像認識モデルでは、はじめの層付近では大抵の場合、どんな分野であっても共通した特徴を抽出することが知られています。

そのため、なるべく学習済みパラメータをそのまま使いつつ、後半部分だけを再学習させることが一般的です。

 

それほど、膨大な量のパラメータがあるため、どこまで再学習させるかを明示的に示す必要があるというわけです。

デフォルトでは、すべてのパラメータが再学習する設定になっているので、上で紹介したコードで後半部分のみを再学習させるように、 param.requires_grad = True として設定し直しています。

 

学習し終わった後、114行目でパラメータを保存しています。

 

 

ちなみに、今回紹介したコードは、CPUで実行するとめちゃくちゃ時間がかかるので、必ずGoogle colabratoryを使用してください。

使用する際はランタイムをGPUに切り替えることを忘れずに!

 

スポンサーリンク

 

おわりに

今回までで、画像収集、データセット作り、モデル構築までの作業がすべて終わりました。

 

次回から、このモデルをLINEbotに組み込むところを解説するのですが、ぶっちゃけそこはちょっとした応用になるだけなので、メイン部分は今回で終了といった感じになりますね。

 

ただ、学習モデルをどのように予測用として扱うのかを解説していないので、気になる方は次回も見ていただけるとありがたいです。

 

では、今回はここまでとします。

お疲れさまでした。

 

追記:第5回を更新しました。

>>【LINEbot】Pythonで女優識別botを作ろう Part5【Bot作成編】

 

PyTorchの扱い方が一通り学べるサイトはこちら