0525 AI 펭귄

2021. 5. 25. 10:23unity/AI

bhttps://www.immersivelimit.com/tutorials/reinforcement-learning-penguins-part-2-unity-ml-agents

 

Reinforcement Learning Penguins (Part 2/4) | Unity ML-Agents — Immersive Limit

 

www.immersivelimit.com

유니티에서 ML Agent를 임포트 한 후에 

디스크로 위의 경로에서 json 파일을 import 하면 
ML Agent가 업데이트 된다.

그 후에 대한 것은 위의 documentation 참고

뺑이 도는 중
훈련이 완료되어 모델을 끼워넣고 재생한 결과.

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Fish : MonoBehaviour
{
    public float fishSpeed;
    private float randomizedSpeed = 0f;
    private float nextActionTime = -1f;
    private Vector3 targetPosition;

    private void FixedUpdate()
    {
        if (this.fishSpeed > 0)
        {
            this.Swim();
        }
    }

    private void Swim()
    {

        if (Time.fixedTime >= this.nextActionTime)
        {
            //속도 설정 
            randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);

            //펭귄 공간의 물고기가 갈수있는 위치중 랜덤으로 타겟 위치 설정 
            targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);

            //회전 설정 
            transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);

            //타겟 위치까지 가는 시간 계산 
            float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
            nextActionTime = Time.fixedTime + timeToGetThere;
        }
        else
        {
            //이동 벡터 구하기 
            Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;

            //거리계산 
            if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
            {
                //이동 
                transform.position += moveVector;
            }
            else
            {
                //현재 위치를 타겟 위치로 변경 
                transform.position = targetPosition;

                //다음행동시간 초기화 
                nextActionTime = Time.fixedTime;
            }
        }
    }
}
using System.Collections.Generic;
using UnityEngine;

public class PenguinArea : MonoBehaviour
{
    public PenguinAgent penguinAgent;
    public Fish fishPrefab;
    private List<GameObject> fishList;

    // Start is called before the first frame update
    void Start()
    {
        ResetArea();
    }

    public void ResetArea()
    {
        RemoveAllFish();
        PlacePenguin();
        SpawnFish(4, .5f);
    }

    public int FishRemaining
    {
        get { return fishList.Count; }
    }

    //모든 물고기 제거 
    private void RemoveAllFish()
    {
        if (this.fishList != null)
        {
            foreach (var go in this.fishList)
            {
                Destroy(go);
            }
        }
        //새로운 컬렉션을 만든다 
        this.fishList = new List<GameObject>();
    }

    private void SpawnFish(int count, float fishSpeed)
    {
        for (int i = 0; i < count; i++)
        {
            // Spawn and place the fish
            GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
            fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

            // Set the fish's parent to this area's transform
            fishObject.transform.SetParent(transform);

            // Keep track of the fish
            fishList.Add(fishObject);

            // Set the fish speed
            fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
        }
    }

    public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
    {
        float radius = minRadius;
        float angle = minAngle;

        if (maxRadius > minRadius)
        {
            // Pick a random radius
            radius = UnityEngine.Random.Range(minRadius, maxRadius);
        }

        if (maxAngle > minAngle)
        {
            // Pick a random angle
            angle = UnityEngine.Random.Range(minAngle, maxAngle);
        }

        // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
        return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
    }


    private void PlacePenguin()
    {
        Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
        penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
    }

    public void RemoveSpecificFish(GameObject fishObject)
    {
        this.fishList.Remove(fishObject);
        Destroy(fishObject);
    }

    public int GetFishCount() 
    {
        return this.fishList.Count;
    }
    
}
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class PenguinAgent : Agent
{
    public float moveSpeed = 5f;
    public float turnSpeed = 180f;
    private PenguinArea penguinArea;
    new private Rigidbody rigidbody;
    private bool isFull; // If true, penguin has a full stomach
    private int count;

    public override void Initialize()
    {
        base.Initialize();
        penguinArea = GetComponentInParent<PenguinArea>();
        rigidbody = GetComponent<Rigidbody>();
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // Convert the first action to forward movement
        float forwardAmount = actionBuffers.DiscreteActions[0];

        // Convert the second action to turning left or right
        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
        {
            turnAmount = -1f;
        }
        else if (actionBuffers.DiscreteActions[1] == 2f)
        {
            turnAmount = 1f;
        }

        // Apply movement
        rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);
    }

    private void OnCollisionEnter(Collision collision)
    {
        if (collision.transform.CompareTag("fish"))
        {
            // Try to eat the fish
            EatFish(collision.gameObject);
        }
    }


    private void EatFish(GameObject fishObject)
    {
        AddReward(1f);
        penguinArea.RemoveSpecificFish(fishObject);

        var count = this.penguinArea.GetFishCount();
        if (count == 5)
        {
            this.EndEpisode();
            return;
        }
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Whether the penguin has eaten a fish (1 float = 1 value)
        sensor.AddObservation(isFull);

        // Direction penguin is facing (1 Vector3 = 3 values)
        sensor.AddObservation(transform.forward);

        // 1 + 1 + 3 + 3 = 8 total values
    }


    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
        {
            // move forward
            forwardAction = 1;
        }
        if (Input.GetKey(KeyCode.A))
        {
            // turn left
            turnAction = 1;
        }
        else if (Input.GetKey(KeyCode.D))
        {
            // turn right
            turnAction = 2;
        }

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    }
    public override void OnEpisodeBegin()
    {
        isFull = false;
        penguinArea.ResetArea();
    }
}

 


약간의 응용 추가

먹으면 안되는 생선을 하나 배치해서 그걸 먹으면 리워드를 -1로 둠.

훈련 완료 영상

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class RottenFish : MonoBehaviour
{
    public float fishSpeed;
    private float randomizedSpeed = 0f;
    private float nextActionTime = -1f;
    private Vector3 targetPosition;

    private void FixedUpdate()
    {
        if (this.fishSpeed > 0)
        {
            this.Swim();
        }
    }

    private void Swim()
    {

        if (Time.fixedTime >= this.nextActionTime)
        {
            //속도 설정 
            randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);

            //펭귄 공간의 물고기가 갈수있는 위치중 랜덤으로 타겟 위치 설정 
            targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);

            //회전 설정 
            transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);

            //타겟 위치까지 가는 시간 계산 
            float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
            nextActionTime = Time.fixedTime + timeToGetThere;
        }
        else
        {
            //이동 벡터 구하기 
            Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;

            //거리계산 
            if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
            {
                //이동 
                transform.position += moveVector;
            }
            else
            {
                //현재 위치를 타겟 위치로 변경 
                transform.position = targetPosition;

                //다음행동시간 초기화 
                nextActionTime = Time.fixedTime;
            }
        }
    }
}
using System.Collections.Generic;
using UnityEngine;

public class PenguinArea : MonoBehaviour
{
    public PenguinAgent penguinAgent;
    public Fish fishPrefab;
    public RottenFish rottenFishPrefab;
    public List<GameObject> fishList;

    // Start is called before the first frame update
    void Start()
    {
        ResetArea();
    }

    public void ResetArea()
    {
        RemoveAllFish();
        PlacePenguin();
        SpawnFish(4, .5f);
        SpawnRottenFish(1,.5f);
    }

    public int FishRemaining
    {
        get { return fishList.Count; }
    }

    //모든 물고기 제거 
    private void RemoveAllFish()
    {
        if (this.fishList != null)
        {
            foreach (var go in this.fishList)
            {
                Destroy(go);
            }
        }
        //새로운 컬렉션을 만든다 
        this.fishList = new List<GameObject>();
    }

    private void SpawnFish(int count, float fishSpeed)
    {
        for (int i = 0; i < count; i++)
        {
            // Spawn and place the fish
            GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
            fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

            // Set the fish's parent to this area's transform
            fishObject.transform.SetParent(transform);

            // Keep track of the fish
            fishList.Add(fishObject);

            // Set the fish speed
            fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
        }
    }

    private void SpawnRottenFish(int count, float fishSpeed)
    {
        for (int i =0; i< count; i++) {

            GameObject rottenFishOb = Instantiate<GameObject>(rottenFishPrefab.gameObject);
            rottenFishOb.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            rottenFishOb.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

            rottenFishOb.transform.SetParent(transform);

            fishList.Add(rottenFishOb);

            rottenFishOb.GetComponent<RottenFish>().fishSpeed = fishSpeed;
        }
        
    }

    public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
    {
        float radius = minRadius;
        float angle = minAngle;

        if (maxRadius > minRadius)
        {
            // Pick a random radius
            radius = UnityEngine.Random.Range(minRadius, maxRadius);
        }

        if (maxAngle > minAngle)
        {
            // Pick a random angle
            angle = UnityEngine.Random.Range(minAngle, maxAngle);
        }

        // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
        return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
    }


    private void PlacePenguin()
    {
        Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
        penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
    }

    public void RemoveSpecificFish(GameObject fishObject)
    {
        this.fishList.Remove(fishObject);
        Destroy(fishObject);
    }

    public int GetFishCount()
    {
        return this.fishList.Count;
    }

}
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class PenguinAgent : Agent
{
    public float moveSpeed = 5f;
    public float turnSpeed = 180f;
    private PenguinArea penguinArea;
    new private Rigidbody rigidbody;
    private bool isFull; // If true, penguin has a full stomach
    private int count;

    public override void Initialize()
    {
        base.Initialize();
        penguinArea = GetComponentInParent<PenguinArea>();
        rigidbody = GetComponent<Rigidbody>();
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // Convert the first action to forward movement
        float forwardAmount = actionBuffers.DiscreteActions[0];

        // Convert the second action to turning left or right
        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
        {
            turnAmount = -1f;
        }
        else if (actionBuffers.DiscreteActions[1] == 2f)
        {
            turnAmount = 1f;
        }

        // Apply movement
        rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);
    }

    private void OnCollisionEnter(Collision collision)
    {
        if (collision.transform.CompareTag("fish"))
        {
            // Try to eat the fish
            EatFish(collision.gameObject);
        }
        else if (collision.transform.CompareTag("rottenFish")) 
        {
            EatRottenFish(collision.gameObject);
        }
    }


    private void EatFish(GameObject fishObject)
    {
        penguinArea.RemoveSpecificFish(fishObject);
        AddReward(1f);

        int count = 0;
        foreach (var i in penguinArea.fishList)
        {
            if (i.tag == "fish")
            {
                count++;
            }
        }

        if (count <= 0)
        {
            this.EndEpisode();
        }
    }

    private void EatRottenFish(GameObject rottenFishOb)
    {
        AddReward(-1f);
        penguinArea.RemoveSpecificFish(rottenFishOb);
        EndEpisode();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Whether the penguin has eaten a fish (1 float = 1 value)
        sensor.AddObservation(isFull);

        // Direction penguin is facing (1 Vector3 = 3 values)
        sensor.AddObservation(transform.forward);

        // 1 + 1 + 3 + 3 = 8 total values
    }


    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
        {
            // move forward
            forwardAction = 1;
        }
        if (Input.GetKey(KeyCode.A))
        {
            // turn left
            turnAction = 1;
        }
        else if (Input.GetKey(KeyCode.D))
        {
            // turn right
            turnAction = 2;
        }

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    }
    public override void OnEpisodeBegin()
    {
        isFull = false;
        penguinArea.ResetArea();
    }
}

'unity > AI' 카테고리의 다른 글

0517 AI Machine Running  (0) 2021.05.17