0525 AI 펭귄
2021. 5. 25. 10:23ㆍunity/AI
bhttps://www.immersivelimit.com/tutorials/reinforcement-learning-penguins-part-2-unity-ml-agents
유니티에서 ML Agent를 임포트 한 후에
그 후에 대한 것은 위의 documentation 참고
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class Fish : MonoBehaviour
{
public float fishSpeed;
private float randomizedSpeed = 0f;
private float nextActionTime = -1f;
private Vector3 targetPosition;
private void FixedUpdate()
{
if (this.fishSpeed > 0)
{
this.Swim();
}
}
private void Swim()
{
if (Time.fixedTime >= this.nextActionTime)
{
//속도 설정
randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);
//펭귄 공간의 물고기가 갈수있는 위치중 랜덤으로 타겟 위치 설정
targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);
//회전 설정
transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);
//타겟 위치까지 가는 시간 계산
float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
nextActionTime = Time.fixedTime + timeToGetThere;
}
else
{
//이동 벡터 구하기
Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
//거리계산
if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
{
//이동
transform.position += moveVector;
}
else
{
//현재 위치를 타겟 위치로 변경
transform.position = targetPosition;
//다음행동시간 초기화
nextActionTime = Time.fixedTime;
}
}
}
}
using System.Collections.Generic;
using UnityEngine;
public class PenguinArea : MonoBehaviour
{
public PenguinAgent penguinAgent;
public Fish fishPrefab;
private List<GameObject> fishList;
// Start is called before the first frame update
void Start()
{
ResetArea();
}
public void ResetArea()
{
RemoveAllFish();
PlacePenguin();
SpawnFish(4, .5f);
}
public int FishRemaining
{
get { return fishList.Count; }
}
//모든 물고기 제거
private void RemoveAllFish()
{
if (this.fishList != null)
{
foreach (var go in this.fishList)
{
Destroy(go);
}
}
//새로운 컬렉션을 만든다
this.fishList = new List<GameObject>();
}
private void SpawnFish(int count, float fishSpeed)
{
for (int i = 0; i < count; i++)
{
// Spawn and place the fish
GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
// Set the fish's parent to this area's transform
fishObject.transform.SetParent(transform);
// Keep track of the fish
fishList.Add(fishObject);
// Set the fish speed
fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
}
}
public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
{
float radius = minRadius;
float angle = minAngle;
if (maxRadius > minRadius)
{
// Pick a random radius
radius = UnityEngine.Random.Range(minRadius, maxRadius);
}
if (maxAngle > minAngle)
{
// Pick a random angle
angle = UnityEngine.Random.Range(minAngle, maxAngle);
}
// Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
}
private void PlacePenguin()
{
Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
}
public void RemoveSpecificFish(GameObject fishObject)
{
this.fishList.Remove(fishObject);
Destroy(fishObject);
}
public int GetFishCount()
{
return this.fishList.Count;
}
}
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PenguinAgent : Agent
{
public float moveSpeed = 5f;
public float turnSpeed = 180f;
private PenguinArea penguinArea;
new private Rigidbody rigidbody;
private bool isFull; // If true, penguin has a full stomach
private int count;
public override void Initialize()
{
base.Initialize();
penguinArea = GetComponentInParent<PenguinArea>();
rigidbody = GetComponent<Rigidbody>();
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Convert the first action to forward movement
float forwardAmount = actionBuffers.DiscreteActions[0];
// Convert the second action to turning left or right
float turnAmount = 0f;
if (actionBuffers.DiscreteActions[1] == 1f)
{
turnAmount = -1f;
}
else if (actionBuffers.DiscreteActions[1] == 2f)
{
turnAmount = 1f;
}
// Apply movement
rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
// Apply a tiny negative reward every step to encourage action
if (MaxStep > 0) AddReward(-1f / MaxStep);
}
private void OnCollisionEnter(Collision collision)
{
if (collision.transform.CompareTag("fish"))
{
// Try to eat the fish
EatFish(collision.gameObject);
}
}
private void EatFish(GameObject fishObject)
{
AddReward(1f);
penguinArea.RemoveSpecificFish(fishObject);
var count = this.penguinArea.GetFishCount();
if (count == 5)
{
this.EndEpisode();
return;
}
}
public override void CollectObservations(VectorSensor sensor)
{
// Whether the penguin has eaten a fish (1 float = 1 value)
sensor.AddObservation(isFull);
// Direction penguin is facing (1 Vector3 = 3 values)
sensor.AddObservation(transform.forward);
// 1 + 1 + 3 + 3 = 8 total values
}
public override void Heuristic(in ActionBuffers actionsOut)
{
int forwardAction = 0;
int turnAction = 0;
if (Input.GetKey(KeyCode.W))
{
// move forward
forwardAction = 1;
}
if (Input.GetKey(KeyCode.A))
{
// turn left
turnAction = 1;
}
else if (Input.GetKey(KeyCode.D))
{
// turn right
turnAction = 2;
}
// Put the actions into the array
actionsOut.DiscreteActions.Array[0] = forwardAction;
actionsOut.DiscreteActions.Array[1] = turnAction;
}
public override void OnEpisodeBegin()
{
isFull = false;
penguinArea.ResetArea();
}
}
약간의 응용 추가
먹으면 안되는 생선을 하나 배치해서 그걸 먹으면 리워드를 -1로 둠.
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class RottenFish : MonoBehaviour
{
public float fishSpeed;
private float randomizedSpeed = 0f;
private float nextActionTime = -1f;
private Vector3 targetPosition;
private void FixedUpdate()
{
if (this.fishSpeed > 0)
{
this.Swim();
}
}
private void Swim()
{
if (Time.fixedTime >= this.nextActionTime)
{
//속도 설정
randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);
//펭귄 공간의 물고기가 갈수있는 위치중 랜덤으로 타겟 위치 설정
targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);
//회전 설정
transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);
//타겟 위치까지 가는 시간 계산
float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
nextActionTime = Time.fixedTime + timeToGetThere;
}
else
{
//이동 벡터 구하기
Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
//거리계산
if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
{
//이동
transform.position += moveVector;
}
else
{
//현재 위치를 타겟 위치로 변경
transform.position = targetPosition;
//다음행동시간 초기화
nextActionTime = Time.fixedTime;
}
}
}
}
using System.Collections.Generic;
using UnityEngine;
public class PenguinArea : MonoBehaviour
{
public PenguinAgent penguinAgent;
public Fish fishPrefab;
public RottenFish rottenFishPrefab;
public List<GameObject> fishList;
// Start is called before the first frame update
void Start()
{
ResetArea();
}
public void ResetArea()
{
RemoveAllFish();
PlacePenguin();
SpawnFish(4, .5f);
SpawnRottenFish(1,.5f);
}
public int FishRemaining
{
get { return fishList.Count; }
}
//모든 물고기 제거
private void RemoveAllFish()
{
if (this.fishList != null)
{
foreach (var go in this.fishList)
{
Destroy(go);
}
}
//새로운 컬렉션을 만든다
this.fishList = new List<GameObject>();
}
private void SpawnFish(int count, float fishSpeed)
{
for (int i = 0; i < count; i++)
{
// Spawn and place the fish
GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
// Set the fish's parent to this area's transform
fishObject.transform.SetParent(transform);
// Keep track of the fish
fishList.Add(fishObject);
// Set the fish speed
fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
}
}
private void SpawnRottenFish(int count, float fishSpeed)
{
for (int i =0; i< count; i++) {
GameObject rottenFishOb = Instantiate<GameObject>(rottenFishPrefab.gameObject);
rottenFishOb.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
rottenFishOb.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
rottenFishOb.transform.SetParent(transform);
fishList.Add(rottenFishOb);
rottenFishOb.GetComponent<RottenFish>().fishSpeed = fishSpeed;
}
}
public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
{
float radius = minRadius;
float angle = minAngle;
if (maxRadius > minRadius)
{
// Pick a random radius
radius = UnityEngine.Random.Range(minRadius, maxRadius);
}
if (maxAngle > minAngle)
{
// Pick a random angle
angle = UnityEngine.Random.Range(minAngle, maxAngle);
}
// Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
}
private void PlacePenguin()
{
Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
}
public void RemoveSpecificFish(GameObject fishObject)
{
this.fishList.Remove(fishObject);
Destroy(fishObject);
}
public int GetFishCount()
{
return this.fishList.Count;
}
}
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PenguinAgent : Agent
{
public float moveSpeed = 5f;
public float turnSpeed = 180f;
private PenguinArea penguinArea;
new private Rigidbody rigidbody;
private bool isFull; // If true, penguin has a full stomach
private int count;
public override void Initialize()
{
base.Initialize();
penguinArea = GetComponentInParent<PenguinArea>();
rigidbody = GetComponent<Rigidbody>();
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Convert the first action to forward movement
float forwardAmount = actionBuffers.DiscreteActions[0];
// Convert the second action to turning left or right
float turnAmount = 0f;
if (actionBuffers.DiscreteActions[1] == 1f)
{
turnAmount = -1f;
}
else if (actionBuffers.DiscreteActions[1] == 2f)
{
turnAmount = 1f;
}
// Apply movement
rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
// Apply a tiny negative reward every step to encourage action
if (MaxStep > 0) AddReward(-1f / MaxStep);
}
private void OnCollisionEnter(Collision collision)
{
if (collision.transform.CompareTag("fish"))
{
// Try to eat the fish
EatFish(collision.gameObject);
}
else if (collision.transform.CompareTag("rottenFish"))
{
EatRottenFish(collision.gameObject);
}
}
private void EatFish(GameObject fishObject)
{
penguinArea.RemoveSpecificFish(fishObject);
AddReward(1f);
int count = 0;
foreach (var i in penguinArea.fishList)
{
if (i.tag == "fish")
{
count++;
}
}
if (count <= 0)
{
this.EndEpisode();
}
}
private void EatRottenFish(GameObject rottenFishOb)
{
AddReward(-1f);
penguinArea.RemoveSpecificFish(rottenFishOb);
EndEpisode();
}
public override void CollectObservations(VectorSensor sensor)
{
// Whether the penguin has eaten a fish (1 float = 1 value)
sensor.AddObservation(isFull);
// Direction penguin is facing (1 Vector3 = 3 values)
sensor.AddObservation(transform.forward);
// 1 + 1 + 3 + 3 = 8 total values
}
public override void Heuristic(in ActionBuffers actionsOut)
{
int forwardAction = 0;
int turnAction = 0;
if (Input.GetKey(KeyCode.W))
{
// move forward
forwardAction = 1;
}
if (Input.GetKey(KeyCode.A))
{
// turn left
turnAction = 1;
}
else if (Input.GetKey(KeyCode.D))
{
// turn right
turnAction = 2;
}
// Put the actions into the array
actionsOut.DiscreteActions.Array[0] = forwardAction;
actionsOut.DiscreteActions.Array[1] = turnAction;
}
public override void OnEpisodeBegin()
{
isFull = false;
penguinArea.ResetArea();
}
}
'unity > AI' 카테고리의 다른 글
0517 AI Machine Running (0) | 2021.05.17 |
---|