Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds image extracted features observation term and cartpole examples for it #1191

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5d09fcb
Add Built In Feature Extraction
glvov-bdai Oct 8, 2024
f7ef078
Update environments.rst
glvov-bdai Oct 8, 2024
12dadd1
Update environments.rst
glvov-bdai Oct 8, 2024
d907803
Update environments.rst
glvov-bdai Oct 8, 2024
bf2238a
convert obs to class
glvov-bdai Oct 9, 2024
11abbcf
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 9, 2024
502b882
Update observations.py
garylvov Oct 10, 2024
ab42141
add vision transformer
garylvov Oct 11, 2024
df2b1ff
simplify MLP
garylvov Oct 11, 2024
a490961
change model
garylvov Oct 11, 2024
2343e35
Update pyproject.toml
glvov-bdai Oct 11, 2024
f92957e
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 11, 2024
72a9fbf
update
glvov-bdai Oct 12, 2024
9447165
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 12, 2024
5336ad6
Update feature_extractor.py
glvov-bdai Oct 12, 2024
3744dde
update env names
glvov-bdai Oct 12, 2024
953716b
Merge branch 'feature/preprocess_observation_updated' of https://gith…
glvov-bdai Oct 12, 2024
abb8743
consistent import ordering
glvov-bdai Oct 12, 2024
1e81c87
Update observations.py
glvov-bdai Oct 15, 2024
c4cbb95
Update cartpole_camera_env_cfg.py
glvov-bdai Oct 15, 2024
a06033e
formatting
glvov-bdai Oct 21, 2024
4c43aca
format but actually good this time
glvov-bdai Oct 21, 2024
4f61f88
fix to inherit from obs group
glvov-bdai Oct 21, 2024
ecb65ce
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 22, 2024
27451f9
Merge branch 'main' into feature/preprocess_observation_updated
Dhoeller19 Oct 24, 2024
8829417
Merge branch 'main' into feature/preprocess_observation_updated
Dhoeller19 Oct 24, 2024
01bda48
Merge branch 'main' into feature/preprocess_observation_updated
glvov-bdai Oct 24, 2024
e9de97f
Merge branch 'main' into feature/preprocess_observation_updated
glvov-bdai Oct 28, 2024
cf79302
formatting
glvov-bdai Oct 28, 2024
ca775fa
Add changelog about cartpole renaming
glvov-bdai Oct 28, 2024
05ac736
Update source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/obser…
glvov-bdai Oct 28, 2024
c1f412f
address James' comments
glvov-bdai Oct 28, 2024
4bb7993
Merge branch 'main' into feature/preprocess_observation_updated
glvov-bdai Oct 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Guidelines for modifications:
* Chenyu Yang
* David Yang
* Dorsa Rohani
* Felix Yu
* Gary Lvov
* Giulio Romualdi
* HoJin Jeon
Expand Down
11 changes: 9 additions & 2 deletions docs/source/overview/environments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
| | | |
| | |cartpole-depth-direct-link|| |
+------------------+-----------------------------+-------------------------------------------------------------------------+
| |cartpole| | |cartpole-resnet-link| | Move the cart to keep the pole upwards in the classic cartpole control |
| | | based off of features extracted from perceptive inputs with pre-trained |
| | |cartpole-theia-link| | frozen vision encoders |
+------------------+-----------------------------+-------------------------------------------------------------------------+
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved

.. |humanoid| image:: ../_static/tasks/classic/humanoid.jpg
.. |ant| image:: ../_static/tasks/classic/ant.jpg
Expand All @@ -69,8 +73,11 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
.. |humanoid-link| replace:: `Isaac-Humanoid-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/humanoid/humanoid_env_cfg.py>`__
.. |ant-link| replace:: `Isaac-Ant-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/ant/ant_env_cfg.py>`__
.. |cartpole-link| replace:: `Isaac-Cartpole-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py>`__
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-resnet-link| replace:: `Isaac-Cartpole-RGB-ResNet18-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-theia-link| replace:: `Isaac-Cartpole-RGB-TheiaTiny-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__


.. |humanoid-direct-link| replace:: `Isaac-Humanoid-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/humanoid/humanoid_env.py>`__
.. |ant-direct-link| replace:: `Isaac-Ant-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/ant/ant_env.py>`__
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ extra_standard_library = [
"toml",
"trimesh",
"tqdm",
"torchvision",
"transformers",
"einops" # Needed for transformers, doesn't always auto-install
]
# Imports from Isaac Sim and Omniverse
known_third_party = [
Expand Down
3 changes: 2 additions & 1 deletion source/extensions/omni.isaac.lab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.6"

version = "0.27.7"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
10 changes: 10 additions & 0 deletions source/extensions/omni.isaac.lab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Changelog
---------


0.27.7 (2024-10-28)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added frozen encoder feature extraction observation space with ResNet and Theia


0.27.6 (2024-10-25)
~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import Articulation, RigidObject
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers.manager_base import ManagerTermBase
from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg
from omni.isaac.lab.sensors import Camera, Imu, RayCaster, RayCasterCamera, TiledCamera

if TYPE_CHECKING:
from omni.isaac.lab.envs import ManagerBasedEnv, ManagerBasedRLEnv


"""
Root state.
"""
Expand Down Expand Up @@ -273,6 +276,134 @@ def image(
return images.clone()


class image_features(ManagerTermBase):
"""Extracted image features from a pre-trained frozen encoder.
This method calls the :meth:`image` function to retrieve images, and then performs
inference on those images.
"""

def __init__(self, cfg: ObservationTermCfg, env: ManagerBasedEnv):
super().__init__(cfg, env)
from torchvision import models
from transformers import AutoModel

def create_theia_model(model_name):
return {
"model": (
lambda: AutoModel.from_pretrained(f"theaiinstitute/{model_name}", trust_remote_code=True)
.eval()
.to("cuda:0")
),
"preprocess": lambda img: (img - torch.amin(img, dim=(1, 2), keepdim=True)) / (
torch.amax(img, dim=(1, 2), keepdim=True) - torch.amin(img, dim=(1, 2), keepdim=True)
),
"inference": lambda model, images: model.forward_feature(
images, do_rescale=False, interpolate_pos_encoding=True
),
}

def create_resnet_model(resnet_name):
return {
"model": lambda: getattr(models, resnet_name)(pretrained=True).eval().to("cuda:0"),
"preprocess": lambda img: (
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved
"inference": lambda model, images: model(images),
}

# List of Theia models
theia_models = [
"theia-tiny-patch16-224-cddsv",
"theia-tiny-patch16-224-cdiv",
"theia-small-patch16-224-cdiv",
"theia-base-patch16-224-cdiv",
"theia-small-patch16-224-cddsv",
"theia-base-patch16-224-cddsv",
]

# List of ResNet models
resnet_models = ["resnet18", "resnet34", "resnet50", "resnet101"]
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved

self.default_model_zoo_cfg = {}

# Add Theia models to the zoo
for model_name in theia_models:
self.default_model_zoo_cfg[model_name] = create_theia_model(model_name)

# Add ResNet models to the zoo
for resnet_name in resnet_models:
self.default_model_zoo_cfg[resnet_name] = create_resnet_model(resnet_name)

self.model_zoo_cfg = self.default_model_zoo_cfg
self.model_zoo = {}

def __call__(
self,
env: ManagerBasedEnv,
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
data_type: str = "rgb",
convert_perspective_to_orthogonal: bool = False,
model_zoo_cfg: dict | None = None,
model_name: str = "ResNet18",
model_device: str | None = "cuda:0",
reset_model: bool = False,
) -> torch.Tensor:
"""Extracted image features from a pre-trained frozen encoder.
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved
Args:
env: The environment.
sensor_cfg: The sensor configuration to poll. Defaults to SceneEntityCfg("tiled_camera").
data_type: THe sensor configuration datatype. Defaults to "rgb".
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved
convert_perspective_to_orthogonal: Whether to orthogonalize perspective depth images.
This is used only when the data type is "distance_to_camera". Defaults to False.
model_zoo_cfg: Map from model name to model configuration dictionary. Each model
configuration dictionary should include the following entries:
- "model": A callable that returns the model when invoked without arguments.
- "preprocess": A callable that processes the images and returns the preprocessed results.
- "inference": A callable that, when given the model and preprocessed images,
returns the extracted features.
model_name: The name of the model to use for inference. Defaults to "ResNet18".
model_device: The device to store and infer models on. This can be used help offload
computation from the main environment GPU. Defaults to "cuda:0".
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved
reset_model: Initialize the model even if it already exists. Defaults to False.
Returns:
torch.Tensor: the image features, on the same device as the image
"""
if model_zoo_cfg is not None: # use other than default
self.model_zoo_cfg.update(model_zoo_cfg)

if model_name not in self.model_zoo or reset_model:
# The following allows to only load a desired subset of a model zoo into GPU memory
# as it becomes needed, in a "lazy" evaluation.
print(f"[INFO]: Adding {model_name} to the model zoo")
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved

if model_device is not None and self.model_zoo[model_name].device != model_device:
# want to offload vision model inference to another device
self.model_zoo[model_name] = self.model_zoo[model_name].to(model_device)

images = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # want this for training stability
)

image_device = images.device

if model_device is not None:
images = images.to(model_device)

proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)

return features.to(image_device).clone()


"""
Actions.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.10.10"
version = "0.10.12"

# Description
title = "Isaac Lab Environments"
Expand Down
17 changes: 17 additions & 0 deletions source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
Changelog
---------

0.10.12 (2024-10-28)
~~~~~~~~~~~~~~~~~~~~

Changed
^^^^^^^

* Changed manager-based vision cartpole environment names from Isaac-Cartpole-RGB-Camera-v0
and Isaac-Cartpole-Depth-Camera-v0 to Isaac-Cartpole-RGB-v0 and Isaac-Cartpole-Depth-v0
glvov-bdai marked this conversation as resolved.
Show resolved Hide resolved

0.10.11 (2024-10-28)
~~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added feature extracted observation cartpole examples.

0.10.10 (2024-10-25)
~~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import torch
import torch.nn as nn

import torchvision

from omni.isaac.lab.sensors import save_images_to_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import gymnasium as gym

from . import agents
from .cartpole_camera_env_cfg import CartpoleDepthCameraEnvCfg, CartpoleRGBCameraEnvCfg
from .cartpole_camera_env_cfg import (
CartpoleDepthCameraEnvCfg,
CartpoleResNet18CameraEnvCfg,
CartpoleRGBCameraEnvCfg,
CartpoleTheiaTinyCameraEnvCfg,
)
from .cartpole_env_cfg import CartpoleEnvCfg

##
Expand All @@ -31,7 +36,7 @@
)

gym.register(
id="Isaac-Cartpole-RGB-Camera-v0",
id="Isaac-Cartpole-RGB-v0",
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
Expand All @@ -41,11 +46,31 @@
)

gym.register(
id="Isaac-Cartpole-Depth-Camera-v0",
id="Isaac-Cartpole-Depth-v0",
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": CartpoleDepthCameraEnvCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_camera_ppo_cfg.yaml",
},
)

gym.register(
id="Isaac-Cartpole-RGB-ResNet18-v0",
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": CartpoleResNet18CameraEnvCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml",
},
)

gym.register(
id="Isaac-Cartpole-RGB-TheiaTiny-v0",
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": CartpoleTheiaTinyCameraEnvCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml",
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
params:
seed: 42

# environment wrapper clipping
env:
# added to the wrapper
clip_observations: 5.0
# can make custom wrapper?
clip_actions: 1.0

algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

# doesn't have this fine grained control but made it close
network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None

mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
mlp:
units: [256]
activation: elu
d2rl: False

initializer:
name: default
regularizer:
name: None

load_checkpoint: False # flag which sets whether to load the checkpoint
load_path: '' # path to the checkpoint to load

config:
name: cartpole_features
env_name: rlgpu
device: 'cuda:0'
device_name: 'cuda:0'
multi_gpu: False
ppo: True
mixed_precision: False
normalize_input: True
normalize_value: True
value_bootstraop: True
num_actors: -1 # configured from the script (based on num_envs)
reward_shaper:
scale_value: 1.0
normalize_advantage: True
gamma: 0.99
tau : 0.95
learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
score_to_win: 20000
max_epochs: 5000
save_best_after: 50
save_frequency: 25
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 16
minibatch_size: 2048
mini_epochs: 8
critic_coef: 4
clip_value: True
seq_length: 4
bounds_loss_coef: 0.0001
Loading
Loading