diff --git a/conf/config.yaml b/conf/config.yaml deleted file mode 100644 index 83cfa46..0000000 --- a/conf/config.yaml +++ /dev/null @@ -1,10 +0,0 @@ -params: - num_agents: 5 - max_agents: 10 - num_obs: 3 - grid_size: 20 - num_steps: 50 - random_seed: 0 - visualize: True - step_delay: 0.01 - sim_type: "two_d" \ No newline at end of file diff --git a/simulate.py b/simulate.py index 6c1cd3a..badfd52 100644 --- a/simulate.py +++ b/simulate.py @@ -1,56 +1,55 @@ import time +import argparse -import hydra -from omegaconf import DictConfig, OmegaConf from jax import random from simulationsandbox.two_d_simulation import SimpleSimulation from simulationsandbox.three_d_simulation import ThreeDSimulation -@hydra.main(version_base=None, config_path="conf", config_name="config") -def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num_agents", type=int, default=5) + parser.add_argument("--max_agents", type=int, default=10) + parser.add_argument("--num_obs", type=int, default=3) + parser.add_argument("--grid_size", type=int, default=20) + parser.add_argument("--num_steps", type=int, default=60) + parser.add_argument("--step_delay", type=float, default=0.01) + parser.add_argument("--sim_type", type=str, default="two_d") + parser.add_argument("--random_seed", type=int, default=0) + parser.add_argument("--visualize", action="store_false") + args = parser.parse_args() - num_agents = cfg.params.num_agents - max_agents = cfg.params.max_agents - num_obs = cfg.params.num_obs - grid_size = cfg.params.grid_size - num_steps = cfg.params.num_steps - visualize = cfg.params.visualize - step_delay = cfg.params.step_delay - sim_type = cfg.params.sim_type - - key = random.PRNGKey(cfg.params.random_seed) + key = random.PRNGKey(args.random_seed) # Choose a simulation type - if sim_type == "two_d": + if args.sim_type == "two_d": Simulation = SimpleSimulation - elif sim_type == "three_d": + elif args.sim_type == "three_d": Simulation = ThreeDSimulation else: - raise(ValueError(f"Unknown sim type {sim_type}")) + raise(ValueError(f"Unknown sim type {args.sim_type}")) - sim = Simulation(max_agents, grid_size) - state = sim.init_state(num_agents, num_obs, key) + sim = Simulation(args.max_agents, args.grid_size) + state = sim.init_state(args.num_agents, args.num_obs, key) # Launch a simulation print("Simulation started") - for timestep in range(num_steps): - time.sleep(step_delay) + for timestep in range(args.num_steps): + time.sleep(args.step_delay) key, a_key, step_key = random.split(key, 3) if timestep % 10 == 0: print(f"\nstep {timestep}") - if timestep == (num_steps // 3): + if timestep == (args.num_steps // 3): # Add 3 agents and change the color of an agent state = sim.add_agent(state, 7) state = sim.add_agent(state, 9) state = sim.add_agent(state, 5) state = state.replace(colors=state.colors.at[0, 2].set(1.0)) - if timestep == 2* (num_steps // 3): + if timestep == 2* (args.num_steps // 3): # Remove 3 other agents and change the color of another agent state = sim.remove_agent(state, 2) state = sim.remove_agent(state, 1) @@ -61,7 +60,7 @@ def main(cfg: DictConfig): actions = sim.choose_action(state.obs, a_key) state = sim.step(state, actions, step_key) - if visualize: + if args.visualize: Simulation.visualize_sim(state) print("\nSimulation ended")