PyTorchで単純な1層の線形モデルの実装と自動微分についての基礎的なメモ
はじめに
本稿はPyTorchの使い方を学んだ際の個人的な備忘録である。今回は「線形層1枚で順伝播・損失計算・逆伝播の流れを手作業で確認する」ことに主眼を置き、PyTorchを少しだけ学ぶ。
このメモの目的
本稿は、ニューラルネットワークを実用的に組むための網羅的な解説ではなく、単純な実装を通して処理の全体像を自分の手で確かめる事を目的とする。そのためニューラルネットワークを最小の構成で組んでおり、活性化関数、多層化、 torch.nn.Module 、 torch.optim などの実用的な部分には触れない。
注意
本稿ではPythonやPyTorchのインストール手順は省略する。インストール方法は公式ドキュメントを参照すると良い。
PyTorchのソースからのビルドとインストールについては、備忘録として以下の記事を書いた。
内容はビルド時の作業ログだが、ビルド環境は環境依存が大きい。ビルドの作業は公式ドキュメントを元に行い、この記事は参考程度の情報と考えた方が良い。あくまで公式ドキュメントに沿って、綺麗な環境で実施することを勧める。
PyTorchとテンソルをざっくり触る
PyTorch はニューラルネットワークを構築するためのライブラリであり、動的計算グラフ、テンソル操作、自動微分などを提供する。
テンソル
僕は数学に詳しいわけじゃないから、テンソルについては大雑把に把握していく。
テンソルとは、形状が (d1, d2, ..., dn) のn次元配列を指す。0階のテンソルはスカラーの事、1階のテンソルはベクトルの事、2階のテンソルは行列の事であり、3階以上はたいていテンソルと呼ばれる。数学的に正しい表現ではないと思うが大雑把なイメージとしては、値がN回リストにネストされて格納されているかという感じだ。ただし各次元の形状が揃っていなければテンソルとは呼べず、不揃いの場合は不規則データや可変長データとして扱うしかなくなる。
| テンソル的な表現 | 一般的な呼び名 | イメージ |
|---|---|---|
| 0階のテンソル | スカラー | 値そのもの |
| 1階のテンソル | ベクトル | スカラーのリスト |
| 2階のテンソル | 行列 | スカラーのリストのリスト |
| 3階のテンソル | テンソル | スカラーのリストのリストのリスト |
PyTorchでは基本的にデータをテンソルという形で取り扱う。 torch.tensor() を使用してテンソルを作る事ができる。
Pythonでの表現(直感的なイメージ)
まずはスカラー、ベクトル、行列、テンソル(3階のテンソル)のイメージを整理する。以下のコードの実際の値は、Pythonではint型の数値やリストだ。括弧書きの表現はあくまで概念的なイメージとして捉えるだけのものだ。
1[1, 2, 3][[1, 2, 3],
[4, 5, 6]][[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]]恐らくこれら組み込み型を使えば様々な実装できるが、それだと大抵の場合は計算の遅さや計算の複雑さが問題になるだろう。それを解消するために、これらの計算に特化した型や処理を提供しているライブラリがある。それがNumpyであるし、PyTorchもテンソル用の型を提供している。
PyTorchでの表現
先程の各値をPyTorchではどのように表現するのかを確認する。PyTorchは、Pythonのモジュール名としては torch として作られているので、まずは torch をインポートする。
import torch
torch.tensor を使って、次のように各値を作る事ができる。
torch.tensor(1)torch.tensor([1, 2, 3])torch.tensor(
[[1, 2, 3],
[4, 5, 6]])torch.tensor(
[[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]])
これらは torch.Tensor 型の値が作られる。PyTorchは、この型の値を高速に計算できるように実装されている。これらは全て torch.Tensor 型になる。括弧内はあくまでそのように考える事ができるというものだ。
今回の課題設定
本稿では分かりやすさを優先するために、「入力 [0.1, 0.2, 0.3] を線形変換し、どんなときも出力 [1.0] を返す」 という極端に単純な課題を設定する。この課題を解く1 層のモデルを PyTorch で手動実装し、勾配計算によって係数を更新する流れをなぞる。
期待する入出力
- 入力:
torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) - 期待値:
torch.tensor([1.0], dtype=torch.float32)
学びたいポイント
本稿で特に以下のポイントを順を追って実行し、その動きを確認する。
- テンソルの概念と
torch.tensor()の基本的な使い方 torch.nn.functional.linear()とtorch.nn.functional.mse_loss()を用いた順伝播と損失計算requires_grad=Trueとloss.backward()による自動微分の仕組みtorch.no_grad()と勾配リセットで更新ループを自前で書くとどうなるか
単純なニューラルネットワークを作る
ここからは前節で定めた課題を torch.nn.functional を中心に最小構成で実装する。なお多層化や活性化関数はゴールから外れるため扱わない。
重みとバイアス
ニューラルネットワークでは重みやバイアスといった数値の集まりを扱う。この重みやバイアスは、PyTorchでは torch.Tensor 型の値として扱う。
それでは重みとバイアスを適当に定義してみる。通常、重みやバイアスに含まれる値は浮動小数点を使う。 torch.tensor() は dtype を指定する事で値の型を指定できる。ここでは torch.float32 を指定した。
weight = torch.tensor([[.4, .5, .6]], dtype=torch.float32)
bias = torch.tensor([.1], dtype=torch.float32)順伝播
この値を使ってデータ変換をしてみよう。ここでは例として適当な入力を作り、その入力を線形変換する。
まず入力となるデータを作る。
input_data = torch.tensor([.1, .2, .3], dtype=torch.float32)
print(input_data)tensor([0.1000, 0.2000, 0.3000])
次にこのデータの変換処理を行う。PyTorchには変換処理を行う関数も多数用意されているが、ここでは線形変換を適用する関数 torch.nn.functional.linear() を使用する。この関数は引数として入力(input)、重み(weight)、バイアス(bias、切片とも言うらしい)を受け取り、入力に対して線形変換を適用し、その値を返す。
import torch.nn.functional
output_data = torch.nn.functional.linear(input_data, weight, bias)
print(output_data)tensor([0.4200])
tensor([0.1000, 0.2000, 0.3000]) という値が、重みとバイアスを使って線形変換が適用され、 tensor([0.4200]) という値に変換された。このように重みとバイアスを使って入力を変換する処理は 順伝播 と呼ばれる。
逆伝播と損失の計算
この変換の結果は、重みとバイアスによって変化する。いわゆるモデルの訓練とは、この重みとバイアスの値を調整する事だ。しかし、この例では weight も bias も更新される仕組みがない。そこで、重みとパラメータを調整するための仕組みを組み込んでいく事にする。
先程の例では、ただ数値を変換し結果を確認した。この変換結果は期待通りだったのだろうか。重みやバイアスを調整するためには、どのような変換結果が期待通りなのかを決める必要がある。そして実際に変換された値と期待する値との差が小さくなるように、重みやバイアスを調整する事になる。
何が起きているのかを分かりやすくするために、ここではどんな入力も tensor([1.000]) に変換されるような、重みとバイアスの調整を例として考えてみる。
損失関数による損失値の計算
ここからは期待通りの値は期待値、実際の変換結果を予測値と呼ぶ事にしよう。つまり今回の例では期待値は tensor([1.000]) で固定となる。
重みとバイアスを調整するためには、期待値と予測値の異なり具合を計算し、それを反映する事になる。この異なり具合の計算で使われる関数は 損失関数 と呼ばれ、その計算方法は様々ある。PyTorchでは torch.nn.functional に 〜_loss という名前で、よく使われる損失関数が実装されている。ここでは、平均二乗誤差を実装している torch.nn.functional.mse_loss() を使用する。平均二乗誤差は回帰問題で損失関数としてよく使われる。期待値と出力の差を二乗し、その平均を計算する。
expect_data = torch.tensor([1.000], dtype=torch.float32)
loss = torch.nn.functional.mse_loss(output_data, expect_data)
print(loss)tensor(0.3364)
損失として tensor(0.3364) を計算できた。
逆伝播
次はこの損失値を、重みとバイアスに反映する必要がある。そのためには、損失値を計算する時に使用したテンソルを逆方向に辿り、勾配を計算していく必要がある。
PyTorchはそれを簡単に行う機能を提供している。 torch.Torch の計算は、計算グラフという独自に参照のようなものを組み込んでおり、そのテンソルの計算がどのテンソル由来のものかを辿れる機能がある。この機能を有効にするには torch.tensor() を呼び出す際に requires_grad=True を指定する。
example_tensor= torch.tensor([1.], dtype=torch.float32, requires_grad=True)
この requires_grad=True を指定したテンソルは、計算を構成する各演算とテンソルをノードとして、依存関係をエッジで表現し保持している。そのため、結果を元に依存関係を辿る事ができる。
この機能を利用して損失値を元に逆方向にグラフを辿り、各ノードで勾配計算を行って周れるようになっている。これは逆伝播(バックプロパゲーション)と呼ばれる。
最初の例で重みやバイアスを定義していたが requires_grad=True は指定していなかった。これでは計算グラフを辿る事ができないため、重みやバイアスの定義をしなおし、線形変換を行い、損失関数による損失値の計算までをやり直す。
weight = torch.tensor([[.4, .5, .6]], dtype=torch.float32, requires_grad=True)
bias = torch.tensor([.1], dtype=torch.float32, requires_grad=True)
input_data = torch.tensor([.1, .2, .3], dtype=torch.float32)
expect_data = torch.tensor([1.000], dtype=torch.float32)
output_data = torch.nn.functional.linear(input_data, weight, bias)
loss = torch.nn.functional.mse_loss(output_data, expect_data)
損失値から逆伝播を行う。 逆伝播は .backward() で呼び出せる。
loss.backward()
計算された勾配の結果は各テンソルの .grad に格納されている。これを使用して重みとバイアスを更新する。更新の影響が適度な大きさになるように 学習率 も合わせて指定する。この学習率は、大きすぎると発散し小さすぎると収束が遅くなる。ここでは 0.01 を設定する。
また重みやバイアスに値を反映する時の計算は、その後の勾配計算に使う事はなく影響されてもいけないたため torch.no_grad() コンテキストマネージャにより、計算グラフを作成を一時的に無効化する。
lr = 0.01 # 学習率 (learning rate)
with torch.no_grad():
weight -= lr * weight.grad
bias -= lr * bias.grad勾配計算のための情報をリセットする
これで重みとバイアスを更新できた。最後にこれまでの処理が次のステップの勾配の計算に影響ないように、保持されている勾配をリセットする。
weight.grad.zero_()
bias.grad.zero_()訓練を繰り返す
ここまでの処理を何度も繰り返し、重みやバイアスが調整される事で、期待する結果が得られるようにする。ここでは1000回繰り返してみよう。
weight = torch.tensor([[.4, .5, .6]], dtype=torch.float32, requires_grad=True)
bias = torch.tensor([.1], dtype=torch.float32, requires_grad=True)
input_data = torch.tensor([.1, .2, .3], dtype=torch.float32)
expect_data = torch.tensor([1.000], dtype=torch.float32)
lr = 0.01 # 学習率 (learning rate)
for ii in range(1000):
output_data = torch.nn.functional.linear(input_data, weight, bias)
loss = torch.nn.functional.mse_loss(output_data, expect_data)
loss.backward()
with torch.no_grad():
weight -= lr * weight.grad
bias -= lr * bias.grad
weight.grad.zero_()
bias.grad.zero_()訓練結果
結果を確認する。
with torch.no_grad():
for ii in range(10):
print(torch.nn.functional.linear(input_data, weight, bias))
初期状態での出力は tensor([0.4200]) 、損失は tensor(0.3364) と目標から遠い値だったが、ループを回して係数を更新すると徐々に出力が tensor([0.999...]) へ近づき、損失も十分に小さくなる。極小な問題設定なので、単一サンプルでも勾配降下が収束する様子を観察しやすい。
| ステップ | 出力 | 損失 |
|---|---|---|
| 0 | 0.4200 | 3.364e-01 |
| 1 | 0.4332 | 3.212e-01 |
| 2 | 0.4461 | 3.067e-01 |
| 9 | 0.5287 | 2.221e-01 |
| 99 | 0.9409 | 3.496e-03 |
| 199 | 0.9941 | 3.470e-05 |
| 499 | 0.99999 | 3.391e-11 |
| 999 | 1.00000 | 3.264e-21 |
この問題は [.1, .2, .3] を入力されたら [1.] を出力したいという、あまり意味のない単純なものだから、1000回程度訓練すれば期待する値を得られるようになった。Pythonで書くとこんな感じだろうか(ちょっと違うし、不動小数点比較するなみたいなのはあるけど、あくまでイメージなのでそういうのは無視する)。
def func(val):
if [.1, .2, .3] == val: # 近い数値は考慮していない
return [1.]
return [0.] # ここは何も考慮していないこれだけだとあまり有用には思えないが、昨今のAIの技術がこれらの技術をベースにしている事を考えると、認知機能の中心的な部分にこのような単純な事が行われているのだろう。
次に学ぶなら
この次に学ぶなら以下のような内容を学んでいくと良さそうだ。
torch.nn.ReLUなどの活性化関数を差し込み、非線形性がないと表現力が制限されることを確かめる(公式チュートリアル「Neural Networks」参照)。torch.nn.Moduleでモデルクラスを組み、torch.optim.SGDなどのオプティマイザに更新を任せて今回手動で書いた部分と比較する(「Training a Classifier」チュートリアルが参考になる)。- 複数サンプルを
DataLoaderで扱い、バッチ学習と損失平均化がどのように実装されるかを追う(公式ドキュメントの「Data Loading and Processing Tutorial」)。
まとめ
今回得たこと
- 1層の線形モデルなら
torch.nn.functionalの最小APIでも順伝播→損失計算→逆伝播→更新を一通り体験できる。 - `requires_grad=True` と `torch.no_grad()` の使い分けを通じて、PyTorchの計算グラフと自動微分の仕組みを肌で感じられた。
- 学習率0.01・1000ステップ程度で、単一サンプルでも損失が十分に小さくなる挙動を確認できた。
今回触れなかったが重要なこと
- 活性化関数や多層化を行わない限り線形モデルの表現力は限定される。
torch.nn.Moduleとtorch.optimによる構造化・更新高速化は実用コードでは必須になる。- データのバッチ化、正則化、評価指標など、本番運用に求められる周辺要素はここでは扱っていない。
実際の実装では torch.nn.functional の関数を直接呼び出すのではなく、 torch.nn.Module や torch.optim で提供されているものを使用したり、もっと多層にして複雑な問題を解決できるようにしていく事になる。それは別の機会にまとめたい。
ただ、このように基本的な機能のみを使う事で、PyTorchが何を肩代わりしてくれているのかを把握しやすくなる。今回は少しだけPyTorchの事が理解できた気がする。