Skip to content

Commit

Permalink
feat: add clip range of JointAction
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Nov 28, 2024
1 parent 4d99147 commit 1238d7a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class JointActionCfg(ActionTermCfg):
"""Scale factor for the action (float or dict of regex expressions). Defaults to 1.0."""
offset: float | dict[str, float] = 0.0
"""Offset factor for the action (float or dict of regex expressions). Defaults to 0.0."""
clip: dict[str, tuple] | None = None
"""Clip range for the action (dict of regex expressions). Defaults to None."""
preserve_order: bool = False
"""Whether to preserve the order of the joint names in the action output. Defaults to False."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class JointAction(ActionTerm):
"""The scaling factor applied to the input action."""
_offset: torch.Tensor | float
"""The offset applied to the input action."""
_clip: dict[str, tuple] | None = None
"""The clip applied to the input action."""

def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> None:
# initialize the action term
Expand Down Expand Up @@ -94,6 +96,12 @@ def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> Non
self._offset[:, index_list] = torch.tensor(value_list, device=self.device)
else:
raise ValueError(f"Unsupported offset type: {type(cfg.offset)}. Supported types are float and dict.")
# parse clip
if cfg.clip is not None:
if isinstance(cfg.clip, dict):
self._clip = cfg.clip
else:
raise ValueError(f"Unsupported clip type: {type(cfg.scale)}. Supported types are dict.")

"""
Properties.
Expand All @@ -120,6 +128,13 @@ def process_actions(self, actions: torch.Tensor):
self._raw_actions[:] = actions
# apply the affine transformations
self._processed_actions = self._raw_actions * self._scale + self._offset
# clip actions
if self._clip is not None:
# resolve the dictionary config
index_list, _, value_list = string_utils.resolve_matching_names_values(self._clip, self._joint_names)
for index in range(len(index_list)):
min_value, max_value = value_list[index]
self._processed_actions[:, index_list[index]].clip_(min_value, max_value)

def reset(self, env_ids: Sequence[int] | None = None) -> None:
self._raw_actions[env_ids] = 0.0
Expand Down

0 comments on commit 1238d7a

Please sign in to comment.