diff --git a/habitat-baselines/habitat_baselines/rl/ddppo/policy/resnet_policy.py b/habitat-baselines/habitat_baselines/rl/ddppo/policy/resnet_policy.py index 110e9ad730..a346c217cd 100644 --- a/habitat-baselines/habitat_baselines/rl/ddppo/policy/resnet_policy.py +++ b/habitat-baselines/habitat_baselines/rl/ddppo/policy/resnet_policy.py @@ -51,6 +51,7 @@ def __init__( rnn_type: str = "GRU", resnet_baseplanes: int = 32, backbone: str = "resnet18", + normalize_visual_inputs: bool = False, force_blind_policy: bool = False, policy_config: "DictConfig" = None, aux_loss_config: Optional["DictConfig"] = None, @@ -77,6 +78,7 @@ def __init__( rnn_type=rnn_type, backbone=backbone, resnet_baseplanes=resnet_baseplanes, + normalize_visual_inputs=normalize_visual_inputs, fuse_keys=fuse_keys, force_blind_policy=force_blind_policy, discrete_actions=discrete_actions, @@ -115,6 +117,7 @@ def from_config( rnn_type=config.habitat_baselines.rl.ddppo.rnn_type, num_recurrent_layers=config.habitat_baselines.rl.ddppo.num_recurrent_layers, backbone=config.habitat_baselines.rl.ddppo.backbone, + normalize_visual_inputs="rgb" in observation_space.spaces, force_blind_policy=config.habitat_baselines.force_blind_policy, policy_config=config.habitat_baselines.rl.policy, aux_loss_config=config.habitat_baselines.rl.auxiliary_losses, @@ -130,6 +133,7 @@ def __init__( ngroups: int = 32, spatial_size: int = 128, make_backbone=None, + normalize_visual_inputs: bool = False, ): super().__init__() @@ -149,7 +153,7 @@ def __init__( observation_space.spaces[k].shape[2] for k in self.visual_keys ) - if self._n_input_channels > 0: + if normalize_visual_inputs: self.running_mean_and_var: nn.Module = RunningMeanAndVar( self._n_input_channels ) @@ -252,6 +256,7 @@ def __init__( rnn_type: str, backbone, resnet_baseplanes, + normalize_visual_inputs: bool, fuse_keys: Optional[List[str]], force_blind_policy: bool = False, discrete_actions: bool = True, @@ -376,6 +381,7 @@ def __init__( baseplanes=resnet_baseplanes, ngroups=resnet_baseplanes // 2, make_backbone=getattr(resnet, backbone), + normalize_visual_inputs=normalize_visual_inputs, ) setattr(self, f"{uuid}_encoder", goal_visual_encoder) @@ -408,6 +414,7 @@ def __init__( baseplanes=resnet_baseplanes, ngroups=resnet_baseplanes // 2, make_backbone=getattr(resnet, backbone), + normalize_visual_inputs=normalize_visual_inputs, ) if not self.visual_encoder.is_blind: diff --git a/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py index 816ce11c28..449beb8f21 100644 --- a/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py +++ b/habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py @@ -171,27 +171,27 @@ def _setup_actor_critic_agent(self, ppo_cfg: "DictConfig") -> None: self.config.habitat_baselines.rl.ddppo.pretrained_encoder or self.config.habitat_baselines.rl.ddppo.pretrained ): - pretrained_state = torch.load( + pretrained_state = self.load_checkpoint( self.config.habitat_baselines.rl.ddppo.pretrained_weights, map_location="cpu", ) - - if self.config.habitat_baselines.rl.ddppo.pretrained: - self.actor_critic.load_state_dict( - { # type: ignore - k[len("actor_critic.") :]: v - for k, v in pretrained_state["state_dict"].items() - } - ) - elif self.config.habitat_baselines.rl.ddppo.pretrained_encoder: - prefix = "actor_critic.net.visual_encoder." - self.actor_critic.net.visual_encoder.load_state_dict( - { - k[len(prefix) :]: v - for k, v in pretrained_state["state_dict"].items() - if k.startswith(prefix) - } - ) + if self.config.habitat_baselines.rl.ddppo.pretrained: + prefix = "actor_critic." + self.actor_critic.load_state_dict( + { # type: ignore + k[len(prefix) :]: v + for k, v in pretrained_state["state_dict"].items() + } + ) + else: + prefix = "actor_critic.net.visual_encoder." + self.actor_critic.net.visual_encoder.load_state_dict( + { + k[len(prefix) :]: v + for k, v in pretrained_state["state_dict"].items() + if k.startswith(prefix) + } + ) if not self.config.habitat_baselines.rl.ddppo.train_encoder: self._static_encoder = True diff --git a/habitat-lab/habitat/gym/gym_wrapper.py b/habitat-lab/habitat/gym/gym_wrapper.py index ed91fae7a0..dbb1fb704d 100644 --- a/habitat-lab/habitat/gym/gym_wrapper.py +++ b/habitat-lab/habitat/gym/gym_wrapper.py @@ -30,6 +30,12 @@ HabGymWrapperObsType = Union[np.ndarray, Dict[str, np.ndarray]] +def filter_observation_space(obs_space, limit_keys): + return spaces.Dict( + {k: v for k, v in obs_space.spaces.items() if k in limit_keys} + ) + + def smash_observation_space(obs_space, limit_keys): obs_shapes = [obs_space.spaces[k].shape for k in limit_keys] @@ -47,9 +53,7 @@ def transform_shape(shape): return spaces.Box( shape=(total_dim,), low=-1.0, high=1.0, dtype=np.float32 ) - return spaces.Dict( - {k: v for k, v in obs_space.spaces.items() if k in limit_keys} - ) + return filter_observation_space(obs_space, limit_keys) def _is_continuous(original_space: gym.Space) -> bool: @@ -211,7 +215,7 @@ def __init__( self.action_space = create_action_space(action_space) - self.observation_space = smash_observation_space( + self.observation_space = filter_observation_space( env.observation_space, self._gym_obs_keys ) diff --git a/test/test_gym_wrapper.py b/test/test_gym_wrapper.py index af7ca92389..c1fdcaf66a 100644 --- a/test/test_gym_wrapper.py +++ b/test/test_gym_wrapper.py @@ -30,7 +30,7 @@ "benchmark/rearrange/reach_state.yaml", [], 7, - np.ndarray, + dict, ), ( "benchmark/rearrange/pick.yaml", diff --git a/test/test_pointnav_resnet_policy.py b/test/test_pointnav_resnet_policy.py new file mode 100644 index 0000000000..bb0d85bc5a --- /dev/null +++ b/test/test_pointnav_resnet_policy.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os.path +import shlex +import subprocess + +import numpy as np +import pytest +import torch +from gym import spaces + +from habitat import read_write +from habitat_baselines.config.default import get_config +from habitat_baselines.rl.ddppo.policy import PointNavResNetPolicy + +ACTION_SPACE = spaces.Discrete(4) + +OBSERVATION_SPACES = { + "depth_model": spaces.Dict( + { + "depth": spaces.Box( + low=0, + high=1, + shape=(256, 256, 1), + dtype=np.float32, + ), + "pointgoal_with_gps_compass": spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ), + } + ), + "rgb_model": spaces.Dict( + { + "rgb": spaces.Box( + low=0, + high=255, + shape=(256, 256, 3), + dtype=np.uint8, + ), + "pointgoal_with_gps_compass": spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ), + } + ), + "blind_model": spaces.Dict( + { + "pointgoal_with_gps_compass": spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ), + } + ), +} + +MODELS_DEST_DIR = "data/ddppo-models" +MODELS_BASE_URL = "https://dl.fbaipublicfiles.com/habitat/data/baselines/v1/ddppo/ddppo-models" +MODELS_TO_TEST = { + "gibson-2plus-resnet50.pth": { + "backbone": "resnet50", + "observation_space": OBSERVATION_SPACES["depth_model"], + "action_space": ACTION_SPACE, + }, + "gibson-2plus-mp3d-train-val-test-se-resneXt50-rgb.pth": { + "backbone": "se_resneXt50", + "observation_space": OBSERVATION_SPACES["rgb_model"], + "action_space": ACTION_SPACE, + }, + "gibson-0plus-mp3d-train-val-test-blind.pth": { + "backbone": None, + "observation_space": OBSERVATION_SPACES["blind_model"], + "action_space": ACTION_SPACE, + }, +} + + +def _get_model_url(model_name): + return f"{MODELS_BASE_URL}/{model_name}" + + +def _get_model_path(model_name): + return f"{MODELS_DEST_DIR}/{model_name}" + + +@pytest.fixture(scope="module", autouse=True) +def download_data(): + for model_name in MODELS_TO_TEST: + model_url = _get_model_url(model_name) + model_path = _get_model_path(model_name) + if not os.path.exists(model_path): + print(f"Downloading {model_name}.") + download_command = ( + "wget --continue " + model_url + " -P " + MODELS_DEST_DIR + ) + subprocess.check_call(shlex.split(download_command)) + assert os.path.exists( + model_path + ), "Download failed, no package found." + + +@pytest.mark.parametrize( + "pretrained_weights_path,backbone,observation_space,action_space", + [ + ( + _get_model_path(model_name), + params["backbone"], + params["observation_space"], + params["action_space"], + ) + for model_name, params in MODELS_TO_TEST.items() + ], +) +def test_pretrained_models( + pretrained_weights_path, backbone, observation_space, action_space +): + config = get_config("test/ddppo_pointnav_test.yaml") + with read_write(config): + ddppo_config = config.habitat_baselines.rl.ddppo + ddppo_config.pretrained = True + ddppo_config.pretrained_weights = pretrained_weights_path + if backbone is not None: + ddppo_config.backbone = backbone + + policy = PointNavResNetPolicy.from_config( + config=config, + observation_space=observation_space, + action_space=action_space, + ) + + pretrained_state = torch.load(pretrained_weights_path, map_location="cpu") + + prefix = "actor_critic." + policy.load_state_dict( + { # type: ignore + k[len(prefix) :]: v + for k, v in pretrained_state["state_dict"].items() + } + )