Skip to content

Commit

Permalink
fix: all envs now construct entities correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Jun 7, 2024
1 parent dff090b commit db4310d
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 19 deletions.
17 changes: 4 additions & 13 deletions baselines/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,17 @@ class Args:
wandb.init(project=args.project_name, config=config)

# init environment
env = FlattenObsWrapper(nx.make(env_id))
env = nx.make(env_id)
env = FlattenObsWrapper(env)

# create agent
network = nn.Sequential(
[
nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
),
nn.tanh,
nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
),
nn.tanh,
]
)
agent = PPO(
hparams=args.ppo,
network=ActorCritic(action_dim=len(env.action_set)),
env=env,
)

# run experiment
experiment = nx.Experiment(
name=args.project_name,
agent=agent,
Expand Down
4 changes: 2 additions & 2 deletions navix/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from navix.environments.environment import Timestep
from navix.states import State

from .models import ActorCriticRNN
from .models import ActorCritic


@dataclass
Expand Down Expand Up @@ -87,7 +87,7 @@ class TrainingState(TrainState):

class PPO(Agent):
hparams: PPOHparams = struct.field(pytree_node=False)
network: ActorCriticRNN = struct.field(pytree_node=False)
network: ActorCritic = struct.field(pytree_node=False)
env: Environment

def collect_experience(
Expand Down
2 changes: 1 addition & 1 deletion navix/environments/dynamic_obstacles.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
)
# goal
goal_pos = jnp.asarray([self.height - 2, self.width - 2])
goal = Goal(position=goal_pos, probability=jnp.asarray(1.0))
goal = Goal.create(position=goal_pos, probability=jnp.asarray(1.0))

# balls
exclude = jnp.stack([player_pos, goal_pos])
Expand Down
6 changes: 3 additions & 3 deletions navix/environments/lava_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times


register_env(
"Navix-DoorKey-S5-v0",
"Navix-LavaGap-S5-v0",
lambda *args, **kwargs: LavaGap.create(
*args,
**kwargs,
Expand All @@ -106,7 +106,7 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
),
)
register_env(
"Navix-DoorKey-S6-v0",
"Navix-LavaGap-S6-v0",
lambda *args, **kwargs: LavaGap.create(
*args,
**kwargs,
Expand All @@ -118,7 +118,7 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
),
)
register_env(
"Navix-DoorKey-S7-v0",
"Navix-LavaGap-S7-v0",
lambda *args, **kwargs: LavaGap.create(
*args,
**kwargs,
Expand Down

0 comments on commit db4310d

Please sign in to comment.