Skip to content

Commit

Permalink
Fix loading pre-trained DD-PPO checkpoints (facebookresearch#1116)
Browse files Browse the repository at this point in the history
* fix loading pretrained depth and rgb ddppo checkpoints listed here: https://github.com/facebookresearch/habitat-lab/blob/6030bfdf0acac9b28e3251b4a83d81f176b8b3b6/habitat-baselines/habitat_baselines/rl/ddppo/README.md

* fixed loading pretrained blind model

* added unit test to check if pretrained ddppo checkpoints can be loaded

* fixed test/test_gym_wrapper.py test
  • Loading branch information
rpartsey authored Mar 13, 2023
1 parent 87658d2 commit 7cc06be
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
rnn_type: str = "GRU",
resnet_baseplanes: int = 32,
backbone: str = "resnet18",
normalize_visual_inputs: bool = False,
force_blind_policy: bool = False,
policy_config: "DictConfig" = None,
aux_loss_config: Optional["DictConfig"] = None,
Expand All @@ -77,6 +78,7 @@ def __init__(
rnn_type=rnn_type,
backbone=backbone,
resnet_baseplanes=resnet_baseplanes,
normalize_visual_inputs=normalize_visual_inputs,
fuse_keys=fuse_keys,
force_blind_policy=force_blind_policy,
discrete_actions=discrete_actions,
Expand Down Expand Up @@ -115,6 +117,7 @@ def from_config(
rnn_type=config.habitat_baselines.rl.ddppo.rnn_type,
num_recurrent_layers=config.habitat_baselines.rl.ddppo.num_recurrent_layers,
backbone=config.habitat_baselines.rl.ddppo.backbone,
normalize_visual_inputs="rgb" in observation_space.spaces,
force_blind_policy=config.habitat_baselines.force_blind_policy,
policy_config=config.habitat_baselines.rl.policy,
aux_loss_config=config.habitat_baselines.rl.auxiliary_losses,
Expand All @@ -130,6 +133,7 @@ def __init__(
ngroups: int = 32,
spatial_size: int = 128,
make_backbone=None,
normalize_visual_inputs: bool = False,
):
super().__init__()

Expand All @@ -149,7 +153,7 @@ def __init__(
observation_space.spaces[k].shape[2] for k in self.visual_keys
)

if self._n_input_channels > 0:
if normalize_visual_inputs:
self.running_mean_and_var: nn.Module = RunningMeanAndVar(
self._n_input_channels
)
Expand Down Expand Up @@ -252,6 +256,7 @@ def __init__(
rnn_type: str,
backbone,
resnet_baseplanes,
normalize_visual_inputs: bool,
fuse_keys: Optional[List[str]],
force_blind_policy: bool = False,
discrete_actions: bool = True,
Expand Down Expand Up @@ -376,6 +381,7 @@ def __init__(
baseplanes=resnet_baseplanes,
ngroups=resnet_baseplanes // 2,
make_backbone=getattr(resnet, backbone),
normalize_visual_inputs=normalize_visual_inputs,
)
setattr(self, f"{uuid}_encoder", goal_visual_encoder)

Expand Down Expand Up @@ -408,6 +414,7 @@ def __init__(
baseplanes=resnet_baseplanes,
ngroups=resnet_baseplanes // 2,
make_backbone=getattr(resnet, backbone),
normalize_visual_inputs=normalize_visual_inputs,
)

if not self.visual_encoder.is_blind:
Expand Down
36 changes: 18 additions & 18 deletions habitat-baselines/habitat_baselines/rl/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,27 @@ def _setup_actor_critic_agent(self, ppo_cfg: "DictConfig") -> None:
self.config.habitat_baselines.rl.ddppo.pretrained_encoder
or self.config.habitat_baselines.rl.ddppo.pretrained
):
pretrained_state = torch.load(
pretrained_state = self.load_checkpoint(
self.config.habitat_baselines.rl.ddppo.pretrained_weights,
map_location="cpu",
)

if self.config.habitat_baselines.rl.ddppo.pretrained:
self.actor_critic.load_state_dict(
{ # type: ignore
k[len("actor_critic.") :]: v
for k, v in pretrained_state["state_dict"].items()
}
)
elif self.config.habitat_baselines.rl.ddppo.pretrained_encoder:
prefix = "actor_critic.net.visual_encoder."
self.actor_critic.net.visual_encoder.load_state_dict(
{
k[len(prefix) :]: v
for k, v in pretrained_state["state_dict"].items()
if k.startswith(prefix)
}
)
if self.config.habitat_baselines.rl.ddppo.pretrained:
prefix = "actor_critic."
self.actor_critic.load_state_dict(
{ # type: ignore
k[len(prefix) :]: v
for k, v in pretrained_state["state_dict"].items()
}
)
else:
prefix = "actor_critic.net.visual_encoder."
self.actor_critic.net.visual_encoder.load_state_dict(
{
k[len(prefix) :]: v
for k, v in pretrained_state["state_dict"].items()
if k.startswith(prefix)
}
)

if not self.config.habitat_baselines.rl.ddppo.train_encoder:
self._static_encoder = True
Expand Down
12 changes: 8 additions & 4 deletions habitat-lab/habitat/gym/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
HabGymWrapperObsType = Union[np.ndarray, Dict[str, np.ndarray]]


def filter_observation_space(obs_space, limit_keys):
return spaces.Dict(
{k: v for k, v in obs_space.spaces.items() if k in limit_keys}
)


def smash_observation_space(obs_space, limit_keys):
obs_shapes = [obs_space.spaces[k].shape for k in limit_keys]

Expand All @@ -47,9 +53,7 @@ def transform_shape(shape):
return spaces.Box(
shape=(total_dim,), low=-1.0, high=1.0, dtype=np.float32
)
return spaces.Dict(
{k: v for k, v in obs_space.spaces.items() if k in limit_keys}
)
return filter_observation_space(obs_space, limit_keys)


def _is_continuous(original_space: gym.Space) -> bool:
Expand Down Expand Up @@ -211,7 +215,7 @@ def __init__(

self.action_space = create_action_space(action_space)

self.observation_space = smash_observation_space(
self.observation_space = filter_observation_space(
env.observation_space, self._gym_obs_keys
)

Expand Down
2 changes: 1 addition & 1 deletion test/test_gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"benchmark/rearrange/reach_state.yaml",
[],
7,
np.ndarray,
dict,
),
(
"benchmark/rearrange/pick.yaml",
Expand Down
148 changes: 148 additions & 0 deletions test/test_pointnav_resnet_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os.path
import shlex
import subprocess

import numpy as np
import pytest
import torch
from gym import spaces

from habitat import read_write
from habitat_baselines.config.default import get_config
from habitat_baselines.rl.ddppo.policy import PointNavResNetPolicy

ACTION_SPACE = spaces.Discrete(4)

OBSERVATION_SPACES = {
"depth_model": spaces.Dict(
{
"depth": spaces.Box(
low=0,
high=1,
shape=(256, 256, 1),
dtype=np.float32,
),
"pointgoal_with_gps_compass": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
),
}
),
"rgb_model": spaces.Dict(
{
"rgb": spaces.Box(
low=0,
high=255,
shape=(256, 256, 3),
dtype=np.uint8,
),
"pointgoal_with_gps_compass": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
),
}
),
"blind_model": spaces.Dict(
{
"pointgoal_with_gps_compass": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
),
}
),
}

MODELS_DEST_DIR = "data/ddppo-models"
MODELS_BASE_URL = "https://dl.fbaipublicfiles.com/habitat/data/baselines/v1/ddppo/ddppo-models"
MODELS_TO_TEST = {
"gibson-2plus-resnet50.pth": {
"backbone": "resnet50",
"observation_space": OBSERVATION_SPACES["depth_model"],
"action_space": ACTION_SPACE,
},
"gibson-2plus-mp3d-train-val-test-se-resneXt50-rgb.pth": {
"backbone": "se_resneXt50",
"observation_space": OBSERVATION_SPACES["rgb_model"],
"action_space": ACTION_SPACE,
},
"gibson-0plus-mp3d-train-val-test-blind.pth": {
"backbone": None,
"observation_space": OBSERVATION_SPACES["blind_model"],
"action_space": ACTION_SPACE,
},
}


def _get_model_url(model_name):
return f"{MODELS_BASE_URL}/{model_name}"


def _get_model_path(model_name):
return f"{MODELS_DEST_DIR}/{model_name}"


@pytest.fixture(scope="module", autouse=True)
def download_data():
for model_name in MODELS_TO_TEST:
model_url = _get_model_url(model_name)
model_path = _get_model_path(model_name)
if not os.path.exists(model_path):
print(f"Downloading {model_name}.")
download_command = (
"wget --continue " + model_url + " -P " + MODELS_DEST_DIR
)
subprocess.check_call(shlex.split(download_command))
assert os.path.exists(
model_path
), "Download failed, no package found."


@pytest.mark.parametrize(
"pretrained_weights_path,backbone,observation_space,action_space",
[
(
_get_model_path(model_name),
params["backbone"],
params["observation_space"],
params["action_space"],
)
for model_name, params in MODELS_TO_TEST.items()
],
)
def test_pretrained_models(
pretrained_weights_path, backbone, observation_space, action_space
):
config = get_config("test/ddppo_pointnav_test.yaml")
with read_write(config):
ddppo_config = config.habitat_baselines.rl.ddppo
ddppo_config.pretrained = True
ddppo_config.pretrained_weights = pretrained_weights_path
if backbone is not None:
ddppo_config.backbone = backbone

policy = PointNavResNetPolicy.from_config(
config=config,
observation_space=observation_space,
action_space=action_space,
)

pretrained_state = torch.load(pretrained_weights_path, map_location="cpu")

prefix = "actor_critic."
policy.load_state_dict(
{ # type: ignore
k[len(prefix) :]: v
for k, v in pretrained_state["state_dict"].items()
}
)

0 comments on commit 7cc06be

Please sign in to comment.