Skip to content

Commit

Permalink
Replace hydra with argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Feb 16, 2024
1 parent a0ebbeb commit b0c0b16
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 35 deletions.
10 changes: 0 additions & 10 deletions conf/config.yaml

This file was deleted.

49 changes: 24 additions & 25 deletions simulate.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,55 @@
import time
import argparse

import hydra
from omegaconf import DictConfig, OmegaConf
from jax import random

from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.three_d_simulation import ThreeDSimulation


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_agents", type=int, default=5)
parser.add_argument("--max_agents", type=int, default=10)
parser.add_argument("--num_obs", type=int, default=3)
parser.add_argument("--grid_size", type=int, default=20)
parser.add_argument("--num_steps", type=int, default=60)
parser.add_argument("--step_delay", type=float, default=0.01)
parser.add_argument("--sim_type", type=str, default="two_d")
parser.add_argument("--random_seed", type=int, default=0)
parser.add_argument("--visualize", action="store_false")
args = parser.parse_args()

num_agents = cfg.params.num_agents
max_agents = cfg.params.max_agents
num_obs = cfg.params.num_obs
grid_size = cfg.params.grid_size
num_steps = cfg.params.num_steps
visualize = cfg.params.visualize
step_delay = cfg.params.step_delay
sim_type = cfg.params.sim_type

key = random.PRNGKey(cfg.params.random_seed)
key = random.PRNGKey(args.random_seed)

# Choose a simulation type
if sim_type == "two_d":
if args.sim_type == "two_d":
Simulation = SimpleSimulation
elif sim_type == "three_d":
elif args.sim_type == "three_d":
Simulation = ThreeDSimulation
else:
raise(ValueError(f"Unknown sim type {sim_type}"))
raise(ValueError(f"Unknown sim type {args.sim_type}"))

sim = Simulation(max_agents, grid_size)
state = sim.init_state(num_agents, num_obs, key)
sim = Simulation(args.max_agents, args.grid_size)
state = sim.init_state(args.num_agents, args.num_obs, key)

# Launch a simulation
print("Simulation started")
for timestep in range(num_steps):
time.sleep(step_delay)
for timestep in range(args.num_steps):
time.sleep(args.step_delay)
key, a_key, step_key = random.split(key, 3)

if timestep % 10 == 0:
print(f"\nstep {timestep}")

if timestep == (num_steps // 3):
if timestep == (args.num_steps // 3):
# Add 3 agents and change the color of an agent
state = sim.add_agent(state, 7)
state = sim.add_agent(state, 9)
state = sim.add_agent(state, 5)
state = state.replace(colors=state.colors.at[0, 2].set(1.0))

if timestep == 2* (num_steps // 3):
if timestep == 2* (args.num_steps // 3):
# Remove 3 other agents and change the color of another agent
state = sim.remove_agent(state, 2)
state = sim.remove_agent(state, 1)
Expand All @@ -61,7 +60,7 @@ def main(cfg: DictConfig):
actions = sim.choose_action(state.obs, a_key)
state = sim.step(state, actions, step_key)

if visualize:
if args.visualize:
Simulation.visualize_sim(state)
print("\nSimulation ended")

Expand Down

0 comments on commit b0c0b16

Please sign in to comment.