Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The agent does not achieve its goal #6172

Open
begosik opened this issue Nov 6, 2024 · 0 comments
Open

The agent does not achieve its goal #6172

begosik opened this issue Nov 6, 2024 · 0 comments
Labels
request Issue contains a feature request.

Comments

@begosik
Copy link

begosik commented Nov 6, 2024

Hello, I tried to replicate the crawler example, but my agent doesn't want to learn. Maybe I didn’t set up his rewards and observations completely or didn’t select the config correctly? Here's the code:

using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgentsExamples;
using Unity.VisualScripting;
using UnityEngine;
using UnityEngine.PlayerLoop;

public class AgentMove : Agent
{
    [SerializeField] private Transform target;
    //float speed = 2;

    public Transform body, Leg0, Leg1, Leg2, Leg3;
    public Transform Leg00, Leg10, Leg20, Leg30;


    public float legX;

    public float legY;

    public float legX1;

    public float legY1;

    //public float strength;

    JointDriveController m_JdController;

    float lastDis;

    float dis;

    public float strength;

    public float sec;

    
     
    public override void Initialize()
    {
        m_JdController = GetComponent<JointDriveController>();

        //Setup each body part

        m_JdController.SetupBodyPart(Leg0);
        m_JdController.SetupBodyPart(Leg1);
        m_JdController.SetupBodyPart(Leg2);
        m_JdController.SetupBodyPart(Leg3);
        m_JdController.SetupBodyPart(Leg00);
        m_JdController.SetupBodyPart(Leg10);
        m_JdController.SetupBodyPart(Leg20);
        m_JdController.SetupBodyPart(Leg30);

    }
    public override void OnEpisodeBegin()
    {
        Rigidbody rgBody = GetComponent<Rigidbody>();
        //rgBody.isKinematic = true;
        //rgBody.isKinematic = false;
        lastDis = 0;
        sec = 0;

        SpawnTarget();
    
        //target.transform.localPosition = new Vector3(0, 1, 4.5f);

        
        rgBody.velocity = new Vector3(0, 0, 0);
        rgBody.angularVelocity = new Vector3(0, 0, 0);

        var bpDict = m_JdController.bodyPartsDict;

        bpDict[Leg0].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg1].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg2].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg3].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg00].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg10].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg20].SetJointTargetRotation(0, 0, 0);
        bpDict[Leg30].SetJointTargetRotation(0, 0, 0);

        transform.localPosition = new Vector3(0, 2.8f, 0);
        transform.localRotation = Quaternion.Euler(0, 0, 0);

        rgBody.velocity = new Vector3(0, 0, 0);
        rgBody.angularVelocity = new Vector3(0, 0, 0);

        foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
        {
            bodyPart.Reset(bodyPart);
        }


    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(transform.localPosition);

        sensor.AddObservation(target.localPosition);

        sensor.AddObservation(dis);

        var bp = m_JdController.bodyPartsList;

        sensor.AddObservation(bp[0].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[1].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[2].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[3].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[4].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[5].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[6].currentStrength / m_JdController.maxJointForceLimit);
        sensor.AddObservation(bp[7].currentStrength / m_JdController.maxJointForceLimit);
    }

    void FixedUpdate()
    {
        dis = lastDis - Vector3.Distance(transform.localPosition, target.localPosition);

        if (lastDis != dis)
        {
            AddReward(dis * 5);
            lastDis = Vector3.Distance(transform.localPosition, target.localPosition);
        }
        //transform.localPosition += new Vector3(moveX, 0, moveZ) * Time.deltaTime * speed;
        //transform.Rotate(new Vector3(0, rotY, 0));
        //AddReward(-0.001f);
    }

    public override void OnActionReceived(ActionBuffers actions)
    {

        var bpDict = m_JdController.bodyPartsDict;

        bpDict[Leg0].SetJointTargetRotation(actions.ContinuousActions[0], actions.ContinuousActions[1], 0);
        bpDict[Leg1].SetJointTargetRotation(actions.ContinuousActions[2], actions.ContinuousActions[3], 0);
        bpDict[Leg2].SetJointTargetRotation(actions.ContinuousActions[4], actions.ContinuousActions[5], 0);
        bpDict[Leg3].SetJointTargetRotation(actions.ContinuousActions[6], actions.ContinuousActions[7], 0);

        bpDict[Leg00].SetJointTargetRotation(actions.ContinuousActions[8], 0, 0);
        bpDict[Leg10].SetJointTargetRotation(actions.ContinuousActions[9], 0, 0);
        bpDict[Leg20].SetJointTargetRotation(actions.ContinuousActions[10], 0, 0);
        bpDict[Leg30].SetJointTargetRotation(actions.ContinuousActions[11], 0, 0);



        bpDict[Leg0].SetJointStrength(actions.ContinuousActions[12]);
        bpDict[Leg1].SetJointStrength(actions.ContinuousActions[13]);
        bpDict[Leg2].SetJointStrength(actions.ContinuousActions[14]);
        bpDict[Leg3].SetJointStrength(actions.ContinuousActions[15]);

        bpDict[Leg00].SetJointStrength(actions.ContinuousActions[16]);
        bpDict[Leg10].SetJointStrength(actions.ContinuousActions[17]);
        bpDict[Leg20].SetJointStrength(actions.ContinuousActions[18]);
        bpDict[Leg30].SetJointStrength(actions.ContinuousActions[19]);
    }

    private void OnCollisionEnter(Collision collision)
    {
        if(collision.gameObject.tag == "Target")
        {
            print("pobeda");
            AddReward(1f);
            EndEpisode();
            SpawnTarget();
            sec = 0;
        }
        else if (collision.gameObject.tag == "Wall")
        {
            AddReward(-1f);
            EndEpisode();
        }
        else if (collision.gameObject.tag == "ground")
        {
            AddReward(-1f);
            EndEpisode();
        }
    }

    public void SpawnTarget()
    {
        int rnd0 = Random.Range(-1, 1);
        int rnd1 = Random.Range(-1, 1);
        if (rnd0 == 0)
        {
            if (rnd1 == 0)
                target.transform.localPosition = new Vector3(Random.Range(4f, 7f), 1, Random.Range(7f, -7f));
            else
                target.transform.localPosition = new Vector3(Random.Range(7f, -7f), 1, Random.Range(4f, 7f));
        }
        else
        {
            if (rnd1 == 0)
                target.transform.localPosition = new Vector3(Random.Range(-4f, -7f), 1, Random.Range(7f, -7f));
            else
                target.transform.localPosition = new Vector3(Random.Range(7f, -7f), 1, Random.Range(-4f, -7f));
        }
    }

CONFIG:
behaviors:
Agent1:
trainer_type: ppo
hyperparameters:
batch_size: 2048
buffer_size: 20480
learning_rate: 0.0003
beta: 0.005
epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 256
num_layers: 3
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.995
strength: 1.0
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 1000
summary_freq: 30000

@begosik begosik added the request Issue contains a feature request. label Nov 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
request Issue contains a feature request.
Projects
None yet
Development

No branches or pull requests

1 participant