From 502b88239661ba5cc4fc60687ee4393af3db2dd3 Mon Sep 17 00:00:00 2001 From: garylvov <67614381+garylvov@users.noreply.github.com> Date: Wed, 9 Oct 2024 22:41:45 -0400 Subject: [PATCH] Update observations.py Signed-off-by: garylvov <67614381+garylvov@users.noreply.github.com> --- .../omni/isaac/lab/envs/mdp/observations.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py index 5973e13ae1..e39aa7c942 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py @@ -267,9 +267,13 @@ def __init__( } self.reset_model(initialize_all=initialize_all) - # The following is named reset_model instead of reset as otherwise it's called at the end of every episode - def reset_model(self, initialize_all=False): - self.model_zoo = {} + # The following is named reset_model instead of reset as otherwise, it's called at the end of every episode + def reset_model(self, model_name: str | None = None, initialize_all: bool = False): + if model_name is None: + print("[WARNING]: No model name supplied, emptying entire model zoo.") + self.model_zoo = {} + elif model_name is not None: + self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]() if initialize_all: for model_name, model_callables in self.model_zoo_cfg.items(): self.model_zoo[model_name] = model_callables["model"]() @@ -281,7 +285,7 @@ def __call__( data_type: str = "rgb", convert_perspective_to_orthogonal: bool = False, model_name: str = "ResNet18", - ): + ) -> torch.Tensor: if model_name not in self.model_zoo: print(f"[INFO]: Adding {model_name} to the model zoo") self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()