From f551aa529592f790ccc10b7a7f3a8939fb05d202 Mon Sep 17 00:00:00 2001 From: Haneul Choi Date: Fri, 27 Oct 2023 10:03:36 +0000 Subject: [PATCH] Debugging action_queue update --- .gitignore | 3 +- debug.py | 14 +++++-- download_replay.sh | 2 +- utils.py | 96 +++++++++++++++++++++++++++++++++++++++------- 4 files changed, 96 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index 6894c3d..f906614 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ cython_debug/ .vscode/ # Data -data/ \ No newline at end of file +data/ +jux diff --git a/debug.py b/debug.py index 569d10c..0562bcc 100644 --- a/debug.py +++ b/debug.py @@ -5,12 +5,18 @@ from utils import replay_run_early_phase, replay_run_n_late_game_step MAX_N_UNITS = 200 -n = 376 +n = 281 + +# real_env_step 278에서 갑자기 245번 유닛의 정보가 사라지고 그 자리가 246번으로 대체됨 +# id2idx가 바뀌어서 발생하는 것으로 예상이 됨 # replay_id = 52900827 replay_id = 52958192 -lux_env, lux_actions = jux.utils.load_replay(f'replays/{replay_id}.json') +lux_env, lux_actions = jux.utils.load_replay(f'data/{replay_id}.json') jux_env, state = JuxEnv.from_lux(lux_env, buf_cfg=JuxBufferConfig(MAX_N_UNITS=MAX_N_UNITS)) -state, lux_actions = replay_run_early_phase(jux_env, state, lux_actions) -state, lux_actions = replay_run_n_late_game_step(n, jux_env, state, lux_actions) \ No newline at end of file +# state, lux_actions = replay_run_early_phase(jux_env, state, lux_actions) +# state, lux_actions = replay_run_n_late_game_step(n, jux_env, state, lux_actions) + +state, lux_actions, lux_env = replay_run_early_phase(jux_env, state, lux_actions, lux_env=lux_env) +state, lux_actions, lux_env = replay_run_n_late_game_step(n, jux_env, state, lux_actions, lux_env=lux_env) \ No newline at end of file diff --git a/download_replay.sh b/download_replay.sh index 263e3db..f60441c 100755 --- a/download_replay.sh +++ b/download_replay.sh @@ -1,7 +1,7 @@ #!/bin/bash #Please write ids of replay to download -ids=(52965522 51283939) +ids=(52965522 52958192 52900827) mkdir -p data && cd data diff --git a/utils.py b/utils.py index b4c6dc1..c9c548c 100644 --- a/utils.py +++ b/utils.py @@ -1,11 +1,24 @@ -import jux +from jax import jit +from jax.tree_util import tree_map +import jax.numpy as jnp + +from luxai_s2.env import LuxAI_S2 +from luxai_s2.actions import format_action_vec +import jux from jux.env import JuxEnv from jux.config import JuxBufferConfig from jux.state import State from jux.actions import JuxAction -def replay_run_early_phase(jux_env: JuxEnv, state: State, lux_actions): +def get_unit_idx(state: State, id: int): + unit_idx = state.unit_id2idx[id] + return unit_idx[...,0], unit_idx[...,1] + +def get_unit_pos(state: State, id: int): + return state.units.pos.pos[get_unit_idx(state, id)] + +def replay_run_early_phase(jux_env: JuxEnv, state: State, lux_actions, lux_env=None): """ Util function | skip game until late_game stage @@ -16,34 +29,91 @@ def replay_run_early_phase(jux_env: JuxEnv, state: State, lux_actions): print(f"[Replay Util] Replaying early steps") + with_lux_env = lux_env is not None + # Bid Step - bid, faction = jux.actions.bid_action_from_lux(next(lux_actions)) + lux_action = next(lux_actions) + bid, faction = jux.actions.bid_action_from_lux(lux_action) state, (obs, rwd, dones, infos) = jux_env.step_bid(state, bid, faction) + if with_lux_env: + lux_env.step(lux_action) # Factory Placement Step while state.real_env_steps < 0: - lux_act = next(lux_actions) - spawn, water, metal = jux.actions.factory_placement_action_from_lux(lux_act) + lux_action = next(lux_actions) + spawn, water, metal = jux.actions.factory_placement_action_from_lux(lux_action) state, _ = jux_env.step_factory_placement(state, spawn, water, metal) + if with_lux_env: + lux_env.step(lux_action) print(f"[Replay Util] Replaying early steps - Done") + if with_lux_env: + return state, lux_actions, lux_env return state, lux_actions -def replay_run_n_late_game_step(n: int, jux_env: JuxEnv, state: State, lux_actions): +def replay_run_n_late_game_step(n: int, jux_env: JuxEnv, state: State, lux_actions, lux_env=None): + with_lux_env = lux_env is not None + for i in range(n): - print(f"[Replay Util] Replaying {i+1}/{n} steps") + print(f"[Replay Util] Replaying {i}/{n} steps") lux_act = next(lux_actions) jux_act = JuxAction.from_lux(state, lux_act) + # assert jux_act.to_lux(state) == lux_act, f"JuxAction.to_lux() is not reversible: {jux_act.to_lux(state)} != {lux_act}" # step - state, _ = jux_env.step_late_game(state, jux_act) + state, (obs, rwd, dones, infos) = jux_env.step_late_game(state, jux_act) + + assert ((state.units.action_queue.count == 0) | + ((0 <= state.units.action_queue.front < 20) & + (0 <= state.units.action_queue.rear < 20))).all() + + if with_lux_env: + obs = lux_env.step(lux_act)[0] + obs = obs['player_0'] + + if i >= 8: # unit 88 action_queue mismatch + run_check_action_queue(obs, state) + if i % 20 == 10: + run_check_pos(obs, state) + + if dones[0]: + print(f"[Replay Util] Replaying {i+1}/{n} steps - Done") + break + if with_lux_env: + return state, lux_actions, lux_env return state, lux_actions -if __name__ == "__main__": - lux_env, lux_actions = jux.utils.load_replay('replays/52958192.json') - jux_env, state = JuxEnv.from_lux(lux_env, buf_cfg=JuxBufferConfig(MAX_N_UNITS=200)) - state, lux_actions = replay_run_early_phase(jux_env, state, lux_actions) - state, lux_actions = replay_run_n_late_game_step(100, jux_env, state, lux_actions) \ No newline at end of file +def run_check_pos(obs, state): + for player_id, player_units in obs['units'].items(): + for unit_id, unit in player_units.items(): + assert unit_id.startswith('unit_') + int_id = int(unit_id.split('_')[1]) + pos = get_unit_pos(state, int_id) + assert (pos == unit['pos']).all() + +def get_action_queue_from_id(state, id): + idx = get_unit_idx(state, id) + from_jux = tree_map(lambda x: x[idx], state.units.action_queue).to_lux() + return from_jux + +def get_unit_from_id(state, id): + idx = get_unit_idx(state, id) + return tree_map(lambda x: x[idx], state.units) + +def run_check_action_queue(obs, state): + for player_id, player_units in obs['units'].items(): + for unit_id, unit in player_units.items(): + int_id = int(unit_id.split('_')[1]) + from_jux = get_action_queue_from_id(state, int_id) + if len(unit['action_queue']) == 0: + continue + from_lux = jnp.stack(unit['action_queue']) + assert (from_lux == from_jux).all() + + +def print_action_queue(array): + for i, a in enumerate(array): + print(f"{i}: {format_action_vec(a)}") \ No newline at end of file