Morikatron Engineer Blog

モリカトロン開発者ブログ

UnityとAx(Adaptive Experimentation Platform)の連携

はじめまして、モリカトロン・プログラマのはっとりです。
今日からブログを始めます。

初回はUnityとPythonのAx(Adaptive Experimentation Platform)を連携させてみた件について書こうと思います。

はじめに

ことの発端は

  • シミュレーションのパラメータ自動調整をやってみよう
  • アルゴリズムはベイズ最適化を使ってみよう
  • Axというベイズ最適化のツールがあるらしい

ということで

  • まずはUnityとAxを連携させるサンプルを作ろう

となりました。

その次にくるべき

  • 実際にシミュレーションを最適化させてみよう

は、追い追い報告したいと思います。

Axとは

Axとは、数理モデルの最適化を自動化するPythonライブラリです。
内部でBoTorchを使用しています。
BoTorchとは、ベイズ最適化(Bayesian Optimization)のPythonライブラリです。
いずれも昨年Facebookから公開されました。
BoTorchは単体でも使用できますが、Axを使えばBoTorchの存在を意識することなく、実験自体に集中することができます。
https://ax.dev/
https://botorch.org/

ベイズ最適化とは

ベイズ最適化とは、未知の関数が最大値(最小値)を取るような入力を、ガウス過程回帰を利用して効率よく探索するための手法です。
これ以上の説明は筆者には荷が重いので、詳しくは他の解説記事に譲りますが、筆者自身は主に下記の書籍を参考にしました。
www.kspub.co.jp

Axの超簡単なサンプル

本題に入る前に、まずはAxの超簡単なサンプルを見てみましょう。
これは公式サイトの「GET STARTED」で紹介されているものと同じ内容です。 https://ax.dev/#quickstart

インストール

PyTorchとAxのインストール

conda install pytorch torchvision -c pytorch # OSX only

pip install ax-platform # all systems

実際には依存関係で他のライブラリも自動的にインストールされますが、ユーザーが手で入力するのはこの2行のみです。
BoTorchもAxをインストールすると依存ライブラリとして自動的にインストールされます。

PyTorchのインストールコマンドは動作させる環境によって異なるので、詳しくはPyTorchの公式サイトを参照してください。
https://pytorch.org/get-started/locally/

スクリプト

>>> from ax import optimize
>>> best_parameters, best_values, experiment, model = optimize(
        parameters=[
          {
            "name": "x1",
            "type": "range",
            "bounds": [-10.0, 10.0],
          },
          {
            "name": "x2",
            "type": "range",
            "bounds": [-10.0, 10.0],
          },
        ],
        # Booth function
        evaluation_function=lambda p: (p["x1"] + 2*p["x2"] - 7)**2 + (2*p["x1"] + p["x2"] - 5)**2,
        minimize=True,
    )
>>> best_parameters

実行結果

実際に筆者の環境で動作させてみました。

[INFO 04-16 17:28:26] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to  model-fitting.
[INFO 04-16 17:28:26] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 1...
[INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 2...
[INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 3...
[INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 4...
[INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 5...
[INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 6...
[INFO 04-16 17:28:27] ax.service.managed_loop: Running optimization trial 7...
[INFO 04-16 17:28:28] ax.service.managed_loop: Running optimization trial 8...
[INFO 04-16 17:28:30] ax.service.managed_loop: Running optimization trial 9...
[INFO 04-16 17:28:31] ax.service.managed_loop: Running optimization trial 10...
[INFO 04-16 17:28:32] ax.service.managed_loop: Running optimization trial 11...
[INFO 04-16 17:28:33] ax.service.managed_loop: Running optimization trial 12...
[INFO 04-16 17:28:34] ax.service.managed_loop: Running optimization trial 13...
[INFO 04-16 17:28:36] ax.service.managed_loop: Running optimization trial 14...
[INFO 04-16 17:28:37] ax.service.managed_loop: Running optimization trial 15...
[INFO 04-16 17:28:39] ax.service.managed_loop: Running optimization trial 16...
[INFO 04-16 17:28:40] ax.service.managed_loop: Running optimization trial 17...
[INFO 04-16 17:28:42] ax.service.managed_loop: Running optimization trial 18...
[INFO 04-16 17:28:43] ax.service.managed_loop: Running optimization trial 19...
[INFO 04-16 17:28:44] ax.service.managed_loop: Running optimization trial 20...
{'x1': 0.9242265891515515, 'x2': 3.0831835290815146}

このサンプルで扱っている最適化問題はBooth Functionというもので最適化アルゴリズムのベンチマークの一つです。 上の結果では20回の試行で最適解(x1=1,x2=3)にかなり近づいていることがわかります。 ちなみに筆者が30回で試したところほぼ最適解に辿り着きました。

Booth Functionについては下記を参照してください。
https://www.sfu.ca/~ssurjano/booth.html

APIについて

このサンプルはAxの3つある書き方(API)のうち、最もシンプルな

  • Loop API

で書かれています。上の例では

  • 最適化対象のパラメータ(parameters:例ではx1、x2の二つ)
  • 評価関数(evaluation_function:例ではlambda式)
  • 目標を最小化するか(minimize:例ではTrue最小化)

を定義しているのみです。 これ以外にもトライアル回数など、指定できるパラメータはいくつかあります。 Loop APIで十分というケースはありそうですが、より細かなカスタマイズをする場合は、他の2つのAPI

  • Service API
  • Developer API

を使用します。 これらのAPIを利用すると実験の統計データの保存や実験状態の保存・復元などが可能になります。

各APIの詳細については公式サイトのチュートリアルを参照してください。
https://ax.dev/tutorials/

Unity~Ax連携のサンプルプログラム

ここからが今回の本題です。

ターゲットへの適用前に、本当にできるの?というのを確認するために、技術調査用にUnity(WebGL)とAxを連携させるサンプルプログラムを作成しました。

最適化問題は前述のBooth Functionを使用しました。
公式サンプルと違う点は、評価関数(evaluation_function)をクライアント側に持たせたことです。
本来なら評価計算に必要な元ネタを貰ってサーバー側で評価計算してもいいのですが、Booth Functionの場合は元ネタ=調整パラメータになってしまうので、ここはわかりやすくクライアント側で計算することにしました。

クライアントとしてUnity(WebGL)を選んだ理由は、もちろんターゲットの環境に合わせたのですが、こういうのがあってもいいかなという動機もあります。

以下の順で説明していきます。

  • 動作環境
  • 実行手順
  • プログラム説明

動作環境

バージョン情報

筆者の動作環境とソフトウェアのバージョンです。

  • Windows 10 Pro 1909
  • Google Chrome 81.0.4044.113
  • Unity 2019.3.1f1
  • Miniconda3(conda 4.8.3)
  • python 3.7.7
  • torch 1.5.0
  • torchvision 0.6.0
  • ax-platform 0.1.11
  • Flask 1.1.2
  • gevent 20.4.0
  • gevent-websocket 0.10.1
GitHub

GitHubで実際に動作するUnityプロジェクトを公開しています。
https://github.com/morikatron/AxSample.git

Unity WebGLビルド
  1. git clone [URL]
  2. Unity EditorでAxSampleプロジェクトを開く
  3. Assets/Scenes/SampleSceneを選択
  4. File>Build Settings...を開く
    • Scenes In Build>Add Open Scenesを実行
    • Platform>WebGL>Switch Platformを実行
    • Player Settings...>Player>WebGL Settings>Resolution and Presentation>WebGL Template>Default2を選択
    • Buildを実行
      • ビルド先フォルダとして新しいフォルダWebGLを作成し選択
  5. ビルドエラーがなければOK
Pythonライブラリ
  • 一部前述のインストールと重複します。
  • PyTorchのインストールコマンドは動作環境によって異なります。
    詳細はPyTorch公式を参照してください。
    https://pytorch.org/get-started/locally/

仮想環境
conda create -n ax-sample python=3.7

conda activate ax-sample

PyTorch
conda install pytorch torchvision cpuonly -c pytorch

Ax他
pip install ax-platform Flask gevent gevent-websocket
または
cd AxSample; pip install -r requirements.txt

実行手順
  1. 仮想環境を起動
    conda activate ax-sample
  2. WebGLビルド先のフォルダに移動
    cd AxSample/WebGL
  3. サーバーアプリ起動
    python app.py
  4. ブラウザからURLにアクセス
    http://localhost:8080/ または http://127.0.0.1:8080/
    • URLにアクセスするとすぐ最適化処理がスタートします。
    • Unityの画面には何も表示されません。
    • 適宜ブラウザおよびサーバーのコンソールでログを確認してください。

プログラム説明

概要
  • システムはクライアントとサーバーで構成し、 クライアントに評価関数の本体を、サーバーにAxの全体ループを配置しました。
  • クライアントとサーバー間は主にWebSocketで通信し、以下のコマンドでやりとりします。
    • サーバー→クライアント
      • OPTIMIZER_PARAM(パラメータ通知)
      • OPTIMIZER_BEST_PARAM(ベストパラメータ通知=最適化終了)
    • クライアント→サーバー
      • OPTIMIZER_STEP(評価値通知=次ステップ用パラメータ要求)
      • OPTIMIZER_END(最適化終了応答)
処理フロー
  1. サーバー起動(python app.py)
  2. ブラウザでアプリをロード
    http://localhost:8080/ または http://127.0.0.1:8080/
  3. サーバー
    • パラメータを通知(OPTIMIZE_PARAM)
  4. クライアント
    • パラメータから評価値を計算し結果を通知(OPTIMIZE_STEP)
  5. サーバー
    • 次のパラメータを通知(OPTIMIZE_PARAM)
  6. (以後、OPTIMIZE_PARAM~OPTIMIZE_STEPを規定回数繰り返し)
  7. サーバー
    • 規定回数終了
    • ベストパラメータを通知(OPTIMIZE_BEST_PARAM)
  8. クライアント
    • OPTIMIZE終了を通知(OPTIMIZE_END)
    • WebSocket切断
  9. サーバー
    • WebSocket切断
主なプログラム
  • サーバー
    • app.py
    • booth_loop.py
  • クライアント
    • Optimizer.cs
    • OptimizerLib.jslib
    • OptimizerExt.js
app.py
  • アプリケーションサーバー(Flask)
  • クライアントとの通信処理(gevent-websocket)
  • 最適化処理(Optimizer)のインスタンスを生成・保持
import json

from gevent.pywsgi import WSGIServer
from geventwebsocket.handler import WebSocketHandler

from flask import Flask, request, render_template, Blueprint

from booth_loop import Optimizer

# Unity WebGLのテンプレートのフォルダ構成に合わせるため
# Blueprintを利用してstatic_folderを追加する
app = Flask(__name__, static_folder="TemplateData")
app.config.from_object(__name__)
blueprint = Blueprint('optimizer', __name__, static_folder='Build',
                      template_folder=".")
app.register_blueprint(blueprint)


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/optimizer')
def optimizer():
    if request.environ.get('wsgi.websocket'):
        # WebSocket接続
        ws = request.environ['wsgi.websocket']
        # パラメータ通知コールバック
        def cb_step(x):
            # メッセージ作成:パラメータ通知
            data = {'m': 'OPTIMIZER_PARAM', 'x': x.tolist()}
            print(data)
            # WebSocektメッセージ送信
            ws.send(json.dumps(data))
            print('Waiting...')
            # WebSocketメッセージ待ち受け
            message = ws.receive()
            score = 0
            if message:
                print(message)
                data2 = json.loads(message)
                score = data2['score']
            return score

        # ベストパラメータ通知コールバック
        def cb_end(x):
            # メッセージ作成:ベストパラメータ通知
            data = {'m': 'OPTIMIZER_BEST_PARAM', 'x': x.tolist()}
            print(data)
            # WebSocektメッセージ送信
            ws.send(json.dumps(data))
            print('Waiting...')
            # WebSocketメッセージ待ち受け
            message = ws.receive()
            if message:
                print(message)

        obj = Optimizer(cb_step, cb_end)
        obj.optimize()
        ws.close()
    return "OK"


if __name__ == '__main__':
    app.debug = True

    host = 'localhost'
    port = 8080

    host_port = (host, port)
    server = WSGIServer(
        host_port,
        app,
        handler_class=WebSocketHandler
    )
    server.serve_forever()
booth_loop.py
  • 最適化処理、全体ループの管理(Ax、BoTorch)
import numpy as np
from ax import optimize


# Optimizerクラス
class Optimizer:
    def __init__(self, cb_step=None, cb_end=None):
        self.cb_step = cb_step  # ステップ・コールバック
        self.cb_end = cb_end  # 終了コールバック

    # 評価関数
    def evaluation_function(self, parameterization):
        l = len(parameterization)
        # 最適化対象パラメータの値を配列にセットする
        x = np.array([parameterization.get(f"x{i+1}") for i in range(l)])
        if self.cb_step:
            # パラメータ通知コールバック
            # クライアントへ今回のパラメータを通知し、評価値の受信を待ち受ける
            score = self.cb_step(x)
        else:
            # スクリプト単体で動作させる場合
            x1 = x[0]
            x2 = x[1]
            # Booth Function
            score = (x1 + 2*x2 - 7)**2 + (2*x1 + x2 - 5)**2
        return {"score": (score, 0.0)}

    # 最適化関数
    def optimize(self):
        # ax.optimizeをコールする
        # ax.optimizeは最適化が終了するまで復帰して来ない
        # ax.optimize内でevaluation_functionが試行回数分コールされる
        best_parameters, best_values, experiment, model = optimize(
            parameters=[  # 最適化対象パラメータの定義
                {
                "name": "x1",
                "type": "range",
                "bounds": [-10.0, 10.0],
                },
                {
                "name": "x2",
                "type": "range",
                "bounds": [-10.0, 10.0],
                },
            ],
            # 評価関数(オリジナルコード)
            # evaluation_function=lambda p: (p["x1"] + 2*p["x2"] - 7)**2 + (2*p["x1"] + p["x2"] - 5)**2,
            evaluation_function=self.evaluation_function,  # 評価関数
            objective_name="score",  # 評価指標の名前
            minimize=True,  # 最適化=最小化
        )
        # best_parameters contains {'x1': 1.02, 'x2': 2.97};
        # the global min is (1, 3)
        # ベストパラメータ=最適解に最も近づいた値
        print(f"best_parameters = {best_parameters}")

        # その他指標
        means, covariances = best_values
        print(f"means = {means}")

        if self.cb_end:
            # 終了コールバック:Unityアプリへベストパラメータを通知する
            x1 = best_parameters['x1']
            x2 = best_parameters['x2']
            x = np.array([x1, x2])
            self.cb_end(x)


if __name__ == '__main__':
    obj = Optimizer()
    obj.optimize()
Optimizer.cs
  • クライアント側のメインプログラム(Unity、WebGL)
  • サーバーからパラメータを貰い評価値を計算する(評価関数)
  • サンプルのため余分なGUIは省略
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

using System.Runtime.InteropServices;


// Optimizerクラス
// OptimizerオブジェクトにAttachされている
public class Optimizer : MonoBehaviour
{
    // DLL関数定義:最適化開始
    [DllImport("__Internal")]
    private static extern void OptimizerStart();

    // DLL関数定義:評価値通知
    [DllImport("__Internal")]
    private static extern void OptimizerStep(float score);

    // DLL関数定義:最適化終了
    [DllImport("__Internal")]
    private static extern void OptimizerEnd();

    // Start is called before the first frame update
    void Start()
    {
        // 最適化開始
        OptimizerStart();
    }

    // Update is called once per frame
    void Update()
    {

    }

    // パラメータ通知(DLLから呼ばれる)
    void OnParam(string str)
    {
        // CSVをパラメータに分解
        string[] arr =  str.Split(',');
        float x1 = float.Parse(arr[0]);
        float x2 = float.Parse(arr[1]);
        Debug.Log("OnParam called");
        Debug.LogFormat("x1={0}, x2={1}", x1, x2);
        // Booth Function
        float score = Mathf.Pow((x1 + 2*x2 - 7),2) + Mathf.Pow((2*x1 + x2 - 5),2);
        // 評価値通知
        OptimizerStep(score);
    }

    // ベストパラメータ通知(DLLから呼ばれる)
    void OnBestParam(string str)
    {
        // CSVをベストパラメータに分解
        string[] arr =  str.Split(',');
        float x1 = float.Parse(arr[0]);
        float x2 = float.Parse(arr[1]);
        Debug.Log("OnBestParam called");
        Debug.LogFormat("x1={0}, x2={1}", x1, x2);
        // 最適化終了
        OptimizerEnd();
    }

    // エラー通知(DLLから呼ばれる)
    void OnError(string str)
    {
        Debug.Log("OnError called");
        // エラー理由or箇所
        Debug.LogFormat("str = {0}", str);
        // 最適化終了
        OptimizerEnd();
    }
}
OptimizerLib.jslib
  • Unity WebGL DLLの本体(JavaScript)
  • コンパイル頻度を下げるため必要最低限の処理のみ記述
// DLLスクリプト
// ビルドでコンパイルされる
mergeInto(LibraryManager.library, {

  // 最適化開始
  OptimizerStart: function () {
    // 外部スクリプトの対応関数をコール
    OptimizerExt.OptimizerStart(function (oname, fname, param) {
      // Unityスクリプトの関数をコール
      SendMessage(oname, fname, param);
    });
  },

  // 評価値通知
  OptimizerStep: function (score) {
    // 外部スクリプトの対応関数をコール
    OptimizerExt.OptimizerStep(score, function (oname, fname, param) {
      // Unityスクリプトの関数をコール
      SendMessage(oname, fname, param);
    });
  },

  // 最適化終了
  OptimizerEnd: function () {
    // 外部スクリプトの対応関数をコール
    OptimizerExt.OptimizerEnd(function (oname, fname, param) {
      // Unityスクリプトの関数をコール
      SendMessage(oname, fname, param);
    });
  },

});
OptimizerExt.js
  • Unity WebGL DLLを補完する外部スクリプト(JavaScript)
  • サーバーとの通信処理(WebSocket)
  • コンパイル対象外であるためほとんどの処理をこちらに記述
// 外部スクリプト
// ビルドでコンパイルされない=ビルドなしで編集可能
var OptimizerExt = {

    // 最適化開始
    OptimizerStart: function (callback) {
        console.log('OptimizerStart');
        // URL作成
        //var host = "ws://localhost:8080/optimizer";
        var protocol = location.protocol == 'https:' ? 'wss:' : 'ws:';
        var host = protocol + "//" + location.host + location.pathname + "/optimizer";
        console.log(host);
        try {
            // WebSocket接続
            this.g_ws = new WebSocket(host);
        }
        catch (e) {
            window.alert(e);
            console.error(e);
            callback('Optimizer', 'OnError', 'OptimizerStart');
        }
        this.g_ws.onmessage = function (message) {
            // WebSocket受信ハンドラ
            var message_data = JSON.parse(message.data);
            var m = message_data['m'];
            var x = message_data['x'];
            var str = '';
            for (var i = 0; i < x.length; i++) {
                if (i == 0) {
                    str += x[i];
                } else {
                    str += ',' + x[i];
                }
            }
            if (m == 'OPTIMIZER_PARAM') {
                // パラメータ通知
                callback('Optimizer', 'OnParam', str);
            } else if (m == 'OPTIMIZER_BEST_PARAM') {
                // ベストパラメータ通知
                callback('Optimizer', 'OnBestParam', str);
            } else {
                callback('Optimizer', 'OnError', 'OptimizerStart');
            }
        };
    },

    // 評価値通知
    OptimizerStep: function (score, callback) {
        console.log('OptimizerStep');
        // メッセージ作成:評価値通知
        var data = { 'score': score, 'm': 'OPTIMIZER_STEP' };
        var message = JSON.stringify(data);
        try {
            // WebSocketメッセージ送信
            this.g_ws.send(message);
        }
        catch (e) {
            window.alert(e);
            console.error(e);
            callback('Optimizer', 'OnError', 'OptimizerStep');
        }
    },

    // 最適化終了
    OptimizerEnd: function (callback) {
        console.log('OptimizerEnd');
        // メッセージ作成:最適化終了
        var data = { 'm': 'OPTIMIZER_END' };
        var message = JSON.stringify(data);
        try {
            // WebSocketメッセージ送信
            this.g_ws.send(message);
        }
        catch (e) {
            window.alert(e);
            console.error(e);
            callback('Optimizer', 'OnError', 'OptimizerStep');
        }
        this.g_ws.onclose = function (event) {
            // WebSocket切断ハンドラ
            console.log(event);
            console.log('Disconnected');
        };
        try {
            // WebSocket切断
            this.g_ws.close()
        }
        catch (e) {
            window.alert(e);
            console.error(e);
            callback('Optimizer', 'OnError', 'OptimizerEnd');
        }
    },
}

最後に

今回はここまでですが、パラメータ自動調整の結果が出たら、あらためてこのブログで報告したいと思います。

最後まで読んでいただきありがとうございました。

それではまた。