Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to PhysicsSensors to include angular velocities #5771

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
return m_Bodies[index].velocity;
}

/// <inheritdoc/>
protected internal override Vector3 GetAngularVelocityAt(int index)
{
return m_Bodies[index].angularVelocity;
}

/// <inheritdoc/>
protected internal override Pose GetPoseAt(int index)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,21 @@ public struct PhysicsSensorSettings
/// </summary>
public bool UseModelSpaceLinearVelocity;

/// <summary>
/// Whether to use model space (relative to the root body) angular velocities as observations.
/// </summary>
public bool UseModelSpaceAngularVelocity;

/// <summary>
/// Whether to use local space (relative to the parent body) linear velocities as observations.
/// </summary>
public bool UseLocalSpaceLinearVelocity;

/// <summary>
/// Whether to use local space (relative to the parent body) angular velocities as observations.
/// </summary>
public bool UseLocalSpaceAngularVelocity;

/// <summary>
/// Whether to use joint-specific positions and angles as observations.
/// </summary>
Expand Down Expand Up @@ -67,15 +77,17 @@ public static PhysicsSensorSettings Default()
/// </summary>
public bool UseModelSpace
{
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity ||
UseModelSpaceAngularVelocity; }
}

/// <summary>
/// Whether any local space observations are being used.
/// </summary>
public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity ||
UseLocalSpaceAngularVelocity; }
}
}

Expand Down Expand Up @@ -109,9 +121,18 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
}
}

foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities())
if (settings.UseModelSpaceLinearVelocity)
{
if (settings.UseModelSpaceLinearVelocity)
foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities())
{
writer.Add(vel, offset);
offset += 3;
}
}

if (settings.UseModelSpaceAngularVelocity)
{
foreach (var vel in poseExtractor.GetEnabledModelSpaceAngularVelocities())
{
writer.Add(vel, offset);
offset += 3;
Expand All @@ -136,9 +157,18 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
}
}

foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
if (settings.UseLocalSpaceLinearVelocity)
{
foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
{
writer.Add(vel, offset);
offset += 3;
}
}

if (settings.UseLocalSpaceAngularVelocity)
{
if (settings.UseLocalSpaceLinearVelocity)
foreach (var vel in poseExtractor.GetEnabledLocalSpaceAngularVelocities())
{
writer.Add(vel, offset);
offset += 3;
Expand Down
66 changes: 63 additions & 3 deletions com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ public abstract class PoseExtractor
Pose[] m_LocalSpacePoses;

Vector3[] m_ModelSpaceLinearVelocities;
Vector3[] m_ModelSpaceAngularVelocities;
Vector3[] m_LocalSpaceLinearVelocities;
Vector3[] m_LocalSpaceAngularVelocities;

bool[] m_PoseEnabled;

Expand Down Expand Up @@ -83,6 +85,25 @@ public IEnumerable<Vector3> GetEnabledModelSpaceVelocities()
}
}

/// <summary>
/// Read iterator for the enabled model space angular velocities.
/// </summary>
public IEnumerable<Vector3> GetEnabledModelSpaceAngularVelocities()
{
if (m_ModelSpaceAngularVelocities == null)
{
yield break;
}

for (var i = 0; i < m_ModelSpaceAngularVelocities.Length; i++)
{
if (m_PoseEnabled[i])
{
yield return m_ModelSpaceAngularVelocities[i];
}
}
}

/// <summary>
/// Read iterator for the enabled local space linear velocities.
/// </summary>
Expand All @@ -102,6 +123,25 @@ public IEnumerable<Vector3> GetEnabledLocalSpaceVelocities()
}
}

/// <summary>
/// Read iterator for the enabled local space angular velocities.
/// </summary>
public IEnumerable<Vector3> GetEnabledLocalSpaceAngularVelocities()
{
if (m_LocalSpaceAngularVelocities == null)
{
yield break;
}

for (var i = 0; i < m_LocalSpaceAngularVelocities.Length; i++)
{
if (m_PoseEnabled[i])
{
yield return m_LocalSpaceAngularVelocities[i];
}
}
}

/// <summary>
/// Number of enabled poses in the hierarchy (read-only).
/// </summary>
Expand Down Expand Up @@ -181,7 +221,9 @@ protected void Setup(int[] parentIndices)
m_LocalSpacePoses = new Pose[numPoses];

m_ModelSpaceLinearVelocities = new Vector3[numPoses];
m_ModelSpaceAngularVelocities = new Vector3[numPoses];
m_LocalSpaceLinearVelocities = new Vector3[numPoses];
m_LocalSpaceAngularVelocities = new Vector3[numPoses];

m_PoseEnabled = new bool[numPoses];
// All poses are enabled by default. Generally we'll want to disable the root though.
Expand All @@ -205,6 +247,13 @@ protected void Setup(int[] parentIndices)
/// <returns></returns>
protected internal abstract Vector3 GetLinearVelocityAt(int index);

/// <summary>
/// Return the world space angular velocity of the i'th object.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected internal abstract Vector3 GetAngularVelocityAt(int index);

/// <summary>
/// Return the underlying object at the given index. This is only
/// used for display in the inspector.
Expand Down Expand Up @@ -232,6 +281,7 @@ public void UpdateModelSpacePoses()
var rootWorldTransform = GetPoseAt(0);
var worldToModel = rootWorldTransform.Inverse();
var rootLinearVel = GetLinearVelocityAt(0);
var rootAngularVel = GetAngularVelocityAt(0);

for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
Expand All @@ -240,8 +290,11 @@ public void UpdateModelSpacePoses()
m_ModelSpacePoses[i] = currentModelSpacePose;

var currentBodyLinearVel = GetLinearVelocityAt(i);
var relativeVelocity = currentBodyLinearVel - rootLinearVel;
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
var relativeLinearVel = currentBodyLinearVel - rootLinearVel;
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeLinearVel;
var currentBodyAngularVel = GetAngularVelocityAt(i);
var relativeAngularVel = currentBodyAngularVel - rootAngularVel;
m_ModelSpaceAngularVelocities[i] = worldToModel.rotation * relativeAngularVel;
}
}
}
Expand Down Expand Up @@ -272,11 +325,15 @@ public void UpdateLocalSpacePoses()
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
var currentLinearVel = GetLinearVelocityAt(i);
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
var parentAngularVel = GetAngularVelocityAt(m_ParentIndices[i]);
var currentAngularVel = GetAngularVelocityAt(i);
m_LocalSpaceAngularVelocities[i] = invParent.rotation * (currentAngularVel - parentAngularVel);
}
else
{
m_LocalSpacePoses[i] = Pose.identity;
m_LocalSpaceLinearVelocities[i] = Vector3.zero;
m_LocalSpaceAngularVelocities[i] = Vector3.zero;
}
}
}
Expand All @@ -296,7 +353,9 @@ public int GetNumPoseObservations(PhysicsSensorSettings settings)
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0;

obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseModelSpaceAngularVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceAngularVelocity ? 3 : 0;

return NumEnabledPoses * obsPerPose;
}
Expand Down Expand Up @@ -363,6 +422,7 @@ internal IList<DisplayNode> GetDisplayNodes()
{
return Array.Empty<DisplayNode>();
}

var nodesOut = new List<DisplayNode>(NumPoses);

// List of children for each node
Expand All @@ -379,6 +439,7 @@ internal IList<DisplayNode> GetDisplayNodes()
{
tree[parent] = new List<int>();
}

tree[parent].Add(i);
}

Expand Down Expand Up @@ -422,7 +483,6 @@ internal IList<DisplayNode> GetDisplayNodes()

return nodesOut;
}

}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

namespace Unity.MLAgents.Extensions.Sensors
{

/// <summary>
/// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node,
/// and child nodes are connect to their parents via Joints.
Expand Down Expand Up @@ -129,9 +128,22 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
// No velocity on the virtual root
return Vector3.zero;
}

return m_Bodies[index].velocity;
}

/// <inheritdoc/>
protected internal override Vector3 GetAngularVelocityAt(int index)
{
if (index == 0 && m_VirtualRoot != null)
{
// No velocity on the virtual root
return Vector3.zero;
}

return m_Bodies[index].angularVelocity;
}

/// <inheritdoc/>
protected internal override Pose GetPoseAt(int index)
{
Expand All @@ -156,6 +168,7 @@ protected internal override Object GetObjectAt(int index)
{
return m_VirtualRoot;
}

return m_Bodies[index];
}

Expand All @@ -167,6 +180,11 @@ protected internal override Object GetObjectAt(int index)
/// <returns></returns>
internal Dictionary<Rigidbody, bool> GetBodyPosesEnabled()
{
if (m_Bodies == null)
{
return new Dictionary<Rigidbody, bool>();
}

var bodyPosesEnabled = new Dictionary<Rigidbody, bool>(m_Bodies.Length);
for (var i = 0; i < m_Bodies.Length; i++)
{
Expand Down Expand Up @@ -205,5 +223,4 @@ internal IEnumerable<Rigidbody> GetEnabledRigidbodies()
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}

protected internal override Vector3 GetAngularVelocityAt(int index)
{
return Vector3.zero;
}
}

class UselessPoseExtractor : BasicPoseExtractor
Expand Down Expand Up @@ -114,6 +119,10 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
return Vector3.zero;
}

protected internal override Vector3 GetAngularVelocityAt(int index)
{
return Vector3.zero;
}
}

[Test]
Expand Down