【Python】PyTorch Lightningのインストール

【Python】PyTorch Lightningのインストール データ分析

「PyTorchでもっと効率的に機械学習を行いたい」

そのように思う方は、PyTorch Lightningを使いましょう。
PyTorch Lightningを使えば、今よりは機械学習の研究・検証がスムーズに進むはずです。

本記事の内容

  • PyTorch Lightningとは?
  • PyTorch Lightningのシステム要件
  • PyTorch Lightningのインストール
  • PyTorch Lightningの動作確認

それでは、上記に沿って解説していきます。

PyTorch Lightningとは?

PyTorch Lightningとは、高性能AI研究のための軽量なPyTorchラッパーです。
PyTorchで機械学習を行うことは、コーディングを行うことでもあります。

同時に、コーディング以外にもシステム的な知識・スキルが必要とされます。
範囲を広げて言うと、 機械学習ではエンジニアリングをも行っていると言えます。

本来、AIの研究者はエンジニアリングではなく、研究自体にもっと労力を割くべきです。
PyTorch Lightningを使えば、研究者はエンジニアリングに割く労力を削減できます。

つまり、研究者は研究にもっと時間や労力を投入できるということです。
PyTorch Lightningの公式では、次のように表現されています。

LightningはPyTorchのコードを分離し、サイエンスとエンジニアリングを切り離します。

PyTorch Lightningは、研究者だけが対象ではないと思います。
エンジニア以外であれば、マーケッター・アナリストなども対象になるはずです。

もちろん、エンジニアが利用しても良いでしょう。
便利なモノなら、誰が使ってもOKなので。

以上、PyTorch Lightningについても説明でした。
次は、PyTorch Lightningのシステム要件を確認します。

PyTorch Lightningのシステム要件

現時点(2021年8月末)でのPyTorch Lightningの最新バージョンは、1.4.4となります。
この最新バージョンは、2021年8月25日にリリースされています。
結構、高い頻度で更新されているようです。

PyTorch Lightningのシステム要件は、以下の点を確認しましょう。

  • OS
  • Python
  • PyTorch

それぞれを下記で説明します。

OS

サポートOSに関しては、以下を含むクロスプラットフォーム対応です。

  • Windows
  • macOS
  • Linux

そもそも、PyTorchがクロスプラットフォーム対応となります。

Python

サポート対象となるPythonのバージョンは、以下。

  • Python 3.6
  • Python 3.7
  • Python 3.8
  • Python 3.9

下記は、Pythonの公式開発サイクルです。

バージョンリリース日サポート期限
3.62016年12月23日2021年12月
3.72018年6月27日2023年6月
3.82019年10月14日2024年10月
3.92020年10月5日2025年10月

PyTorch Lightningは、上記の開発サイクルに適切に従っています。
これに対して、PyTorchはPython 3.xというアバウトなサポート状況です。

バージョン云々で気になるなら、最新バージョンにアップグレードしておきましょう。
アップグレードは簡単にできます。

PyTorch

PyTorch Lightningは、その名称にあるようにPyTorchありきです。
PyTorchがインストールされていることが、大前提となります。

そして、PyTorchのバージョンはPyTorch 1.6以降が必要です。
なお、PyTorchの最新版はPyTorch 1.9.0となります。

PyTorchのインストールについては、次の記事で解説しています。

まとめ

最も重要なことは、PyTorch 1.6以降ということでしょうね。
そのことは、requirements.txtに記載されています。

とにかく、最新のPyTorchをインストールしておけば問題はないでしょう。

以上、PyTorch Lightningのシステム要件について説明しました。
次は、PyTorch Lightningのインストールを行います。

PyTorch Lightningのインストール

まずは、現状のインストール済みパッケージを確認しておきます。
この時点では、GPU版のPyTorchをインストールしているだけです。

>pip list 
Package           Version 
----------------- ------------ 
numpy             1.21.2 
Pillow            8.3.1 
pip               21.2.4 
setuptools        57.4.0 
torch             1.9.0+cu111 
torchaudio        0.9.0 
torchvision       0.10.0+cu111 
typing-extensions 3.10.0.0

次にするべきことは、pipとsetuptoolsの更新です。
pipコマンドを使う場合、常に以下のコマンドを実行しておきましょう。

python -m pip install --upgrade pip setuptools

では、PyTorch Lightningのインストールです。
PyTorch Lightningのインストールは、以下のコマンドとなります。

pip install pytorch-lightning

インストールは、少し時間はかかります。
では、どんなパッケージがインストールされたのかを確認しましょう。

>pip list
Package                 Version
----------------------- ------------
absl-py                 0.13.0
aiohttp                 3.7.4.post0
async-timeout           3.0.1
attrs                   21.2.0
cachetools              4.2.2
certifi                 2021.5.30
chardet                 4.0.0
charset-normalizer      2.0.4
colorama                0.4.4
fsspec                  2021.7.0
future                  0.18.2
google-auth             1.35.0
google-auth-oauthlib    0.4.5
grpcio                  1.39.0
idna                    3.2
Markdown                3.3.4
multidict               5.1.0
numpy                   1.21.2
oauthlib                3.1.1
packaging               21.0
Pillow                  8.3.1
pip                     21.2.4
protobuf                3.17.3
pyasn1                  0.4.8
pyasn1-modules          0.2.8
pyDeprecate             0.3.1
pyparsing               2.4.7
pytorch-lightning       1.4.4
PyYAML                  5.4.1
requests                2.26.0
requests-oauthlib       1.3.0
rsa                     4.7.2
setuptools              57.4.0
six                     1.16.0
tensorboard             2.6.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit  1.8.0
torch                   1.9.0+cu111
torchaudio              0.9.0
torchmetrics            0.5.0
torchvision             0.10.0+cu111
tqdm                    4.62.2
typing-extensions       3.10.0.0
urllib3                 1.26.6
Werkzeug                2.0.1
wheel                   0.37.0
yarl                    1.6.3

多くのパッケージが、インストールされました。
PyTorch Lightningが、それだけ多くのパッケージに依存しているということです。

この結果を見ると、PyTorch LightningはPythonの仮想環境にインストールすべきでしょうね。
依存関係がこれだけ多いと、インストール時にトラブルが起きやすいです。

Windowsなら、IDEにPyCharmを使えば簡単に仮想環境を利用できます。

また、以下のようにコマンドでも簡単に仮想環境を利用できます。

以上、PyTorch Lightningのインストールを説明しました。
最後は、PyTorch Lightningの動作確認を行います。

PyTorch Lightningの動作確認

公式のサンプルコードを試しましょう。
内容は、有名なMNISTの学習でとなります。

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

「クラスを用意して、コーディングが簡潔になった」
プログラマーの視点で言うと、上記の意見になります。

だから、特別に何か別の処理ができるようになったとかではありません。
機械学習のコーディングが、簡潔に記述できるようになったということに過ぎません。

また、上記を実行すると、かなりの時間がかかります。
学習が開始されれば、PyTorch Lightningの動作としては問題ありません。

なお、動作確認をする程度であれば、60000件のデータ数は減らした方がよいでしょう。
GPUで処理しても、3時間で終わりませんでした。

以上、PyTorch Lightningの動作確認の説明でした。

タイトルとURLをコピーしました