Ccmmutty logo
Commutty IT
8 min read

【強化学習】UnityでML-Agentsを使ったCartPole学習(後半)

https://cdn.magicode.io/media/notebox/f4dc24f1-557b-4ba4-b7da-fa49828cebc6.jpeg

概要

前回はCartPoleのモデルの作成を行いました. ここではml-agentを使って実際にモデルを学習してみましょう.
ML-Agentsのインストールは以下を参考にインストールしてください. https://github.com/Unity-Technologies/ml-agents

手順

モデルを学習するにあたって必要な手順は以下の通りです.
  1. 学習のためのスクリプトを記述
  2. 作成したモデルとスクリプトを紐づけ
  3. 学習の実行

学習のためのスクリプトを記述

まず,下のようにProjectフォルダ内で右クリックし,create => C# scriptを選んでスクリプトを作成しましょう.
スクリプトの名前は適当で大丈夫です.ここではCartPoleAgentとしました.
set_script
作成されたスクリプトを開くとVScodeなどのエディタが立ち上がると思います. エディタ内では以下のようにコードを記述してください.
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using System.Collections;
using System.Collections.Generic;

public class CartPoleAgent : Agent
{

    public GameObject pole;
    Rigidbody poleRB;
    Rigidbody cartRB;
    EnvironmentParameters m_ResetParams;

    public override void Initialize()
    {
        poleRB = pole.GetComponent<Rigidbody>();
        cartRB = gameObject.GetComponent<Rigidbody>();
        m_ResetParams = Academy.Instance.EnvironmentParameters;
        SetResetParameters();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(gameObject.transform.localPosition.z);
        sensor.AddObservation(cartRB.velocity.z);
        sensor.AddObservation(pole.transform.localRotation.eulerAngles.x);
        sensor.AddObservation(poleRB.angularVelocity.x);
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = actionBuffers.ContinuousActions[0];
        controlSignal.z = actionBuffers.ContinuousActions[1];
        var actionZ = 200f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
        cartRB.AddForce(new Vector3(0.0f, 0.0f, actionZ), ForceMode.Force);

        float cart_z = this.gameObject.transform.localPosition.z;
        float angle_x = pole.transform.localRotation.eulerAngles.x;

        if(180f < angle_x && angle_x < 360f)
        {
            angle_x = angle_x - 360f;
        }

        if((-180f < angle_x && angle_x < -45f) || (45f < angle_x && angle_x < 180f))
        {
            SetReward(-1.0f);
            EndEpisode();
        }
        else{
            SetReward(0.1f);
        }

        if(cart_z < -10f || 10f < cart_z)
        {
            SetReward(-1.0f);
            EndEpisode();
        }
    }
    public override void OnActionReceived(float[] verctorAction)
    {
        var actionZ = 200f * Mathf.Clamp(verctorAction[0], -1f, 1f);
        cartRB.AddForce(new Vector3(0.0f, 0.0f, actionZ), ForceMode.Force);

        float cart_z = this.gameObject.transform.localPosition.z;
        float angle_x = pole.transform.localRotation.eulerAngles.x;

        if(180f < angle_x && angle_x < 360f)
        {
            angle_x = angle_x - 360f;
        }

        if((-180f < angle_x && angle_x < -45f) || (45f < angle_x && angle_x < 180f))
        {
            SetReward(-1.0f);
            EndEpisode();
        }
        else{
            SetReward(0.1f);
        }

        if(cart_z < -10f || 10f < cart_z)
        {
            SetReward(-1.0f);
            EndEpisode();
        }
    }

    public override void OnEpisodeBegin()
    {
        gameObject.transform.localPosition = new Vector3(0f, 0f, 0f);
        pole.transform.localPosition = new Vector3(0f, 2.5f, 0f);
        pole.transform.localRotation = Quaternion.Euler(0f, 0f, 0f);
        poleRB.angularVelocity = new Vector3(0f, 0f, 0f);
        poleRB.velocity = new Vector3(0f, 0f, 0f);

        poleRB.angularVelocity = new Vector3(Random.Range(-0.1f, 0.1f), 0f, 0f);
        SetResetParameters();
    }


    public void SetPole()
    {
        poleRB.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
        pole.transform.localScale = new Vector3(0.4f, 2f, 0.4f);
    }

    public void SetResetParameters()
    {
        SetPole();
    }
}
簡単に上記のコードの説明をします.

public override void Initialize()

学習のための初期設定を行います. 具体的にはスクリプト内の変数と実際のオブジェクトを紐づけることなどを行っています.

public override void CollectObservations(VectorSensor sensor)

エージェントが学習に必要な情報をここで取得します. 今回使う情報は以下の通りです.
  • カートの位置
  • カートの速度
  • 棒の角度
  • 棒の角速度

public override void OnActionReceived(ActionBuffers actionBuffers)

エージェントの行動・エピソードの終了判定・また報酬をエージェントに与えることなどをここで行います. 引数には,強化学習で学習したモデルが出力した行動をactionBuffersとして受け取ります.

public override void OnEpisodeBegin()

学習のために,エピソードが始まるときの設定を行います. 具体的には,エージェントの位置をリセット,棒をランダムに傾ける,といったことを行います.

作成したモデルとエージェントの紐付け

次に作成したスクリプトをエージェントに追加します.またinspectorのPoleという位置にhierarchyウィンドウからPoleオブジェクトをドラッグしてください.
またAdd Component から、Behavior Parameters と Decision Requester を追加しましょう.
各パラメータは以下のように設定してください. insp
学習に必要な準備は終了です.次は実際に学習してみましょう!

学習の実行

まず,configの設定が必要になります.
configの位置はダウンロードしたリポジトリのml-agents/config/ppo/になります.
ML-Agentsのダウンロードは以下から行ってください.
https://github.com/Unity-Technologies/ml-agents
configの内容は以下のようにして作成してください.またファイルの名前は「2D_cartpole.yaml」としました.
2行目のCartPoleという名前はbehavior parametersのbehavior Nameと同じ名前にしましょう.
behaviors:
  CartPole:
    trainer_type: ppo
    hyperparameters:
      batch_size: 32
      buffer_size: 12000
      learning_rate: 0.0003
      beta: 0.001
      epsilon: 0.2
      lambd: 0.99
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 128
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 200000
    time_horizon: 1000
    summary_freq: 12000
あとはターミナルからコマンドを入力して学習を行います.
インストールしたリポジトリのml-agentsディレクトリから以下を実行しましょう.
mlagents-learn ../config/ppo/2D_cartpole.yaml  --run-id=2Dcartpole --train
無事に実行できると以下のような画面が表示されます. terminal
そして,出力されているようにUnity側で実行ボタン(矢印のマーク)を押すと実際に学習が始まります.
自分の環境では学習は20万ステップほどで完了し,棒を倒さずカートが動く様子が確認できました!

Discussion

コメントにはログインが必要です。