「PyTorchでもっと効率的に機械学習を行いたい」
そのように思う方は、PyTorch Lightningを使いましょう。
PyTorch Lightningを使えば、今よりは機械学習の研究・検証がスムーズに進むはずです。
本記事の内容
- PyTorch Lightningとは?
- PyTorch Lightningのシステム要件
- PyTorch Lightningのインストール
- PyTorch Lightningの動作確認
それでは、上記に沿って解説していきます。
PyTorch Lightningとは?
PyTorch Lightningとは、高性能AI研究のための軽量なPyTorchラッパーです。
PyTorchで機械学習を行うことは、コーディングを行うことでもあります。
同時に、コーディング以外にもシステム的な知識・スキルが必要とされます。
範囲を広げて言うと、 機械学習ではエンジニアリングをも行っていると言えます。
本来、AIの研究者はエンジニアリングではなく、研究自体にもっと労力を割くべきです。
PyTorch Lightningを使えば、研究者はエンジニアリングに割く労力を削減できます。
つまり、研究者は研究にもっと時間や労力を投入できるということです。
PyTorch Lightningの公式では、次のように表現されています。
LightningはPyTorchのコードを分離し、サイエンスとエンジニアリングを切り離します。
PyTorch Lightningは、研究者だけが対象ではないと思います。
エンジニア以外であれば、マーケッター・アナリストなども対象になるはずです。
もちろん、エンジニアが利用しても良いでしょう。
便利なモノなら、誰が使ってもOKなので。
以上、PyTorch Lightningについても説明でした。
次は、PyTorch Lightningのシステム要件を確認します。
PyTorch Lightningのシステム要件
現時点(2021年8月末)でのPyTorch Lightningの最新バージョンは、1.4.4となります。
この最新バージョンは、2021年8月25日にリリースされています。
結構、高い頻度で更新されているようです。
PyTorch Lightningのシステム要件は、以下の点を確認しましょう。
- OS
- Python
- PyTorch
それぞれを下記で説明します。
OS
サポートOSに関しては、以下を含むクロスプラットフォーム対応です。
- Windows
- macOS
- Linux
そもそも、PyTorchがクロスプラットフォーム対応となります。
Python
サポート対象となるPythonのバージョンは、以下。
- Python 3.6
- Python 3.7
- Python 3.8
- Python 3.9
下記は、Pythonの公式開発サイクルです。
バージョン | リリース日 | サポート期限 |
3.6 | 2016年12月23日 | 2021年12月 |
3.7 | 2018年6月27日 | 2023年6月 |
3.8 | 2019年10月14日 | 2024年10月 |
3.9 | 2020年10月5日 | 2025年10月 |
PyTorch Lightningは、上記の開発サイクルに適切に従っています。
これに対して、PyTorchはPython 3.xというアバウトなサポート状況です。
バージョン云々で気になるなら、最新バージョンにアップグレードしておきましょう。
アップグレードは簡単にできます。
PyTorch
PyTorch Lightningは、その名称にあるようにPyTorchありきです。
PyTorchがインストールされていることが、大前提となります。
そして、PyTorchのバージョンはPyTorch 1.6以降が必要です。
なお、PyTorchの最新版はPyTorch 1.9.0となります。
PyTorchのインストールについては、次の記事で解説しています。
まとめ
最も重要なことは、PyTorch 1.6以降ということでしょうね。
そのことは、requirements.txtに記載されています。
とにかく、最新のPyTorchをインストールしておけば問題はないでしょう。
以上、PyTorch Lightningのシステム要件について説明しました。
次は、PyTorch Lightningのインストールを行います。
PyTorch Lightningのインストール
まずは、現状のインストール済みパッケージを確認しておきます。
この時点では、GPU版のPyTorchをインストールしているだけです。
>pip list Package Version ----------------- ------------ numpy 1.21.2 Pillow 8.3.1 pip 21.2.4 setuptools 57.4.0 torch 1.9.0+cu111 torchaudio 0.9.0 torchvision 0.10.0+cu111 typing-extensions 3.10.0.0
次にするべきことは、pipとsetuptoolsの更新です。
pipコマンドを使う場合、常に以下のコマンドを実行しておきましょう。
python -m pip install --upgrade pip setuptools
では、PyTorch Lightningのインストールです。
PyTorch Lightningのインストールは、以下のコマンドとなります。
pip install pytorch-lightning
インストールは、少し時間はかかります。
では、どんなパッケージがインストールされたのかを確認しましょう。
>pip list Package Version ----------------------- ------------ absl-py 0.13.0 aiohttp 3.7.4.post0 async-timeout 3.0.1 attrs 21.2.0 cachetools 4.2.2 certifi 2021.5.30 chardet 4.0.0 charset-normalizer 2.0.4 colorama 0.4.4 fsspec 2021.7.0 future 0.18.2 google-auth 1.35.0 google-auth-oauthlib 0.4.5 grpcio 1.39.0 idna 3.2 Markdown 3.3.4 multidict 5.1.0 numpy 1.21.2 oauthlib 3.1.1 packaging 21.0 Pillow 8.3.1 pip 21.2.4 protobuf 3.17.3 pyasn1 0.4.8 pyasn1-modules 0.2.8 pyDeprecate 0.3.1 pyparsing 2.4.7 pytorch-lightning 1.4.4 PyYAML 5.4.1 requests 2.26.0 requests-oauthlib 1.3.0 rsa 4.7.2 setuptools 57.4.0 six 1.16.0 tensorboard 2.6.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 torch 1.9.0+cu111 torchaudio 0.9.0 torchmetrics 0.5.0 torchvision 0.10.0+cu111 tqdm 4.62.2 typing-extensions 3.10.0.0 urllib3 1.26.6 Werkzeug 2.0.1 wheel 0.37.0 yarl 1.6.3
多くのパッケージが、インストールされました。
PyTorch Lightningが、それだけ多くのパッケージに依存しているということです。
この結果を見ると、PyTorch LightningはPythonの仮想環境にインストールすべきでしょうね。
依存関係がこれだけ多いと、インストール時にトラブルが起きやすいです。
Windowsなら、IDEにPyCharmを使えば簡単に仮想環境を利用できます。
また、以下のようにコマンドでも簡単に仮想環境を利用できます。
以上、PyTorch Lightningのインストールを説明しました。
最後は、PyTorch Lightningの動作確認を行います。
PyTorch Lightningの動作確認
公式のサンプルコードを試しましょう。
内容は、有名なMNISTの学習でとなります。
import os import torch from torch import nn import torch.nn.functional as F from torchvision.datasets import MNIST from torch.utils.data import DataLoader, random_split from torchvision import transforms import pytorch_lightning as pl class LitAutoEncoder(pl.LightningModule): def __init__(self): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) def forward(self, x): # in lightning, forward defines the prediction/inference actions embedding = self.encoder(x) return embedding def training_step(self, batch, batch_idx): # training_step defines the train loop. It is independent of forward x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) self.log("train_loss", loss) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) train, val = random_split(dataset, [55000, 5000]) autoencoder = LitAutoEncoder() trainer = pl.Trainer() trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
「クラスを用意して、コーディングが簡潔になった」
プログラマーの視点で言うと、上記の意見になります。
だから、特別に何か別の処理ができるようになったとかではありません。
機械学習のコーディングが、簡潔に記述できるようになったということに過ぎません。
また、上記を実行すると、かなりの時間がかかります。
学習が開始されれば、PyTorch Lightningの動作としては問題ありません。
なお、動作確認をする程度であれば、60000件のデータ数は減らした方がよいでしょう。
GPUで処理しても、3時間で終わりませんでした。
以上、PyTorch Lightningの動作確認の説明でした。