Skip to content

Commit

Permalink
Debugging action_queue update
Browse files Browse the repository at this point in the history
  • Loading branch information
caelum02 authored Oct 27, 2023
1 parent b256aac commit f551aa5
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,5 @@ cython_debug/
.vscode/

# Data
data/
data/
jux
14 changes: 10 additions & 4 deletions debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from utils import replay_run_early_phase, replay_run_n_late_game_step

MAX_N_UNITS = 200
n = 376
n = 281

# real_env_step 278에서 갑자기 245번 유닛의 정보가 사라지고 그 자리가 246번으로 대체됨
# id2idx가 바뀌어서 발생하는 것으로 예상이 됨

# replay_id = 52900827
replay_id = 52958192
lux_env, lux_actions = jux.utils.load_replay(f'replays/{replay_id}.json')
lux_env, lux_actions = jux.utils.load_replay(f'data/{replay_id}.json')
jux_env, state = JuxEnv.from_lux(lux_env, buf_cfg=JuxBufferConfig(MAX_N_UNITS=MAX_N_UNITS))

state, lux_actions = replay_run_early_phase(jux_env, state, lux_actions)
state, lux_actions = replay_run_n_late_game_step(n, jux_env, state, lux_actions)
# state, lux_actions = replay_run_early_phase(jux_env, state, lux_actions)
# state, lux_actions = replay_run_n_late_game_step(n, jux_env, state, lux_actions)

state, lux_actions, lux_env = replay_run_early_phase(jux_env, state, lux_actions, lux_env=lux_env)
state, lux_actions, lux_env = replay_run_n_late_game_step(n, jux_env, state, lux_actions, lux_env=lux_env)
2 changes: 1 addition & 1 deletion download_replay.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

#Please write ids of replay to download
ids=(52965522 51283939)
ids=(52965522 52958192 52900827)

mkdir -p data && cd data

Expand Down
96 changes: 83 additions & 13 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import jux
from jax import jit
from jax.tree_util import tree_map
import jax.numpy as jnp

from luxai_s2.env import LuxAI_S2
from luxai_s2.actions import format_action_vec

import jux
from jux.env import JuxEnv
from jux.config import JuxBufferConfig
from jux.state import State
from jux.actions import JuxAction

def replay_run_early_phase(jux_env: JuxEnv, state: State, lux_actions):
def get_unit_idx(state: State, id: int):
unit_idx = state.unit_id2idx[id]
return unit_idx[...,0], unit_idx[...,1]

def get_unit_pos(state: State, id: int):
return state.units.pos.pos[get_unit_idx(state, id)]

def replay_run_early_phase(jux_env: JuxEnv, state: State, lux_actions, lux_env=None):
"""
Util function | skip game until late_game stage
Expand All @@ -16,34 +29,91 @@ def replay_run_early_phase(jux_env: JuxEnv, state: State, lux_actions):

print(f"[Replay Util] Replaying early steps")

with_lux_env = lux_env is not None

# Bid Step
bid, faction = jux.actions.bid_action_from_lux(next(lux_actions))
lux_action = next(lux_actions)
bid, faction = jux.actions.bid_action_from_lux(lux_action)
state, (obs, rwd, dones, infos) = jux_env.step_bid(state, bid, faction)
if with_lux_env:
lux_env.step(lux_action)

# Factory Placement Step
while state.real_env_steps < 0:
lux_act = next(lux_actions)
spawn, water, metal = jux.actions.factory_placement_action_from_lux(lux_act)
lux_action = next(lux_actions)
spawn, water, metal = jux.actions.factory_placement_action_from_lux(lux_action)
state, _ = jux_env.step_factory_placement(state, spawn, water, metal)
if with_lux_env:
lux_env.step(lux_action)

print(f"[Replay Util] Replaying early steps - Done")

if with_lux_env:
return state, lux_actions, lux_env
return state, lux_actions

def replay_run_n_late_game_step(n: int, jux_env: JuxEnv, state: State, lux_actions):
def replay_run_n_late_game_step(n: int, jux_env: JuxEnv, state: State, lux_actions, lux_env=None):

with_lux_env = lux_env is not None

for i in range(n):
print(f"[Replay Util] Replaying {i+1}/{n} steps")
print(f"[Replay Util] Replaying {i}/{n} steps")
lux_act = next(lux_actions)
jux_act = JuxAction.from_lux(state, lux_act)
# assert jux_act.to_lux(state) == lux_act, f"JuxAction.to_lux() is not reversible: {jux_act.to_lux(state)} != {lux_act}"

# step
state, _ = jux_env.step_late_game(state, jux_act)
state, (obs, rwd, dones, infos) = jux_env.step_late_game(state, jux_act)

assert ((state.units.action_queue.count == 0) |
((0 <= state.units.action_queue.front < 20) &
(0 <= state.units.action_queue.rear < 20))).all()

if with_lux_env:
obs = lux_env.step(lux_act)[0]
obs = obs['player_0']

if i >= 8: # unit 88 action_queue mismatch
run_check_action_queue(obs, state)
if i % 20 == 10:
run_check_pos(obs, state)

if dones[0]:
print(f"[Replay Util] Replaying {i+1}/{n} steps - Done")
break

if with_lux_env:
return state, lux_actions, lux_env
return state, lux_actions

if __name__ == "__main__":
lux_env, lux_actions = jux.utils.load_replay('replays/52958192.json')
jux_env, state = JuxEnv.from_lux(lux_env, buf_cfg=JuxBufferConfig(MAX_N_UNITS=200))
state, lux_actions = replay_run_early_phase(jux_env, state, lux_actions)
state, lux_actions = replay_run_n_late_game_step(100, jux_env, state, lux_actions)
def run_check_pos(obs, state):
for player_id, player_units in obs['units'].items():
for unit_id, unit in player_units.items():
assert unit_id.startswith('unit_')
int_id = int(unit_id.split('_')[1])
pos = get_unit_pos(state, int_id)
assert (pos == unit['pos']).all()

def get_action_queue_from_id(state, id):
idx = get_unit_idx(state, id)
from_jux = tree_map(lambda x: x[idx], state.units.action_queue).to_lux()
return from_jux

def get_unit_from_id(state, id):
idx = get_unit_idx(state, id)
return tree_map(lambda x: x[idx], state.units)

def run_check_action_queue(obs, state):
for player_id, player_units in obs['units'].items():
for unit_id, unit in player_units.items():
int_id = int(unit_id.split('_')[1])
from_jux = get_action_queue_from_id(state, int_id)
if len(unit['action_queue']) == 0:
continue
from_lux = jnp.stack(unit['action_queue'])
assert (from_lux == from_jux).all()


def print_action_queue(array):
for i, a in enumerate(array):
print(f"{i}: {format_action_vec(a)}")

0 comments on commit f551aa5

Please sign in to comment.