Morikatron Engineer Blog

モリカトロン開発者ブログ

自動微分+XLA付き機械学習フレームワークJAXを使用してMNISTを学習させてみる

こんにちは、エンジニアの竹内です。
深層学習を行う際によく利用されるフレームワークといえばGoogleが開発しているTensorflowとFacebookが開発しているPytorchの2大巨頭に加えて、Kerasなどが挙げられるかと思いますが、今回はそのような選択肢の一つとしてGoogleが新しく開発している*1新進気鋭(?)の機械学習フレームワークJAXを紹介したいと思います。
github.com



今回JAXを紹介するきっかけですが、最近話題になったVision Transformerの公式実装のソースコードを読む際に、モデルの実装にJAXが使用されており、少し気になったので勉強がてら触ってみたというのが経緯です。
github.com

ディープラーニングのフレームワークの入門といえばMNISTデータセットを使った画像分類ですので、今回はJAXの入門編としてシンプルな多層パーセプトロン(以下MLP)によるMNISTの分類器を作っていきたいと思います。
なお、例の如くブログに載せているコードは解説用に簡略化しているため、全体像についてはgithubを参照していただければと思います。

JAXとは?

JAXを一言で表現すると「自動微分が搭載された高速化可能なNumPy」といった感じです。TensorflowやPytorchなどはそれぞれのフレームワークにのっとった書き方を覚える必要がありますが、JAXであれば慣れ親しまれているNumPyとほぼ同じ書き方で実装できます。
GPUやTPU*2といったアクセラレータにも対応していますが、Windowsに対応していない(当面対応するつもりがないらしい)ので注意が必要です。

公式ガイドによるとインポートする際はjax.numpyにjnpという省略名をつけてNumPyと同じように扱っていくようです。(機能面もほぼ完全にNumPyと同じです)

import jax.numpy as jnp

また、JAXはTensorflowと同様にXLA(Accelerated Linear Algebra)と呼ばれる線形代数の演算に特化したコンパイラに対応しており、@jitデコレータを指定されたフォーマットの関数につけるだけで実行時に自動的にコンパイルされ、大規模な行列計算やfor文による繰り返しなどを高速化することが出来ます。
今回はjitの有無による学習時間の比較も簡単に行いたいと思います。

自動微分を試してみる

JAXの特徴の1つに自動微分があります。自動微分とはPythonで記述した関数に対する(偏)導関数を自動的に求めてくれる機能であり、ディープラーニングを行うためには必要不可欠となります。TensorflowだとGradientTape, Pytorchだとautogradに相当する機能になります。
具体的な使い方としてはgrad関数の第一引数に関数、第二引数に微分したい変数のIndexを渡し、返ってくる(偏)導関数に対して勾配を求めたい点を入れるとその点における勾配が取得できる、という形になります。
自動微分を求められる演算についてはjnpの他、Pyhtonの組み込み関数や基本的な代数演算にも対応しています。

grad関数を使うことで、以下の例のように二変数関数の勾配を簡単に求めることができます。

# 勾配を求める関数の定義
def f(x1, x2):
    # x1 で微分 -> 2x1 + x2
    # x2 で微分 -> x1 + 2x2
    return x1 ** 2 + x1 * x2 + x2 ** 2


def main():
    # 勾配を求める点
    x1 = 1.0
    x2 = 2.0
    x1_grad = grad(f, argnums=0)(x1, x2)  # x1について微分
    print(f"x1 grad: {x1_grad}")  # -> 4.0
    x2_grad = grad(f, argnums=1)(x1, x2)  # x2について微分
    print(f"x2 grad: {x2_grad}")  # -> 5.0


微分したい変数についてはlistやdictなどに入れて複数の変数をまとめて勾配を求めることも可能です。
ニューラルネットの勾配を求める際は、weightとbiasパラメータを全てlistに入れて勾配を求めていきます。

MNISTを学習させてみる

一通りJAXの使い方がわかったところで、MLPを実装してMNISTデータセットで学習させていきます。
ちなみにJAXにはFlaxという専用のニューラルネットワークライブラリがあるのですが、一応今回はあくまでJAXの基本的な機能の紹介ということで、そういったライブラリを使わずに頑張ってjnpとgradだけで書いていこうかと思います。
github.com

MLPの順伝播の実装

今回は簡単な実験なので拡張性などはあまり意識せず、単にパラメータと入力値を入れたらforward計算して予測値を返してくれる関数だけを実装します。
レイヤーや活性化関数などのパーツも分けずに一つの関数にまとめてしまいます。

# 順伝播の計算
@jit
def mlp_predict(params, x):
    w, b = params[0]
    x = jnp.dot(x, w) + b
    for w, b in params[1:]:
        x = jnp.maximum(0, x)
        x = jnp.dot(x, w) + b
    y = softmax(x)
    return y

def softmax(x):
    x = x - jnp.max(x, axis=1, keepdims=True)  # eの階乗でオーバーフローを起こさないための工夫です。
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=1, keepdims=True)

損失の計算

学習の目的関数となる損失関数には交差エントロピー誤差を使用します。
この関数に先ほどのgradを適用すると各パラメータの勾配が返ってくるようにします。

# 損失関数の定義
def mlp_loss(params, x, y):
    probs = mlp_predict(params, x)
    loss = jnp.mean(cross_entropy(probs, y))
    return loss

def cross_entropy(x, y):
    return jnp.sum(- y * jnp.log(x + 1e-8), axis=1)

パラメータの初期化

パラメータの初期値としてはHeの初期値He et al., 2015, Delving Deep into Rectifiersと呼ばれる値を使います。
これはあるレイヤーのユニット数を nとしてその重みとバイアスが平均 0 標準偏差 \sqrt{\frac{2}{n}}の正規分布に従うように生成するというものです。
JAXのrandom.normal関数では正規分布のパラメータを指定できないので、標準正規分布に標準偏差を掛けて初期値を生成しています。

また、JAXの乱数生成の仕組みは少し独特で、乱数を生成するたびにシード値を指定してキーを生成する必要があります。(同じシード値に対しては同じキーが生成され、同じキーに対しては同じ乱数が生成されます。)
生成したキーをrandom.splitによって分割することで新しいキーを生成することができ、繰り返し乱数を生成したい場合は分割を繰り返していくことになります。

# パラメータの初期値を得る
def init_mlp_params(input_size: int, output_size: int, num_units: list, seed: int = 0):
    params = []
    num_units.append(output_size)
    key = random.PRNGKey(seed)
    last_out = input_size
    for unit in num_units:
        key, subkey = random.split(key)
        # using He initialization method
        x, w = random.normal(key, (last_out, unit), dtype=jnp.float32) * jnp.sqrt(2 / unit), random.normal(subkey, (unit, ), dtype=jnp.float32) * jnp.sqrt(2 / unit)
        params.append((x, w))
        last_out = unit
    return params

パラメータの更新

grad関数を使って取得した勾配をもとにパラメータを更新する部分を実装します。
今回はoptimizerにはモーメントを用いない最もシンプルなSGDを使用します。

# SGDを用いた勾配降下
@jit
def apply_grads(params, grads, lr=0.001):
    return [(w - w_grad * lr, b - b_grad * lr)
            for (w, b), (w_grad, b_grad) in zip(params, grads)]

1エポックの学習

一通り関数の定義ができたところで、学習セットからミニバッチを生成し、1エポック分の学習を回す部分を実装していきます。
パラメータを更新する部分については、損失関数の引数paramsに対してgradを適用し、パラメータの勾配を取得した後それをapply_grads関数に渡すという流れになります。一応学習時の損失も監視できるように計算しておきます。

# 1エポック分の学習
def train_one_epoch(params, X_train, y_train, epoch):
    num_samples = X_train.shape[0]
    random_sample_idx = random.permutation(random.PRNGKey(epoch), jnp.arange(num_samples))
    for idx in tqdm(range(0, num_samples, BATCH_SIZE)):
        mini_batch_idx = random_sample_idx[idx:idx + BATCH_SIZE]  # 今回は実験用なのでハイパラは全部グローバルで定義しています。
        mini_batch_x = X_train[mini_batch_idx]
        mini_batch_y = y_train[mini_batch_idx]
        params_grad = grad(mlp_loss, argnums=0)(params, mini_batch_x, mini_batch_y)
        params = apply_grads(params, params_grad, LEARNING_RATE)
        loss = mlp_loss(params, mini_batch_x, mini_batch_y)
    return loss, params

学習を実行

MNISTデータセットはsklearn.datasetsから引っ張ってきます。(実行するたびにロードしてくるので20秒弱かかります)
引っ張ってきたデータセットの2割をテストセットとして、残りの8割を学習に使用します。今回特にハイパーパラメータのチューニングやearly stoppingなどは行いませんが、一応学習セットの2割をvalidationとして分けておき、lossの減少をモニタリングしておきます。

あくまでチュートリアルということで、データの前処理はone-hot化や正規化など最低限のものに留め、ハイパーパラメータもかなりざっくりと決めています。

# hyper-parameters
BATCH_SIZE = 256
NUM_UNITS = [512, 512]
NUM_EPOCHS = 20
LEARNING_RATE = 0.01
SEED = 1234

def main():
    print("fetching mnisit datasets…")
    X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
    X /= 255.0  # データの正規化
    one_hot_encoder = OneHotEncoder()
    y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).A  # ターゲットのone-hot化
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED)
    X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=SEED)
    params = init_mlp_params(X_train.shape[1], y_train.shape[1], NUM_UNITS, SEED)
    for i in range(NUM_EPOCHS):
        start_time = time.time()
        training_loss, params = train_one_epoch(params, X_train, y_train, i)
        validation_loss = validation(params, X_valid, y_valid)
        epoch_time = time.time() - start_time
        print(f"EPOCH: {i} time: {epoch_time:.3f} training_loss: {training_loss:.3f} validation_loss: {validation_loss:.3f}")
    accuracy = compute_accuracy(params, X_test, y_test)
    print(f"training finished. test acc: {accuracy:.4f}")

結果

まずjitコンパイラを使用して学習を回してみた結果が以下となります。

EPOCH: 0 time: 3.191 training_loss: 1.773 validation_loss: 1.751
EPOCH: 1 time: 1.047 training_loss: 1.249 validation_loss: 1.241
EPOCH: 2 time: 1.167 training_loss: 0.953 validation_loss: 0.934
EPOCH: 3 time: 1.174 training_loss: 0.793 validation_loss: 0.766
EPOCH: 4 time: 1.368 training_loss: 0.613 validation_loss: 0.659
EPOCH: 5 time: 1.192 training_loss: 0.568 validation_loss: 0.586
EPOCH: 6 time: 1.106 training_loss: 0.592 validation_loss: 0.533
EPOCH: 7 time: 1.032 training_loss: 0.472 validation_loss: 0.493
EPOCH: 8 time: 1.151 training_loss: 0.583 validation_loss: 0.461
EPOCH: 9 time: 1.193 training_loss: 0.503 validation_loss: 0.436
EPOCH: 10 time: 0.978 training_loss: 0.428 validation_loss: 0.415
EPOCH: 11 time: 1.182 training_loss: 0.378 validation_loss: 0.399
EPOCH: 12 time: 1.364 training_loss: 0.361 validation_loss: 0.385
EPOCH: 13 time: 1.089 training_loss: 0.372 validation_loss: 0.373
EPOCH: 14 time: 1.078 training_loss: 0.342 validation_loss: 0.363
EPOCH: 15 time: 1.058 training_loss: 0.416 validation_loss: 0.354
EPOCH: 16 time: 1.205 training_loss: 0.295 validation_loss: 0.346
EPOCH: 17 time: 1.096 training_loss: 0.361 validation_loss: 0.339
EPOCH: 18 time: 0.979 training_loss: 0.337 validation_loss: 0.333
EPOCH: 19 time: 1.171 training_loss: 0.348 validation_loss: 0.327
training finished. test acc: 0.9040

テストスコアは90.4%とまあまあですが、とりあえず正しく学習はできていそうです。1エポックの計算時間は大体平均1.1秒といったところでしょうか。validation_lossが増えていないのでまだまだエポックを回しても良さそうですね。
次に各関数の@jitデコレーションを全てコメントアウトして学習を回してみます。

EPOCH: 0 time: 6.571 training_loss: 1.720 validation_loss: 1.743
EPOCH: 1 time: 3.762 training_loss: 1.241 validation_loss: 1.228
EPOCH: 2 time: 3.829 training_loss: 0.875 validation_loss: 0.920
EPOCH: 3 time: 4.009 training_loss: 0.773 validation_loss: 0.752
EPOCH: 4 time: 4.009 training_loss: 0.662 validation_loss: 0.646
EPOCH: 5 time: 4.134 training_loss: 0.508 validation_loss: 0.572
EPOCH: 6 time: 3.884 training_loss: 0.525 validation_loss: 0.519
EPOCH: 7 time: 4.321 training_loss: 0.510 validation_loss: 0.478
EPOCH: 8 time: 4.748 training_loss: 0.550 validation_loss: 0.446
EPOCH: 9 time: 4.533 training_loss: 0.402 validation_loss: 0.421
EPOCH: 10 time: 3.529 training_loss: 0.361 validation_loss: 0.402
EPOCH: 11 time: 3.712 training_loss: 0.348 validation_loss: 0.384
EPOCH: 12 time: 4.295 training_loss: 0.325 validation_loss: 0.370
EPOCH: 13 time: 4.101 training_loss: 0.299 validation_loss: 0.359
EPOCH: 14 time: 4.147 training_loss: 0.304 validation_loss: 0.349
EPOCH: 15 time: 4.176 training_loss: 0.375 validation_loss: 0.340
EPOCH: 16 time: 4.084 training_loss: 0.370 validation_loss: 0.333
EPOCH: 17 time: 4.197 training_loss: 0.277 validation_loss: 0.326
EPOCH: 18 time: 4.666 training_loss: 0.350 validation_loss: 0.320
EPOCH: 19 time: 3.832 training_loss: 0.274 validation_loss: 0.314
training finished. test acc: 0.9052

jitコンパイラを使用しない場合だと1エポックに平均4秒強かかっていることがわかります。jitコンパイラを使用した際のおよそ4倍と、パフォーマンスにかなり差が出ていますね。

おまけ

おまけとしてオプティマイザを通常のSGDからRMSPropに変えてどれくらい精度が上がるのか検証してみました。*3
(複雑なオプティマイザを入れると管理するパラメータが増えるのでそろそろクラスを書きたくなってきますね…)

def apply_grads(params, grads, r, lr=0.01, beta=0.9):
    new_r = []
    new_params = []
    for (w, b), (w_grad, b_grad), (w_r, b_r) in zip(params, grads, r):
        w_r = beta * w_r + (1 - beta) * w_grad ** 2
        b_r = beta * b_r + (1 - beta) * b_grad ** 2
        w_eta = lr / jnp.sqrt(w_r + 1e-6)
        b_eta = lr / jnp.sqrt(b_r + 1e-6)
        new_w = w - w_grad * w_eta
        new_b = b - b_grad * b_eta
        new_r.append((w_r, b_r))
        new_params.append((new_w, new_b))
    return new_params, new_r

# 学習させた結果
EPOCH: 0 time: 5.836 training_loss: 0.222 validation_loss: 0.302
EPOCH: 1 time: 2.283 training_loss: 0.101 validation_loss: 0.212
EPOCH: 2 time: 2.024 training_loss: 0.080 validation_loss: 0.165
EPOCH: 3 time: 1.976 training_loss: 0.060 validation_loss: 0.182
EPOCH: 4 time: 2.238 training_loss: 0.087 validation_loss: 0.161
EPOCH: 5 time: 2.577 training_loss: 0.044 validation_loss: 0.140
EPOCH: 6 time: 2.306 training_loss: 0.013 validation_loss: 0.116
EPOCH: 7 time: 1.992 training_loss: 0.046 validation_loss: 0.153
EPOCH: 8 time: 2.141 training_loss: 0.038 validation_loss: 0.121
EPOCH: 9 time: 2.054 training_loss: 0.007 validation_loss: 0.119
EPOCH: 10 time: 1.973 training_loss: 0.003 validation_loss: 0.120
EPOCH: 11 time: 2.139 training_loss: 0.006 validation_loss: 0.139
EPOCH: 12 time: 1.949 training_loss: 0.002 validation_loss: 0.124
EPOCH: 13 time: 2.459 training_loss: 0.004 validation_loss: 0.138
EPOCH: 14 time: 2.479 training_loss: 0.016 validation_loss: 0.140
EPOCH: 15 time: 2.138 training_loss: 0.007 validation_loss: 0.130
EPOCH: 16 time: 2.010 training_loss: 0.090 validation_loss: 0.169
EPOCH: 17 time: 2.124 training_loss: 0.010 validation_loss: 0.137
EPOCH: 18 time: 2.074 training_loss: 0.010 validation_loss: 0.132
EPOCH: 19 time: 1.705 training_loss: 0.008 validation_loss: 0.152
training finished. test acc: 0.963357150554657

若干過学習気味ですが96.34%とSGDよりはだいぶ良いスコアになりますね。

まとめ

自動微分とXLAを併せ持つNumPyライクな機械学習フレームワークJAXを紹介しました。
TPU未対応であったりWindowsでは動かないなど、大御所のフレームワークと比較するとまだまだ発展途上ではありますが、Pytorchのようなdefine-by-runのニューラルネットワークの構築にNumPyを組み合わせることができるという点を生かし、デバッグを容易にしつつ、損失や勾配計算などに対して独自の複雑な処理を自由に行いたい場合などは選択肢の一つとして有力な候補となりうるかもしれません。

random keyの生成についてはややクセがありますが、論文などの再現性を担保する上では必要となってくる機能なのかもしれません。

機械学習(ディープラーニング)フレームワークのトレンドは時代とともに移り変わっていく傾向もあり、今現在はTensorflowとPytorchがメジャーではありますが、今後はJAXが覇権を握る日が来るかも?しれません。

*1:あくまでリサーチプロジェクトの一つであり、公式の製品では無いらしいです。

*2:公式ReferenceにTPU coming soon!と書かれているのでTPUはまだっぽいです。

*3:今回は趣味でゼロから実装しましたが、Flaxには一通りメジャーなオプティマイザーは用意されています。