0525 AI 펭귄

2021. 5. 25. 10:23unity/AI



Reinforcement Learning Penguins (Part 2/4) | Unity ML-Agents — Immersive Limit



유니티에서 ML Agent를 임포트 한 후에 

디스크로 위의 경로에서 json 파일을 import 하면 
ML Agent가 업데이트 된다.

그 후에 대한 것은 위의 documentation 참고

뺑이 도는 중
훈련이 완료되어 모델을 끼워넣고 재생한 결과.


using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Fish : MonoBehaviour
    public float fishSpeed;
    private float randomizedSpeed = 0f;
    private float nextActionTime = -1f;
    private Vector3 targetPosition;

    private void FixedUpdate()
        if (this.fishSpeed > 0)

    private void Swim()

        if (Time.fixedTime >= this.nextActionTime)
            //속도 설정 
            randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);

            //펭귄 공간의 물고기가 갈수있는 위치중 랜덤으로 타겟 위치 설정 
            targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);

            //회전 설정 
            transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);

            //타겟 위치까지 가는 시간 계산 
            float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
            nextActionTime = Time.fixedTime + timeToGetThere;
            //이동 벡터 구하기 
            Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;

            if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
                transform.position += moveVector;
                //현재 위치를 타겟 위치로 변경 
                transform.position = targetPosition;

                //다음행동시간 초기화 
                nextActionTime = Time.fixedTime;
using System.Collections.Generic;
using UnityEngine;

public class PenguinArea : MonoBehaviour
    public PenguinAgent penguinAgent;
    public Fish fishPrefab;
    private List<GameObject> fishList;

    // Start is called before the first frame update
    void Start()

    public void ResetArea()
        SpawnFish(4, .5f);

    public int FishRemaining
        get { return fishList.Count; }

    //모든 물고기 제거 
    private void RemoveAllFish()
        if (this.fishList != null)
            foreach (var go in this.fishList)
        //새로운 컬렉션을 만든다 
        this.fishList = new List<GameObject>();

    private void SpawnFish(int count, float fishSpeed)
        for (int i = 0; i < count; i++)
            // Spawn and place the fish
            GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
            fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

            // Set the fish's parent to this area's transform

            // Keep track of the fish

            // Set the fish speed
            fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;

    public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
        float radius = minRadius;
        float angle = minAngle;

        if (maxRadius > minRadius)
            // Pick a random radius
            radius = UnityEngine.Random.Range(minRadius, maxRadius);

        if (maxAngle > minAngle)
            // Pick a random angle
            angle = UnityEngine.Random.Range(minAngle, maxAngle);

        // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
        return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;

    private void PlacePenguin()
        Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
        penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

    public void RemoveSpecificFish(GameObject fishObject)

    public int GetFishCount() 
        return this.fishList.Count;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class PenguinAgent : Agent
    public float moveSpeed = 5f;
    public float turnSpeed = 180f;
    private PenguinArea penguinArea;
    new private Rigidbody rigidbody;
    private bool isFull; // If true, penguin has a full stomach
    private int count;

    public override void Initialize()
        penguinArea = GetComponentInParent<PenguinArea>();
        rigidbody = GetComponent<Rigidbody>();

    public override void OnActionReceived(ActionBuffers actionBuffers)
        // Convert the first action to forward movement
        float forwardAmount = actionBuffers.DiscreteActions[0];

        // Convert the second action to turning left or right
        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
            turnAmount = -1f;
        else if (actionBuffers.DiscreteActions[1] == 2f)
            turnAmount = 1f;

        // Apply movement
        rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);

    private void OnCollisionEnter(Collision collision)
        if (collision.transform.CompareTag("fish"))
            // Try to eat the fish

    private void EatFish(GameObject fishObject)

        var count = this.penguinArea.GetFishCount();
        if (count == 5)

    public override void CollectObservations(VectorSensor sensor)
        // Whether the penguin has eaten a fish (1 float = 1 value)

        // Direction penguin is facing (1 Vector3 = 3 values)

        // 1 + 1 + 3 + 3 = 8 total values

    public override void Heuristic(in ActionBuffers actionsOut)
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
            // move forward
            forwardAction = 1;
        if (Input.GetKey(KeyCode.A))
            // turn left
            turnAction = 1;
        else if (Input.GetKey(KeyCode.D))
            // turn right
            turnAction = 2;

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    public override void OnEpisodeBegin()
        isFull = false;


약간의 응용 추가

먹으면 안되는 생선을 하나 배치해서 그걸 먹으면 리워드를 -1로 둠.

훈련 완료 영상

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class RottenFish : MonoBehaviour
    public float fishSpeed;
    private float randomizedSpeed = 0f;
    private float nextActionTime = -1f;
    private Vector3 targetPosition;

    private void FixedUpdate()
        if (this.fishSpeed > 0)

    private void Swim()

        if (Time.fixedTime >= this.nextActionTime)
            //속도 설정 
            randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);

            //펭귄 공간의 물고기가 갈수있는 위치중 랜덤으로 타겟 위치 설정 
            targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);

            //회전 설정 
            transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);

            //타겟 위치까지 가는 시간 계산 
            float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
            nextActionTime = Time.fixedTime + timeToGetThere;
            //이동 벡터 구하기 
            Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;

            if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
                transform.position += moveVector;
                //현재 위치를 타겟 위치로 변경 
                transform.position = targetPosition;

                //다음행동시간 초기화 
                nextActionTime = Time.fixedTime;
using System.Collections.Generic;
using UnityEngine;

public class PenguinArea : MonoBehaviour
    public PenguinAgent penguinAgent;
    public Fish fishPrefab;
    public RottenFish rottenFishPrefab;
    public List<GameObject> fishList;

    // Start is called before the first frame update
    void Start()

    public void ResetArea()
        SpawnFish(4, .5f);

    public int FishRemaining
        get { return fishList.Count; }

    //모든 물고기 제거 
    private void RemoveAllFish()
        if (this.fishList != null)
            foreach (var go in this.fishList)
        //새로운 컬렉션을 만든다 
        this.fishList = new List<GameObject>();

    private void SpawnFish(int count, float fishSpeed)
        for (int i = 0; i < count; i++)
            // Spawn and place the fish
            GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
            fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

            // Set the fish's parent to this area's transform

            // Keep track of the fish

            // Set the fish speed
            fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;

    private void SpawnRottenFish(int count, float fishSpeed)
        for (int i =0; i< count; i++) {

            GameObject rottenFishOb = Instantiate<GameObject>(rottenFishPrefab.gameObject);
            rottenFishOb.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            rottenFishOb.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);



            rottenFishOb.GetComponent<RottenFish>().fishSpeed = fishSpeed;

    public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
        float radius = minRadius;
        float angle = minAngle;

        if (maxRadius > minRadius)
            // Pick a random radius
            radius = UnityEngine.Random.Range(minRadius, maxRadius);

        if (maxAngle > minAngle)
            // Pick a random angle
            angle = UnityEngine.Random.Range(minAngle, maxAngle);

        // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
        return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;

    private void PlacePenguin()
        Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
        penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

    public void RemoveSpecificFish(GameObject fishObject)

    public int GetFishCount()
        return this.fishList.Count;

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class PenguinAgent : Agent
    public float moveSpeed = 5f;
    public float turnSpeed = 180f;
    private PenguinArea penguinArea;
    new private Rigidbody rigidbody;
    private bool isFull; // If true, penguin has a full stomach
    private int count;

    public override void Initialize()
        penguinArea = GetComponentInParent<PenguinArea>();
        rigidbody = GetComponent<Rigidbody>();

    public override void OnActionReceived(ActionBuffers actionBuffers)
        // Convert the first action to forward movement
        float forwardAmount = actionBuffers.DiscreteActions[0];

        // Convert the second action to turning left or right
        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
            turnAmount = -1f;
        else if (actionBuffers.DiscreteActions[1] == 2f)
            turnAmount = 1f;

        // Apply movement
        rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);

    private void OnCollisionEnter(Collision collision)
        if (collision.transform.CompareTag("fish"))
            // Try to eat the fish
        else if (collision.transform.CompareTag("rottenFish")) 

    private void EatFish(GameObject fishObject)

        int count = 0;
        foreach (var i in penguinArea.fishList)
            if (i.tag == "fish")

        if (count <= 0)

    private void EatRottenFish(GameObject rottenFishOb)

    public override void CollectObservations(VectorSensor sensor)
        // Whether the penguin has eaten a fish (1 float = 1 value)

        // Direction penguin is facing (1 Vector3 = 3 values)

        // 1 + 1 + 3 + 3 = 8 total values

    public override void Heuristic(in ActionBuffers actionsOut)
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
            // move forward
            forwardAction = 1;
        if (Input.GetKey(KeyCode.A))
            // turn left
            turnAction = 1;
        else if (Input.GetKey(KeyCode.D))
            // turn right
            turnAction = 2;

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    public override void OnEpisodeBegin()
        isFull = false;

'unity > AI' 카테고리의 다른 글

0517 AI Machine Running  (0) 2021.05.17