Skip to content

Commit

Permalink
convert obs to class
Browse files Browse the repository at this point in the history
  • Loading branch information
glvov-bdai committed Oct 9, 2024
1 parent d907803 commit bf2238a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 52 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ known_third_party = [
"warp",
"carb",
"Semantics",
"torchvision"
]
# Imports from this repository
known_first_party = "omni.isaac.lab"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import Articulation, RigidObject
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers.manager_base import ManagerTermBase
from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg
from omni.isaac.lab.sensors import Camera, RayCaster, RayCasterCamera, TiledCamera

if TYPE_CHECKING:
Expand Down Expand Up @@ -233,61 +235,69 @@ def image(
return images.clone()


def image_features(
env: ManagerBasedEnv,
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
data_type: str = "rgb",
convert_perspective_to_orthogonal: bool = True,
model_name: str = "Theia",
model_zoo_cfg: dict | None = None,
) -> torch.Tensor:
"""Extracted image features with a frozen encoder from Images of a specific datatype from the camera sensor.
Args:
env: The environment the cameras are placed within.
sensor_cfg: The desired sensor to read from. Defaults to SceneEntityCfg("tiled_camera").
data_type: The data type to pull from the desired camera. Defaults to "rgb".
model_name: The name of which model to use from the model_zoo_cfg to use to extract features.
model_zoo_cfg: A dictionary with string keys and callable values. Should include "model",
(mapped to a callable with no arguments to return the model), "preprocess" (mapped to
a callable which consumes the images and returns the preprocessed images),
and "inference" (mapped to a callable that provided the model, and the preproccessed images,
returns the features.)
class image_features(ManagerTermBase):
"""Extracted image features with a frozen encoder from images of a specific datatype from the camera sensor.
Returns:
The features from the images produced at the last timestep
Calls :meth:`image` to get the images, then performs inference. On initialization,
for a model zoo different from the default, define model_zoo_cfg: A dictionary with string keys and callable values.
Should include "model", (mapped to a callable with no arguments to return the model), "preprocess" (mapped to
a callable which consumes the images and returns the preprocessed images),
and "inference" (mapped to a callable that provided the model, and the preproccessed images, returns the features.)
"""
if not hasattr(image_features, "model_zoo"):
image_features.model_zoo = {}

if model_zoo_cfg is None:
model_zoo_cfg = {
"ResNet18": {
"model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"),
"preprocess": lambda img: (
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
"inference": lambda model, images: model(images),
},
}

if model_name not in image_features.model_zoo:
print(f"[INFO]: Adding {model_name} to persistent frozen feature extraction model zoo...")
image_features.model_zoo[model_name] = model_zoo_cfg[model_name]["model"]()

images = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # want this for training stability
)

proc_images = model_zoo_cfg[model_name]["preprocess"](images)
features = model_zoo_cfg[model_name]["inference"](image_features.model_zoo[model_name], proc_images)

return features
def __init__(
self,
cfg: ObservationTermCfg,
env: ManagerBasedEnv,
model_zoo_cfg: dict | None = None,
initialize_all: bool = False,
):
super().__init__(cfg, env)
if model_zoo_cfg is None:
self.model_zoo_cfg = {
"ResNet18": {
"model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"),
"preprocess": lambda img: (
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
# Normalize in the format expected by pytorch; https://pytorch.org/hub/pytorch_vision_resnet/
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
"inference": lambda model, images: model(images),
},
}
self.reset_model(initialize_all=initialize_all)

# The following is named reset_model instead of reset as otherwise it's called at the end of every episode
def reset_model(self, initialize_all=False):
self.model_zoo = {}
if initialize_all:
for model_name, model_callables in self.model_zoo_cfg.items():
self.model_zoo[model_name] = model_callables["model"]()

def __call__(
self,
env: ManagerBasedEnv,
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
data_type: str = "rgb",
convert_perspective_to_orthogonal: bool = False,
model_name: str = "ResNet18",
):
if model_name not in self.model_zoo:
print(f"[INFO]: Adding {model_name} to the model zoo")
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()

images = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # want this for training stability
)

proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)

return features


"""
Expand Down

0 comments on commit bf2238a

Please sign in to comment.