Skip to content

Commit

Permalink
Update observations.py
Browse files Browse the repository at this point in the history
Signed-off-by: garylvov <[email protected]>
  • Loading branch information
garylvov authored Oct 10, 2024
1 parent 11abbcf commit 502b882
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,13 @@ def __init__(
}
self.reset_model(initialize_all=initialize_all)

# The following is named reset_model instead of reset as otherwise it's called at the end of every episode
def reset_model(self, initialize_all=False):
self.model_zoo = {}
# The following is named reset_model instead of reset as otherwise, it's called at the end of every episode
def reset_model(self, model_name: str | None = None, initialize_all: bool = False):
if model_name is None:
print("[WARNING]: No model name supplied, emptying entire model zoo.")
self.model_zoo = {}
elif model_name is not None:
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
if initialize_all:
for model_name, model_callables in self.model_zoo_cfg.items():
self.model_zoo[model_name] = model_callables["model"]()
Expand All @@ -281,7 +285,7 @@ def __call__(
data_type: str = "rgb",
convert_perspective_to_orthogonal: bool = False,
model_name: str = "ResNet18",
):
) -> torch.Tensor:
if model_name not in self.model_zoo:
print(f"[INFO]: Adding {model_name} to the model zoo")
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
Expand Down

0 comments on commit 502b882

Please sign in to comment.