はじめまして、モリカトロン・プログラマのはっとりです。
今日からブログを始めます。
初回はUnityとPythonのAx(Adaptive Experimentation Platform)を連携させてみた件について書こうと思います。
はじめに
ことの発端は
- シミュレーションのパラメータ自動調整をやってみよう
- アルゴリズムはベイズ最適化を使ってみよう
- Axというベイズ最適化のツールがあるらしい
ということで
- まずはUnityとAxを連携させるサンプルを作ろう
となりました。
その次にくるべき
- 実際にシミュレーションを最適化させてみよう
は、追い追い報告したいと思います。
Axとは
Axとは、数理モデルの最適化を自動化するPythonライブラリです。
内部でBoTorchを使用しています。
BoTorchとは、ベイズ最適化(Bayesian Optimization)のPythonライブラリです。
いずれも昨年Facebookから公開されました。
BoTorchは単体でも使用できますが、Axを使えばBoTorchの存在を意識することなく、実験自体に集中することができます。
https://ax.dev/
https://botorch.org/
ベイズ最適化とは
ベイズ最適化とは、未知の関数が最大値(最小値)を取るような入力を、ガウス過程回帰を利用して効率よく探索するための手法です。
これ以上の説明は筆者には荷が重いので、詳しくは他の解説記事に譲りますが、筆者自身は主に下記の書籍を参考にしました。
www.kspub.co.jp
Axの超簡単なサンプル
本題に入る前に、まずはAxの超簡単なサンプルを見てみましょう。
これは公式サイトの「GET STARTED」で紹介されているものと同じ内容です。
https://ax.dev/#quickstart
インストール
PyTorchとAxのインストール
conda install pytorch torchvision -c pytorch # OSX only
pip install ax-platform # all systems
実際には依存関係で他のライブラリも自動的にインストールされますが、ユーザーが手で入力するのはこの2行のみです。
BoTorchもAxをインストールすると依存ライブラリとして自動的にインストールされます。
PyTorchのインストールコマンドは動作させる環境によって異なるので、詳しくはPyTorchの公式サイトを参照してください。
https://pytorch.org/get-started/locally/
スクリプト
>>> from ax import optimize >>> best_parameters, best_values, experiment, model = optimize( parameters=[ { "name": "x1", "type": "range", "bounds": [-10.0, 10.0], }, { "name": "x2", "type": "range", "bounds": [-10.0, 10.0], }, ], # Booth function evaluation_function=lambda p: (p["x1"] + 2*p["x2"] - 7)**2 + (2*p["x1"] + p["x2"] - 5)**2, minimize=True, )
>>> best_parameters
実行結果
実際に筆者の環境で動作させてみました。
[INFO 04-16 17:28:26] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting. [INFO 04-16 17:28:26] ax.service.managed_loop: Started full optimization with 20 steps. [INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 1... [INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 2... [INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 3... [INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 4... [INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 5... [INFO 04-16 17:28:26] ax.service.managed_loop: Running optimization trial 6... [INFO 04-16 17:28:27] ax.service.managed_loop: Running optimization trial 7... [INFO 04-16 17:28:28] ax.service.managed_loop: Running optimization trial 8... [INFO 04-16 17:28:30] ax.service.managed_loop: Running optimization trial 9... [INFO 04-16 17:28:31] ax.service.managed_loop: Running optimization trial 10... [INFO 04-16 17:28:32] ax.service.managed_loop: Running optimization trial 11... [INFO 04-16 17:28:33] ax.service.managed_loop: Running optimization trial 12... [INFO 04-16 17:28:34] ax.service.managed_loop: Running optimization trial 13... [INFO 04-16 17:28:36] ax.service.managed_loop: Running optimization trial 14... [INFO 04-16 17:28:37] ax.service.managed_loop: Running optimization trial 15... [INFO 04-16 17:28:39] ax.service.managed_loop: Running optimization trial 16... [INFO 04-16 17:28:40] ax.service.managed_loop: Running optimization trial 17... [INFO 04-16 17:28:42] ax.service.managed_loop: Running optimization trial 18... [INFO 04-16 17:28:43] ax.service.managed_loop: Running optimization trial 19... [INFO 04-16 17:28:44] ax.service.managed_loop: Running optimization trial 20...
{'x1': 0.9242265891515515, 'x2': 3.0831835290815146}
このサンプルで扱っている最適化問題はBooth Functionというもので最適化アルゴリズムのベンチマークの一つです。 上の結果では20回の試行で最適解(x1=1,x2=3)にかなり近づいていることがわかります。 ちなみに筆者が30回で試したところほぼ最適解に辿り着きました。
Booth Functionについては下記を参照してください。
https://www.sfu.ca/~ssurjano/booth.html
APIについて
このサンプルはAxの3つある書き方(API)のうち、最もシンプルな
- Loop API
で書かれています。上の例では
- 最適化対象のパラメータ(parameters:例ではx1、x2の二つ)
- 評価関数(evaluation_function:例ではlambda式)
- 目標を最小化するか(minimize:例ではTrue最小化)
を定義しているのみです。 これ以外にもトライアル回数など、指定できるパラメータはいくつかあります。 Loop APIで十分というケースはありそうですが、より細かなカスタマイズをする場合は、他の2つのAPI
- Service API
- Developer API
を使用します。 これらのAPIを利用すると実験の統計データの保存や実験状態の保存・復元などが可能になります。
各APIの詳細については公式サイトのチュートリアルを参照してください。
https://ax.dev/tutorials/
Unity~Ax連携のサンプルプログラム
ここからが今回の本題です。
ターゲットへの適用前に、本当にできるの?というのを確認するために、技術調査用にUnity(WebGL)とAxを連携させるサンプルプログラムを作成しました。
最適化問題は前述のBooth Functionを使用しました。
公式サンプルと違う点は、評価関数(evaluation_function)をクライアント側に持たせたことです。
本来なら評価計算に必要な元ネタを貰ってサーバー側で評価計算してもいいのですが、Booth Functionの場合は元ネタ=調整パラメータになってしまうので、ここはわかりやすくクライアント側で計算することにしました。
クライアントとしてUnity(WebGL)を選んだ理由は、もちろんターゲットの環境に合わせたのですが、こういうのがあってもいいかなという動機もあります。
以下の順で説明していきます。
- 動作環境
- 実行手順
- プログラム説明
動作環境
バージョン情報
筆者の動作環境とソフトウェアのバージョンです。
- Windows 10 Pro 1909
- Google Chrome 81.0.4044.113
- Unity 2019.3.1f1
- Miniconda3(conda 4.8.3)
- python 3.7.7
- torch 1.5.0
- torchvision 0.6.0
- ax-platform 0.1.11
- Flask 1.1.2
- gevent 20.4.0
- gevent-websocket 0.10.1
GitHub
GitHubで実際に動作するUnityプロジェクトを公開しています。
https://github.com/morikatron/AxSample.git
Unity WebGLビルド
- git clone [URL]
- Unity EditorでAxSampleプロジェクトを開く
- Assets/Scenes/SampleSceneを選択
- File>Build Settings...を開く
- Scenes In Build>Add Open Scenesを実行
- Platform>WebGL>Switch Platformを実行
- Player Settings...>Player>WebGL Settings>Resolution and Presentation>WebGL Template>Default2を選択
- Buildを実行
- ビルド先フォルダとして新しいフォルダWebGLを作成し選択
- ビルドエラーがなければOK
Pythonライブラリ
- 一部前述のインストールと重複します。
- PyTorchのインストールコマンドは動作環境によって異なります。
詳細はPyTorch公式を参照してください。
https://pytorch.org/get-started/locally/
仮想環境
conda create -n ax-sample python=3.7
conda activate ax-sample
PyTorch
conda install pytorch torchvision cpuonly -c pytorch
Ax他
pip install ax-platform Flask gevent gevent-websocket
または
cd AxSample; pip install -r requirements.txt
実行手順
- 仮想環境を起動
conda activate ax-sample
- WebGLビルド先のフォルダに移動
cd AxSample/WebGL
- サーバーアプリ起動
python app.py
- ブラウザからURLにアクセス
http://localhost:8080/
またはhttp://127.0.0.1:8080/
- URLにアクセスするとすぐ最適化処理がスタートします。
- Unityの画面には何も表示されません。
- 適宜ブラウザおよびサーバーのコンソールでログを確認してください。
プログラム説明
概要
- システムはクライアントとサーバーで構成し、 クライアントに評価関数の本体を、サーバーにAxの全体ループを配置しました。
- クライアントとサーバー間は主にWebSocketで通信し、以下のコマンドでやりとりします。
- サーバー→クライアント
- OPTIMIZER_PARAM(パラメータ通知)
- OPTIMIZER_BEST_PARAM(ベストパラメータ通知=最適化終了)
- クライアント→サーバー
- OPTIMIZER_STEP(評価値通知=次ステップ用パラメータ要求)
- OPTIMIZER_END(最適化終了応答)
- サーバー→クライアント
処理フロー
- サーバー起動(python app.py)
- ブラウザでアプリをロード
(http://localhost:8080/ または http://127.0.0.1:8080/)- WebSocket接続
(http://localhost:8080/optimizer または http://127.0.0.1:8080/optimizer)
- WebSocket接続
- サーバー
- パラメータを通知(OPTIMIZE_PARAM)
- クライアント
- パラメータから評価値を計算し結果を通知(OPTIMIZE_STEP)
- サーバー
- 次のパラメータを通知(OPTIMIZE_PARAM)
- (以後、OPTIMIZE_PARAM~OPTIMIZE_STEPを規定回数繰り返し)
- サーバー
- 規定回数終了
- ベストパラメータを通知(OPTIMIZE_BEST_PARAM)
- クライアント
- OPTIMIZE終了を通知(OPTIMIZE_END)
- WebSocket切断
- サーバー
- WebSocket切断
主なプログラム
- サーバー
- app.py
- booth_loop.py
- クライアント
- Optimizer.cs
- OptimizerLib.jslib
- OptimizerExt.js
app.py
- アプリケーションサーバー(Flask)
- クライアントとの通信処理(gevent-websocket)
- 最適化処理(Optimizer)のインスタンスを生成・保持
import json from gevent.pywsgi import WSGIServer from geventwebsocket.handler import WebSocketHandler from flask import Flask, request, render_template, Blueprint from booth_loop import Optimizer # Unity WebGLのテンプレートのフォルダ構成に合わせるため # Blueprintを利用してstatic_folderを追加する app = Flask(__name__, static_folder="TemplateData") app.config.from_object(__name__) blueprint = Blueprint('optimizer', __name__, static_folder='Build', template_folder=".") app.register_blueprint(blueprint) @app.route('/') def index(): return render_template('index.html') @app.route('/optimizer') def optimizer(): if request.environ.get('wsgi.websocket'): # WebSocket接続 ws = request.environ['wsgi.websocket'] # パラメータ通知コールバック def cb_step(x): # メッセージ作成:パラメータ通知 data = {'m': 'OPTIMIZER_PARAM', 'x': x.tolist()} print(data) # WebSocektメッセージ送信 ws.send(json.dumps(data)) print('Waiting...') # WebSocketメッセージ待ち受け message = ws.receive() score = 0 if message: print(message) data2 = json.loads(message) score = data2['score'] return score # ベストパラメータ通知コールバック def cb_end(x): # メッセージ作成:ベストパラメータ通知 data = {'m': 'OPTIMIZER_BEST_PARAM', 'x': x.tolist()} print(data) # WebSocektメッセージ送信 ws.send(json.dumps(data)) print('Waiting...') # WebSocketメッセージ待ち受け message = ws.receive() if message: print(message) obj = Optimizer(cb_step, cb_end) obj.optimize() ws.close() return "OK" if __name__ == '__main__': app.debug = True host = 'localhost' port = 8080 host_port = (host, port) server = WSGIServer( host_port, app, handler_class=WebSocketHandler ) server.serve_forever()
booth_loop.py
- 最適化処理、全体ループの管理(Ax、BoTorch)
import numpy as np from ax import optimize # Optimizerクラス class Optimizer: def __init__(self, cb_step=None, cb_end=None): self.cb_step = cb_step # ステップ・コールバック self.cb_end = cb_end # 終了コールバック # 評価関数 def evaluation_function(self, parameterization): l = len(parameterization) # 最適化対象パラメータの値を配列にセットする x = np.array([parameterization.get(f"x{i+1}") for i in range(l)]) if self.cb_step: # パラメータ通知コールバック # クライアントへ今回のパラメータを通知し、評価値の受信を待ち受ける score = self.cb_step(x) else: # スクリプト単体で動作させる場合 x1 = x[0] x2 = x[1] # Booth Function score = (x1 + 2*x2 - 7)**2 + (2*x1 + x2 - 5)**2 return {"score": (score, 0.0)} # 最適化関数 def optimize(self): # ax.optimizeをコールする # ax.optimizeは最適化が終了するまで復帰して来ない # ax.optimize内でevaluation_functionが試行回数分コールされる best_parameters, best_values, experiment, model = optimize( parameters=[ # 最適化対象パラメータの定義 { "name": "x1", "type": "range", "bounds": [-10.0, 10.0], }, { "name": "x2", "type": "range", "bounds": [-10.0, 10.0], }, ], # 評価関数(オリジナルコード) # evaluation_function=lambda p: (p["x1"] + 2*p["x2"] - 7)**2 + (2*p["x1"] + p["x2"] - 5)**2, evaluation_function=self.evaluation_function, # 評価関数 objective_name="score", # 評価指標の名前 minimize=True, # 最適化=最小化 ) # best_parameters contains {'x1': 1.02, 'x2': 2.97}; # the global min is (1, 3) # ベストパラメータ=最適解に最も近づいた値 print(f"best_parameters = {best_parameters}") # その他指標 means, covariances = best_values print(f"means = {means}") if self.cb_end: # 終了コールバック:Unityアプリへベストパラメータを通知する x1 = best_parameters['x1'] x2 = best_parameters['x2'] x = np.array([x1, x2]) self.cb_end(x) if __name__ == '__main__': obj = Optimizer() obj.optimize()
Optimizer.cs
- クライアント側のメインプログラム(Unity、WebGL)
- サーバーからパラメータを貰い評価値を計算する(評価関数)
- サンプルのため余分なGUIは省略
using System.Collections; using System.Collections.Generic; using UnityEngine; using System.Runtime.InteropServices; // Optimizerクラス // OptimizerオブジェクトにAttachされている public class Optimizer : MonoBehaviour { // DLL関数定義:最適化開始 [DllImport("__Internal")] private static extern void OptimizerStart(); // DLL関数定義:評価値通知 [DllImport("__Internal")] private static extern void OptimizerStep(float score); // DLL関数定義:最適化終了 [DllImport("__Internal")] private static extern void OptimizerEnd(); // Start is called before the first frame update void Start() { // 最適化開始 OptimizerStart(); } // Update is called once per frame void Update() { } // パラメータ通知(DLLから呼ばれる) void OnParam(string str) { // CSVをパラメータに分解 string[] arr = str.Split(','); float x1 = float.Parse(arr[0]); float x2 = float.Parse(arr[1]); Debug.Log("OnParam called"); Debug.LogFormat("x1={0}, x2={1}", x1, x2); // Booth Function float score = Mathf.Pow((x1 + 2*x2 - 7),2) + Mathf.Pow((2*x1 + x2 - 5),2); // 評価値通知 OptimizerStep(score); } // ベストパラメータ通知(DLLから呼ばれる) void OnBestParam(string str) { // CSVをベストパラメータに分解 string[] arr = str.Split(','); float x1 = float.Parse(arr[0]); float x2 = float.Parse(arr[1]); Debug.Log("OnBestParam called"); Debug.LogFormat("x1={0}, x2={1}", x1, x2); // 最適化終了 OptimizerEnd(); } // エラー通知(DLLから呼ばれる) void OnError(string str) { Debug.Log("OnError called"); // エラー理由or箇所 Debug.LogFormat("str = {0}", str); // 最適化終了 OptimizerEnd(); } }
OptimizerLib.jslib
- Unity WebGL DLLの本体(JavaScript)
- コンパイル頻度を下げるため必要最低限の処理のみ記述
// DLLスクリプト // ビルドでコンパイルされる mergeInto(LibraryManager.library, { // 最適化開始 OptimizerStart: function () { // 外部スクリプトの対応関数をコール OptimizerExt.OptimizerStart(function (oname, fname, param) { // Unityスクリプトの関数をコール SendMessage(oname, fname, param); }); }, // 評価値通知 OptimizerStep: function (score) { // 外部スクリプトの対応関数をコール OptimizerExt.OptimizerStep(score, function (oname, fname, param) { // Unityスクリプトの関数をコール SendMessage(oname, fname, param); }); }, // 最適化終了 OptimizerEnd: function () { // 外部スクリプトの対応関数をコール OptimizerExt.OptimizerEnd(function (oname, fname, param) { // Unityスクリプトの関数をコール SendMessage(oname, fname, param); }); }, });
OptimizerExt.js
- Unity WebGL DLLを補完する外部スクリプト(JavaScript)
- サーバーとの通信処理(WebSocket)
- コンパイル対象外であるためほとんどの処理をこちらに記述
// 外部スクリプト // ビルドでコンパイルされない=ビルドなしで編集可能 var OptimizerExt = { // 最適化開始 OptimizerStart: function (callback) { console.log('OptimizerStart'); // URL作成 //var host = "ws://localhost:8080/optimizer"; var protocol = location.protocol == 'https:' ? 'wss:' : 'ws:'; var host = protocol + "//" + location.host + location.pathname + "/optimizer"; console.log(host); try { // WebSocket接続 this.g_ws = new WebSocket(host); } catch (e) { window.alert(e); console.error(e); callback('Optimizer', 'OnError', 'OptimizerStart'); } this.g_ws.onmessage = function (message) { // WebSocket受信ハンドラ var message_data = JSON.parse(message.data); var m = message_data['m']; var x = message_data['x']; var str = ''; for (var i = 0; i < x.length; i++) { if (i == 0) { str += x[i]; } else { str += ',' + x[i]; } } if (m == 'OPTIMIZER_PARAM') { // パラメータ通知 callback('Optimizer', 'OnParam', str); } else if (m == 'OPTIMIZER_BEST_PARAM') { // ベストパラメータ通知 callback('Optimizer', 'OnBestParam', str); } else { callback('Optimizer', 'OnError', 'OptimizerStart'); } }; }, // 評価値通知 OptimizerStep: function (score, callback) { console.log('OptimizerStep'); // メッセージ作成:評価値通知 var data = { 'score': score, 'm': 'OPTIMIZER_STEP' }; var message = JSON.stringify(data); try { // WebSocketメッセージ送信 this.g_ws.send(message); } catch (e) { window.alert(e); console.error(e); callback('Optimizer', 'OnError', 'OptimizerStep'); } }, // 最適化終了 OptimizerEnd: function (callback) { console.log('OptimizerEnd'); // メッセージ作成:最適化終了 var data = { 'm': 'OPTIMIZER_END' }; var message = JSON.stringify(data); try { // WebSocketメッセージ送信 this.g_ws.send(message); } catch (e) { window.alert(e); console.error(e); callback('Optimizer', 'OnError', 'OptimizerStep'); } this.g_ws.onclose = function (event) { // WebSocket切断ハンドラ console.log(event); console.log('Disconnected'); }; try { // WebSocket切断 this.g_ws.close() } catch (e) { window.alert(e); console.error(e); callback('Optimizer', 'OnError', 'OptimizerEnd'); } }, }
最後に
今回はここまでですが、パラメータ自動調整の結果が出たら、あらためてこのブログで報告したいと思います。
最後まで読んでいただきありがとうございました。
それではまた。