From 7dc25b9ec10f66e51da12717156e0dc4e0277ff7 Mon Sep 17 00:00:00 2001 From: glvov-bdai Date: Mon, 28 Oct 2024 12:58:50 -0400 Subject: [PATCH] Adds image extracted features observation term and cartpole examples for it (#1191) # Description This adds an observation term to be able to easily extract features from the images, and adds a cartpole example of using this new term. The new ResNet18 cartpole converges in less than 100 epochs. ## Type of change - New feature (non-breaking change which adds functionality) - This change requires a documentation update ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there I will update the version in the changelog and extension.toml after approval prior to merging in due to it causing merge conflicts when main updates --------- Signed-off-by: glvov-bdai Signed-off-by: garylvov <67614381+garylvov@users.noreply.github.com> Co-authored-by: garylvov <67614381+garylvov@users.noreply.github.com> Co-authored-by: garylvov Co-authored-by: David Hoeller Co-authored-by: James Smith <142246516+jsmith-bdai@users.noreply.github.com> --- CONTRIBUTORS.md | 1 + docs/source/overview/environments.rst | 11 +- pyproject.toml | 3 + .../omni.isaac.lab/config/extension.toml | 3 +- .../omni.isaac.lab/docs/CHANGELOG.rst | 10 ++ .../omni/isaac/lab/envs/mdp/observations.py | 131 ++++++++++++++++++ .../config/extension.toml | 2 +- .../omni.isaac.lab_tasks/docs/CHANGELOG.rst | 17 +++ .../direct/shadow_hand/feature_extractor.py | 1 - .../classic/cartpole/__init__.py | 31 ++++- .../agents/rl_games_feature_ppo_cfg.yaml | 79 +++++++++++ .../cartpole/cartpole_camera_env_cfg.py | 56 +++++++- 12 files changed, 336 insertions(+), 9 deletions(-) create mode 100644 source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/agents/rl_games_feature_ppo_cfg.yaml diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 49497a9909..e90992ca29 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -43,6 +43,7 @@ Guidelines for modifications: * Chenyu Yang * David Yang * Dorsa Rohani +* Felix Yu * Gary Lvov * Giulio Romualdi * HoJin Jeon diff --git a/docs/source/overview/environments.rst b/docs/source/overview/environments.rst index f42f3c34a5..531588a8de 100644 --- a/docs/source/overview/environments.rst +++ b/docs/source/overview/environments.rst @@ -61,6 +61,10 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty | | | | | | |cartpole-depth-direct-link|| | +------------------+-----------------------------+-------------------------------------------------------------------------+ + | |cartpole| | |cartpole-resnet-link| | Move the cart to keep the pole upwards in the classic cartpole control | + | | | based off of features extracted from perceptive inputs with pre-trained | + | | |cartpole-theia-link| | frozen vision encoders | + +------------------+-----------------------------+-------------------------------------------------------------------------+ .. |humanoid| image:: ../_static/tasks/classic/humanoid.jpg .. |ant| image:: ../_static/tasks/classic/ant.jpg @@ -69,8 +73,11 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty .. |humanoid-link| replace:: `Isaac-Humanoid-v0 `__ .. |ant-link| replace:: `Isaac-Ant-v0 `__ .. |cartpole-link| replace:: `Isaac-Cartpole-v0 `__ -.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-Camera-v0 `__ -.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-Camera-v0 `__ +.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-v0 `__ +.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-v0 `__ +.. |cartpole-resnet-link| replace:: `Isaac-Cartpole-RGB-ResNet18-v0 `__ +.. |cartpole-theia-link| replace:: `Isaac-Cartpole-RGB-TheiaTiny-v0 `__ + .. |humanoid-direct-link| replace:: `Isaac-Humanoid-Direct-v0 `__ .. |ant-direct-link| replace:: `Isaac-Ant-Direct-v0 `__ diff --git a/pyproject.toml b/pyproject.toml index 51d4375907..63ec9afd2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ extra_standard_library = [ "toml", "trimesh", "tqdm", + "torchvision", + "transformers", + "einops" # Needed for transformers, doesn't always auto-install ] # Imports from Isaac Sim and Omniverse known_third_party = [ diff --git a/source/extensions/omni.isaac.lab/config/extension.toml b/source/extensions/omni.isaac.lab/config/extension.toml index 1d78d22a46..4a6faf6114 100644 --- a/source/extensions/omni.isaac.lab/config/extension.toml +++ b/source/extensions/omni.isaac.lab/config/extension.toml @@ -1,7 +1,8 @@ [package] # Note: Semantic Versioning is used: https://semver.org/ -version = "0.27.6" + +version = "0.27.7" # Description title = "Isaac Lab framework for Robot Learning" diff --git a/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst b/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst index 16bc34b7ed..ecc0471794 100644 --- a/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst +++ b/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst @@ -1,6 +1,16 @@ Changelog --------- + +0.27.7 (2024-10-28) +~~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added frozen encoder feature extraction observation space with ResNet and Theia + + 0.27.6 (2024-10-25) ~~~~~~~~~~~~~~~~~~~ diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py index fbae1d21cc..aca0f579ce 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py @@ -17,11 +17,14 @@ import omni.isaac.lab.utils.math as math_utils from omni.isaac.lab.assets import Articulation, RigidObject from omni.isaac.lab.managers import SceneEntityCfg +from omni.isaac.lab.managers.manager_base import ManagerTermBase +from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg from omni.isaac.lab.sensors import Camera, Imu, RayCaster, RayCasterCamera, TiledCamera if TYPE_CHECKING: from omni.isaac.lab.envs import ManagerBasedEnv, ManagerBasedRLEnv + """ Root state. """ @@ -273,6 +276,134 @@ def image( return images.clone() +class image_features(ManagerTermBase): + """Extracted image features from a pre-trained frozen encoder. + + This method calls the :meth:`image` function to retrieve images, and then performs + inference on those images. + """ + + def __init__(self, cfg: ObservationTermCfg, env: ManagerBasedEnv): + super().__init__(cfg, env) + from torchvision import models + from transformers import AutoModel + + def create_theia_model(model_name): + return { + "model": ( + lambda: AutoModel.from_pretrained(f"theaiinstitute/{model_name}", trust_remote_code=True) + .eval() + .to("cuda:0") + ), + "preprocess": lambda img: (img - torch.amin(img, dim=(1, 2), keepdim=True)) / ( + torch.amax(img, dim=(1, 2), keepdim=True) - torch.amin(img, dim=(1, 2), keepdim=True) + ), + "inference": lambda model, images: model.forward_feature( + images, do_rescale=False, interpolate_pos_encoding=True + ), + } + + def create_resnet_model(resnet_name): + return { + "model": lambda: getattr(models, resnet_name)(pretrained=True).eval().to("cuda:0"), + "preprocess": lambda img: ( + img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width] + - torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1) + ) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1), + "inference": lambda model, images: model(images), + } + + # List of Theia models + theia_models = [ + "theia-tiny-patch16-224-cddsv", + "theia-tiny-patch16-224-cdiv", + "theia-small-patch16-224-cdiv", + "theia-base-patch16-224-cdiv", + "theia-small-patch16-224-cddsv", + "theia-base-patch16-224-cddsv", + ] + + # List of ResNet models + resnet_models = ["resnet18", "resnet34", "resnet50", "resnet101"] + + self.default_model_zoo_cfg = {} + + # Add Theia models to the zoo + for model_name in theia_models: + self.default_model_zoo_cfg[model_name] = create_theia_model(model_name) + + # Add ResNet models to the zoo + for resnet_name in resnet_models: + self.default_model_zoo_cfg[resnet_name] = create_resnet_model(resnet_name) + + self.model_zoo_cfg = self.default_model_zoo_cfg + self.model_zoo = {} + + def __call__( + self, + env: ManagerBasedEnv, + sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"), + data_type: str = "rgb", + convert_perspective_to_orthogonal: bool = False, + model_zoo_cfg: dict | None = None, + model_name: str = "ResNet18", + model_device: str | None = "cuda:0", + reset_model: bool = False, + ) -> torch.Tensor: + """Extracted image features from a pre-trained frozen encoder. + + Args: + env: The environment. + sensor_cfg: The sensor configuration to poll. Defaults to SceneEntityCfg("tiled_camera"). + data_type: THe sensor configuration datatype. Defaults to "rgb". + convert_perspective_to_orthogonal: Whether to orthogonalize perspective depth images. + This is used only when the data type is "distance_to_camera". Defaults to False. + model_zoo_cfg: Map from model name to model configuration dictionary. Each model + configuration dictionary should include the following entries: + - "model": A callable that returns the model when invoked without arguments. + - "preprocess": A callable that processes the images and returns the preprocessed results. + - "inference": A callable that, when given the model and preprocessed images, + returns the extracted features. + model_name: The name of the model to use for inference. Defaults to "ResNet18". + model_device: The device to store and infer models on. This can be used help offload + computation from the main environment GPU. Defaults to "cuda:0". + reset_model: Initialize the model even if it already exists. Defaults to False. + + Returns: + torch.Tensor: the image features, on the same device as the image + """ + if model_zoo_cfg is not None: # use other than default + self.model_zoo_cfg.update(model_zoo_cfg) + + if model_name not in self.model_zoo or reset_model: + # The following allows to only load a desired subset of a model zoo into GPU memory + # as it becomes needed, in a "lazy" evaluation. + print(f"[INFO]: Adding {model_name} to the model zoo") + self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]() + + if model_device is not None and self.model_zoo[model_name].device != model_device: + # want to offload vision model inference to another device + self.model_zoo[model_name] = self.model_zoo[model_name].to(model_device) + + images = image( + env=env, + sensor_cfg=sensor_cfg, + data_type=data_type, + convert_perspective_to_orthogonal=convert_perspective_to_orthogonal, + normalize=True, # want this for training stability + ) + + image_device = images.device + + if model_device is not None: + images = images.to(model_device) + + proc_images = self.model_zoo_cfg[model_name]["preprocess"](images) + features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images) + + return features.to(image_device).clone() + + """ Actions. """ diff --git a/source/extensions/omni.isaac.lab_tasks/config/extension.toml b/source/extensions/omni.isaac.lab_tasks/config/extension.toml index a0b2718609..544cd97377 100644 --- a/source/extensions/omni.isaac.lab_tasks/config/extension.toml +++ b/source/extensions/omni.isaac.lab_tasks/config/extension.toml @@ -1,7 +1,7 @@ [package] # Note: Semantic Versioning is used: https://semver.org/ -version = "0.10.10" +version = "0.10.12" # Description title = "Isaac Lab Environments" diff --git a/source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst b/source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst index a0a2b39346..c194591492 100644 --- a/source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst +++ b/source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst @@ -1,6 +1,23 @@ Changelog --------- +0.10.12 (2024-10-28) +~~~~~~~~~~~~~~~~~~~~ + +Changed +^^^^^^^ + +* Changed manager-based vision cartpole environment names from Isaac-Cartpole-RGB-Camera-v0 + and Isaac-Cartpole-Depth-Camera-v0 to Isaac-Cartpole-RGB-v0 and Isaac-Cartpole-Depth-v0 + +0.10.11 (2024-10-28) +~~~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added feature extracted observation cartpole examples. + 0.10.10 (2024-10-25) ~~~~~~~~~~~~~~~~~~~~ diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/feature_extractor.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/feature_extractor.py index 1dbfb39b1a..fc92bbfb1b 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/feature_extractor.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/feature_extractor.py @@ -7,7 +7,6 @@ import os import torch import torch.nn as nn - import torchvision from omni.isaac.lab.sensors import save_images_to_file diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/__init__.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/__init__.py index 7a3070d775..43040be70a 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/__init__.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/__init__.py @@ -10,7 +10,12 @@ import gymnasium as gym from . import agents -from .cartpole_camera_env_cfg import CartpoleDepthCameraEnvCfg, CartpoleRGBCameraEnvCfg +from .cartpole_camera_env_cfg import ( + CartpoleDepthCameraEnvCfg, + CartpoleResNet18CameraEnvCfg, + CartpoleRGBCameraEnvCfg, + CartpoleTheiaTinyCameraEnvCfg, +) from .cartpole_env_cfg import CartpoleEnvCfg ## @@ -31,7 +36,7 @@ ) gym.register( - id="Isaac-Cartpole-RGB-Camera-v0", + id="Isaac-Cartpole-RGB-v0", entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv", disable_env_checker=True, kwargs={ @@ -41,7 +46,7 @@ ) gym.register( - id="Isaac-Cartpole-Depth-Camera-v0", + id="Isaac-Cartpole-Depth-v0", entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv", disable_env_checker=True, kwargs={ @@ -49,3 +54,23 @@ "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_camera_ppo_cfg.yaml", }, ) + +gym.register( + id="Isaac-Cartpole-RGB-ResNet18-v0", + entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": CartpoleResNet18CameraEnvCfg, + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Cartpole-RGB-TheiaTiny-v0", + entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": CartpoleTheiaTinyCameraEnvCfg, + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml", + }, +) diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/agents/rl_games_feature_ppo_cfg.yaml b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/agents/rl_games_feature_ppo_cfg.yaml new file mode 100644 index 0000000000..18e0ffd022 --- /dev/null +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/agents/rl_games_feature_ppo_cfg.yaml @@ -0,0 +1,79 @@ +params: + seed: 42 + + # environment wrapper clipping + env: + # added to the wrapper + clip_observations: 5.0 + # can make custom wrapper? + clip_actions: 1.0 + + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + # doesn't have this fine grained control but made it close + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + mlp: + units: [256] + activation: elu + d2rl: False + + initializer: + name: default + regularizer: + name: None + + load_checkpoint: False # flag which sets whether to load the checkpoint + load_path: '' # path to the checkpoint to load + + config: + name: cartpole_features + env_name: rlgpu + device: 'cuda:0' + device_name: 'cuda:0' + multi_gpu: False + ppo: True + mixed_precision: False + normalize_input: True + normalize_value: True + value_bootstraop: True + num_actors: -1 # configured from the script (based on num_envs) + reward_shaper: + scale_value: 1.0 + normalize_advantage: True + gamma: 0.99 + tau : 0.95 + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + score_to_win: 20000 + max_epochs: 5000 + save_best_after: 50 + save_frequency: 25 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + horizon_length: 16 + minibatch_size: 2048 + mini_epochs: 8 + critic_coef: 4 + clip_value: True + seq_length: 4 + bounds_loss_coef: 0.0001 diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py index ce5a6c90b8..f767a21962 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py @@ -78,7 +78,7 @@ class DepthObservationsCfg: """Observation specifications for the MDP.""" @configclass - class DepthCameraPolicyCfg(RGBObservationsCfg.RGBCameraPolicyCfg): + class DepthCameraPolicyCfg(ObsGroup): """Observations for policy group with depth images.""" image = ObsTerm( @@ -88,6 +88,43 @@ class DepthCameraPolicyCfg(RGBObservationsCfg.RGBCameraPolicyCfg): policy: ObsGroup = DepthCameraPolicyCfg() +@configclass +class ResNet18ObservationCfg: + """Observation specifications for the MDP.""" + + @configclass + class ResNet18FeaturesCameraPolicyCfg(ObsGroup): + """Observations for policy group with features extracted from RGB images with a frozen ResNet18.""" + + image = ObsTerm( + func=mdp.image_features, + params={"sensor_cfg": SceneEntityCfg("tiled_camera"), "data_type": "rgb", "model_name": "resnet18"}, + ) + + policy: ObsGroup = ResNet18FeaturesCameraPolicyCfg() + + +@configclass +class TheiaTinyObservationCfg: + """Observation specifications for the MDP.""" + + @configclass + class TheiaTinyFeaturesCameraPolicyCfg(ObsGroup): + """Observations for policy group with features extracted from RGB images with a frozen Theia-Tiny Transformer""" + + image = ObsTerm( + func=mdp.image_features, + params={ + "sensor_cfg": SceneEntityCfg("tiled_camera"), + "data_type": "rgb", + "model_name": "theia-tiny-patch16-224-cddsv", + "model_device": "cuda:0", + }, + ) + + policy: ObsGroup = TheiaTinyFeaturesCameraPolicyCfg() + + ## # Environment configuration ## @@ -107,3 +144,20 @@ class CartpoleDepthCameraEnvCfg(CartpoleEnvCfg): scene: CartpoleSceneCfg = CartpoleDepthCameraSceneCfg(num_envs=1024, env_spacing=20) observations: DepthObservationsCfg = DepthObservationsCfg() + + +@configclass +class CartpoleResNet18CameraEnvCfg(CartpoleRGBCameraEnvCfg): + observations: ResNet18ObservationCfg = ResNet18ObservationCfg() + + +@configclass +class CartpoleTheiaTinyCameraEnvCfg(CartpoleRGBCameraEnvCfg): + """ + Due to TheiaTiny's size in GPU memory, we reduce the number of environments by default. + This helps reduce the possibility of crashing on more modest hardware. + The following configuration uses ~12gb VRAM at peak. + """ + + scene: CartpoleSceneCfg = CartpoleRGBCameraSceneCfg(num_envs=128, env_spacing=20) + observations: TheiaTinyObservationCfg = TheiaTinyObservationCfg()