こんにちは、エンジニアの竹内です。
深層学習を行う際によく利用されるフレームワークといえばGoogleが開発しているTensorflowとFacebookが開発しているPytorchの2大巨頭に加えて、Kerasなどが挙げられるかと思いますが、今回はそのような選択肢の一つとしてGoogleが新しく開発している*1新進気鋭(?)の機械学習フレームワークJAXを紹介したいと思います。
github.com
今回JAXを紹介するきっかけですが、最近話題になったVision Transformerの公式実装のソースコードを読む際に、モデルの実装にJAXが使用されており、少し気になったので勉強がてら触ってみたというのが経緯です。
github.com
ディープラーニングのフレームワークの入門といえばMNISTデータセットを使った画像分類ですので、今回はJAXの入門編としてシンプルな多層パーセプトロン(以下MLP)によるMNISTの分類器を作っていきたいと思います。
なお、例の如くブログに載せているコードは解説用に簡略化しているため、全体像についてはgithubを参照していただければと思います。
JAXとは?
JAXを一言で表現すると「自動微分が搭載された高速化可能なNumPy」といった感じです。TensorflowやPytorchなどはそれぞれのフレームワークにのっとった書き方を覚える必要がありますが、JAXであれば慣れ親しまれているNumPyとほぼ同じ書き方で実装できます。
GPUやTPU*2といったアクセラレータにも対応していますが、Windowsに対応していない(当面対応するつもりがないらしい)ので注意が必要です。
公式ガイドによるとインポートする際はjax.numpyにjnpという省略名をつけてNumPyと同じように扱っていくようです。(機能面もほぼ完全にNumPyと同じです)
import jax.numpy as jnp
また、JAXはTensorflowと同様にXLA(Accelerated Linear Algebra)と呼ばれる線形代数の演算に特化したコンパイラに対応しており、@jitデコレータを指定されたフォーマットの関数につけるだけで実行時に自動的にコンパイルされ、大規模な行列計算やfor文による繰り返しなどを高速化することが出来ます。
今回はjitの有無による学習時間の比較も簡単に行いたいと思います。
自動微分を試してみる
JAXの特徴の1つに自動微分があります。自動微分とはPythonで記述した関数に対する(偏)導関数を自動的に求めてくれる機能であり、ディープラーニングを行うためには必要不可欠となります。TensorflowだとGradientTape, Pytorchだとautogradに相当する機能になります。
具体的な使い方としてはgrad関数の第一引数に関数、第二引数に微分したい変数のIndexを渡し、返ってくる(偏)導関数に対して勾配を求めたい点を入れるとその点における勾配が取得できる、という形になります。
自動微分を求められる演算についてはjnpの他、Pyhtonの組み込み関数や基本的な代数演算にも対応しています。
grad関数を使うことで、以下の例のように二変数関数の勾配を簡単に求めることができます。
# 勾配を求める関数の定義 def f(x1, x2): # x1 で微分 -> 2x1 + x2 # x2 で微分 -> x1 + 2x2 return x1 ** 2 + x1 * x2 + x2 ** 2 def main(): # 勾配を求める点 x1 = 1.0 x2 = 2.0 x1_grad = grad(f, argnums=0)(x1, x2) # x1について微分 print(f"x1 grad: {x1_grad}") # -> 4.0 x2_grad = grad(f, argnums=1)(x1, x2) # x2について微分 print(f"x2 grad: {x2_grad}") # -> 5.0
微分したい変数についてはlistやdictなどに入れて複数の変数をまとめて勾配を求めることも可能です。
ニューラルネットの勾配を求める際は、weightとbiasパラメータを全てlistに入れて勾配を求めていきます。
MNISTを学習させてみる
一通りJAXの使い方がわかったところで、MLPを実装してMNISTデータセットで学習させていきます。
ちなみにJAXにはFlaxという専用のニューラルネットワークライブラリがあるのですが、一応今回はあくまでJAXの基本的な機能の紹介ということで、そういったライブラリを使わずに頑張ってjnpとgradだけで書いていこうかと思います。
github.com
MLPの順伝播の実装
今回は簡単な実験なので拡張性などはあまり意識せず、単にパラメータと入力値を入れたらforward計算して予測値を返してくれる関数だけを実装します。
レイヤーや活性化関数などのパーツも分けずに一つの関数にまとめてしまいます。
# 順伝播の計算 @jit def mlp_predict(params, x): w, b = params[0] x = jnp.dot(x, w) + b for w, b in params[1:]: x = jnp.maximum(0, x) x = jnp.dot(x, w) + b y = softmax(x) return y def softmax(x): x = x - jnp.max(x, axis=1, keepdims=True) # eの階乗でオーバーフローを起こさないための工夫です。 return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=1, keepdims=True)
損失の計算
学習の目的関数となる損失関数には交差エントロピー誤差を使用します。
この関数に先ほどのgradを適用すると各パラメータの勾配が返ってくるようにします。
# 損失関数の定義 def mlp_loss(params, x, y): probs = mlp_predict(params, x) loss = jnp.mean(cross_entropy(probs, y)) return loss def cross_entropy(x, y): return jnp.sum(- y * jnp.log(x + 1e-8), axis=1)
パラメータの初期化
パラメータの初期値としてはHeの初期値He et al., 2015, Delving Deep into Rectifiersと呼ばれる値を使います。
これはあるレイヤーのユニット数をとしてその重みとバイアスが平均 標準偏差の正規分布に従うように生成するというものです。
JAXのrandom.normal関数では正規分布のパラメータを指定できないので、標準正規分布に標準偏差を掛けて初期値を生成しています。
また、JAXの乱数生成の仕組みは少し独特で、乱数を生成するたびにシード値を指定してキーを生成する必要があります。(同じシード値に対しては同じキーが生成され、同じキーに対しては同じ乱数が生成されます。)
生成したキーをrandom.splitによって分割することで新しいキーを生成することができ、繰り返し乱数を生成したい場合は分割を繰り返していくことになります。
# パラメータの初期値を得る def init_mlp_params(input_size: int, output_size: int, num_units: list, seed: int = 0): params = [] num_units.append(output_size) key = random.PRNGKey(seed) last_out = input_size for unit in num_units: key, subkey = random.split(key) # using He initialization method x, w = random.normal(key, (last_out, unit), dtype=jnp.float32) * jnp.sqrt(2 / unit), random.normal(subkey, (unit, ), dtype=jnp.float32) * jnp.sqrt(2 / unit) params.append((x, w)) last_out = unit return params
パラメータの更新
grad関数を使って取得した勾配をもとにパラメータを更新する部分を実装します。
今回はoptimizerにはモーメントを用いない最もシンプルなSGDを使用します。
# SGDを用いた勾配降下 @jit def apply_grads(params, grads, lr=0.001): return [(w - w_grad * lr, b - b_grad * lr) for (w, b), (w_grad, b_grad) in zip(params, grads)]
1エポックの学習
一通り関数の定義ができたところで、学習セットからミニバッチを生成し、1エポック分の学習を回す部分を実装していきます。
パラメータを更新する部分については、損失関数の引数paramsに対してgradを適用し、パラメータの勾配を取得した後それをapply_grads関数に渡すという流れになります。一応学習時の損失も監視できるように計算しておきます。
# 1エポック分の学習 def train_one_epoch(params, X_train, y_train, epoch): num_samples = X_train.shape[0] random_sample_idx = random.permutation(random.PRNGKey(epoch), jnp.arange(num_samples)) for idx in tqdm(range(0, num_samples, BATCH_SIZE)): mini_batch_idx = random_sample_idx[idx:idx + BATCH_SIZE] # 今回は実験用なのでハイパラは全部グローバルで定義しています。 mini_batch_x = X_train[mini_batch_idx] mini_batch_y = y_train[mini_batch_idx] params_grad = grad(mlp_loss, argnums=0)(params, mini_batch_x, mini_batch_y) params = apply_grads(params, params_grad, LEARNING_RATE) loss = mlp_loss(params, mini_batch_x, mini_batch_y) return loss, params
学習を実行
MNISTデータセットはsklearn.datasetsから引っ張ってきます。(実行するたびにロードしてくるので20秒弱かかります)
引っ張ってきたデータセットの2割をテストセットとして、残りの8割を学習に使用します。今回特にハイパーパラメータのチューニングやearly stoppingなどは行いませんが、一応学習セットの2割をvalidationとして分けておき、lossの減少をモニタリングしておきます。
あくまでチュートリアルということで、データの前処理はone-hot化や正規化など最低限のものに留め、ハイパーパラメータもかなりざっくりと決めています。
# hyper-parameters BATCH_SIZE = 256 NUM_UNITS = [512, 512] NUM_EPOCHS = 20 LEARNING_RATE = 0.01 SEED = 1234 def main(): print("fetching mnisit datasets…") X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True) X /= 255.0 # データの正規化 one_hot_encoder = OneHotEncoder() y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).A # ターゲットのone-hot化 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED) X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=SEED) params = init_mlp_params(X_train.shape[1], y_train.shape[1], NUM_UNITS, SEED) for i in range(NUM_EPOCHS): start_time = time.time() training_loss, params = train_one_epoch(params, X_train, y_train, i) validation_loss = validation(params, X_valid, y_valid) epoch_time = time.time() - start_time print(f"EPOCH: {i} time: {epoch_time:.3f} training_loss: {training_loss:.3f} validation_loss: {validation_loss:.3f}") accuracy = compute_accuracy(params, X_test, y_test) print(f"training finished. test acc: {accuracy:.4f}")
結果
まずjitコンパイラを使用して学習を回してみた結果が以下となります。
EPOCH: 0 time: 3.191 training_loss: 1.773 validation_loss: 1.751 EPOCH: 1 time: 1.047 training_loss: 1.249 validation_loss: 1.241 EPOCH: 2 time: 1.167 training_loss: 0.953 validation_loss: 0.934 EPOCH: 3 time: 1.174 training_loss: 0.793 validation_loss: 0.766 EPOCH: 4 time: 1.368 training_loss: 0.613 validation_loss: 0.659 EPOCH: 5 time: 1.192 training_loss: 0.568 validation_loss: 0.586 EPOCH: 6 time: 1.106 training_loss: 0.592 validation_loss: 0.533 EPOCH: 7 time: 1.032 training_loss: 0.472 validation_loss: 0.493 EPOCH: 8 time: 1.151 training_loss: 0.583 validation_loss: 0.461 EPOCH: 9 time: 1.193 training_loss: 0.503 validation_loss: 0.436 EPOCH: 10 time: 0.978 training_loss: 0.428 validation_loss: 0.415 EPOCH: 11 time: 1.182 training_loss: 0.378 validation_loss: 0.399 EPOCH: 12 time: 1.364 training_loss: 0.361 validation_loss: 0.385 EPOCH: 13 time: 1.089 training_loss: 0.372 validation_loss: 0.373 EPOCH: 14 time: 1.078 training_loss: 0.342 validation_loss: 0.363 EPOCH: 15 time: 1.058 training_loss: 0.416 validation_loss: 0.354 EPOCH: 16 time: 1.205 training_loss: 0.295 validation_loss: 0.346 EPOCH: 17 time: 1.096 training_loss: 0.361 validation_loss: 0.339 EPOCH: 18 time: 0.979 training_loss: 0.337 validation_loss: 0.333 EPOCH: 19 time: 1.171 training_loss: 0.348 validation_loss: 0.327 training finished. test acc: 0.9040
テストスコアは90.4%とまあまあですが、とりあえず正しく学習はできていそうです。1エポックの計算時間は大体平均1.1秒といったところでしょうか。validation_lossが増えていないのでまだまだエポックを回しても良さそうですね。
次に各関数の@jitデコレーションを全てコメントアウトして学習を回してみます。
EPOCH: 0 time: 6.571 training_loss: 1.720 validation_loss: 1.743 EPOCH: 1 time: 3.762 training_loss: 1.241 validation_loss: 1.228 EPOCH: 2 time: 3.829 training_loss: 0.875 validation_loss: 0.920 EPOCH: 3 time: 4.009 training_loss: 0.773 validation_loss: 0.752 EPOCH: 4 time: 4.009 training_loss: 0.662 validation_loss: 0.646 EPOCH: 5 time: 4.134 training_loss: 0.508 validation_loss: 0.572 EPOCH: 6 time: 3.884 training_loss: 0.525 validation_loss: 0.519 EPOCH: 7 time: 4.321 training_loss: 0.510 validation_loss: 0.478 EPOCH: 8 time: 4.748 training_loss: 0.550 validation_loss: 0.446 EPOCH: 9 time: 4.533 training_loss: 0.402 validation_loss: 0.421 EPOCH: 10 time: 3.529 training_loss: 0.361 validation_loss: 0.402 EPOCH: 11 time: 3.712 training_loss: 0.348 validation_loss: 0.384 EPOCH: 12 time: 4.295 training_loss: 0.325 validation_loss: 0.370 EPOCH: 13 time: 4.101 training_loss: 0.299 validation_loss: 0.359 EPOCH: 14 time: 4.147 training_loss: 0.304 validation_loss: 0.349 EPOCH: 15 time: 4.176 training_loss: 0.375 validation_loss: 0.340 EPOCH: 16 time: 4.084 training_loss: 0.370 validation_loss: 0.333 EPOCH: 17 time: 4.197 training_loss: 0.277 validation_loss: 0.326 EPOCH: 18 time: 4.666 training_loss: 0.350 validation_loss: 0.320 EPOCH: 19 time: 3.832 training_loss: 0.274 validation_loss: 0.314 training finished. test acc: 0.9052
jitコンパイラを使用しない場合だと1エポックに平均4秒強かかっていることがわかります。jitコンパイラを使用した際のおよそ4倍と、パフォーマンスにかなり差が出ていますね。
おまけ
おまけとしてオプティマイザを通常のSGDからRMSPropに変えてどれくらい精度が上がるのか検証してみました。*3
(複雑なオプティマイザを入れると管理するパラメータが増えるのでそろそろクラスを書きたくなってきますね…)
def apply_grads(params, grads, r, lr=0.01, beta=0.9): new_r = [] new_params = [] for (w, b), (w_grad, b_grad), (w_r, b_r) in zip(params, grads, r): w_r = beta * w_r + (1 - beta) * w_grad ** 2 b_r = beta * b_r + (1 - beta) * b_grad ** 2 w_eta = lr / jnp.sqrt(w_r + 1e-6) b_eta = lr / jnp.sqrt(b_r + 1e-6) new_w = w - w_grad * w_eta new_b = b - b_grad * b_eta new_r.append((w_r, b_r)) new_params.append((new_w, new_b)) return new_params, new_r # 学習させた結果 EPOCH: 0 time: 5.836 training_loss: 0.222 validation_loss: 0.302 EPOCH: 1 time: 2.283 training_loss: 0.101 validation_loss: 0.212 EPOCH: 2 time: 2.024 training_loss: 0.080 validation_loss: 0.165 EPOCH: 3 time: 1.976 training_loss: 0.060 validation_loss: 0.182 EPOCH: 4 time: 2.238 training_loss: 0.087 validation_loss: 0.161 EPOCH: 5 time: 2.577 training_loss: 0.044 validation_loss: 0.140 EPOCH: 6 time: 2.306 training_loss: 0.013 validation_loss: 0.116 EPOCH: 7 time: 1.992 training_loss: 0.046 validation_loss: 0.153 EPOCH: 8 time: 2.141 training_loss: 0.038 validation_loss: 0.121 EPOCH: 9 time: 2.054 training_loss: 0.007 validation_loss: 0.119 EPOCH: 10 time: 1.973 training_loss: 0.003 validation_loss: 0.120 EPOCH: 11 time: 2.139 training_loss: 0.006 validation_loss: 0.139 EPOCH: 12 time: 1.949 training_loss: 0.002 validation_loss: 0.124 EPOCH: 13 time: 2.459 training_loss: 0.004 validation_loss: 0.138 EPOCH: 14 time: 2.479 training_loss: 0.016 validation_loss: 0.140 EPOCH: 15 time: 2.138 training_loss: 0.007 validation_loss: 0.130 EPOCH: 16 time: 2.010 training_loss: 0.090 validation_loss: 0.169 EPOCH: 17 time: 2.124 training_loss: 0.010 validation_loss: 0.137 EPOCH: 18 time: 2.074 training_loss: 0.010 validation_loss: 0.132 EPOCH: 19 time: 1.705 training_loss: 0.008 validation_loss: 0.152 training finished. test acc: 0.963357150554657
若干過学習気味ですが96.34%とSGDよりはだいぶ良いスコアになりますね。
まとめ
自動微分とXLAを併せ持つNumPyライクな機械学習フレームワークJAXを紹介しました。
TPU未対応であったりWindowsでは動かないなど、大御所のフレームワークと比較するとまだまだ発展途上ではありますが、Pytorchのようなdefine-by-runのニューラルネットワークの構築にNumPyを組み合わせることができるという点を生かし、デバッグを容易にしつつ、損失や勾配計算などに対して独自の複雑な処理を自由に行いたい場合などは選択肢の一つとして有力な候補となりうるかもしれません。
random keyの生成についてはややクセがありますが、論文などの再現性を担保する上では必要となってくる機能なのかもしれません。
機械学習(ディープラーニング)フレームワークのトレンドは時代とともに移り変わっていく傾向もあり、今現在はTensorflowとPytorchがメジャーではありますが、今後はJAXが覇権を握る日が来るかも?しれません。