diff --git a/new_src/agent.py b/new_src/agent.py index d42de0c..a04a81d 100644 --- a/new_src/agent.py +++ b/new_src/agent.py @@ -1,6 +1,7 @@ from matplotlib import pyplot as plt from lux.kit import obs_to_game_state, GameState, EnvConfig from lux.utils import direction_to, my_turn_to_place_factory +from lux.forward_sim import stop_movement_collisions import numpy as np import sys class Agent(): @@ -127,6 +128,9 @@ def act(self, step: int, obs, remainingOverageTime: int = 60): move_cost = unit.move_cost(game_state, direction) if move_cost is not None and unit.power >= move_cost + unit.action_queue_cost(game_state): actions[unit_id] = [unit.move(direction, repeat=0, n=1)] + + actions = stop_movement_collisions(obs, game_state, self.env_cfg, self.player, actions) + return actions diff --git a/new_src/lux/forward_sim.py b/new_src/lux/forward_sim.py index 0e10da8..843ffbe 100644 --- a/new_src/lux/forward_sim.py +++ b/new_src/lux/forward_sim.py @@ -1,3 +1,10 @@ +from collections import defaultdict +from typing import Dict, List, Set, Tuple, Union + +from lux.unit import Unit +import numpy as np +import sys + def forward_sim(full_obs, env_cfg, n=2): """ Forward sims for `n` steps given the current full observation and env_cfg @@ -24,4 +31,111 @@ def forward_sim(full_obs, env_cfg, n=2): return [full_obs] obs, _, _, _, _ = env.step(empty_actions) forward_obs.append(obs[agent]) - return forward_obs \ No newline at end of file + return forward_obs + +def forward_sim_act(full_obs, env_cfg, player, action): + from luxai_s2 import LuxAI_S2 + # from luxai_s2.config import UnitConfig + # import copy + env = LuxAI_S2(collect_stats=False, verbose=0) + env.reset(seed=0) + env.state = env.state.from_obs(full_obs, env_cfg) + env.env_cfg = env.state.env_cfg + env.env_cfg.verbose = 0 + env.env_steps = env.state.env_steps + forward_obs = [full_obs] + empty_actions = dict() + for agent in env.agents: + empty_actions[agent] = dict() + empty_actions[player] = action + obs, _, _, _, _ = env.step(empty_actions) + return obs[player] + +move_deltas = np.array([[0, 0], [0, -1], [1, 0], [0, 1], [-1, 0]]) + +def stop_movement_collisions(obs, game_state, env_cfg, agent, actions): + units_map = defaultdict(list) + actions_by_type = defaultdict(list) + for unit in game_state.units[agent].values(): + units_map[tuple(unit.pos)].append(unit) + if unit.unit_id in actions: + unit_a = actions[unit.unit_id][0] + elif unit.action_queue: + unit_a = unit.action_queue[0] + else: + unit_a = None + if unit_a is not None: + actions_by_type[unit_a[0]].append((unit, unit_a)) + #for factory in game_state.factories[agent].values(): + # if len(factory.action_queue) > 0: + # unit_a: Action = factory.action_queue.pop(0) + # actions_by_type[unit_a.act_type].append((factory, unit_a)) + new_units_map: Dict[str, List[Unit]] = defaultdict(list) + heavy_entered_pos: Dict[str, List[Unit]] = defaultdict(list) + light_entered_pos: Dict[str, List[Unit]] = defaultdict(list) + + for unit, move_action in actions_by_type[0]: + # skip move center + if move_action[1] != 0: + old_pos_hash = tuple(unit.pos) + target_pos = ( + unit.pos + move_deltas[move_action[1]] + ) + # power_required = move_action.power_cost + # unit.pos = target_pos + new_pos_hash = tuple(target_pos) + if len(units_map[old_pos_hash]) == 1: + del units_map[old_pos_hash] + else: + units_map[old_pos_hash].remove(unit) + new_units_map[new_pos_hash].append(unit) + + if unit.unit_type == "HEAVY": + heavy_entered_pos[new_pos_hash].append(unit) + else: + light_entered_pos[new_pos_hash].append(unit) + + + for pos_hash, units in units_map.items(): + # add in all the stationary units + new_units_map[pos_hash] += units + + all_stopped_units: Set[Unit] = set() + # new_units_map_after_collision: Dict[str, List[Unit]] = defaultdict(list) + for pos_hash, units in new_units_map.items(): + stopped_units: Set[Unit] = set() + if len(units) <= 1: + # no collision + continue + if len(units_map[pos_hash]) > 0: + # There is a stationary unit, avoid. + surviving_unit = units_map[pos_hash][0] + for u in units: + if u.unit_id != surviving_unit.unit_id: + stopped_units.add(u) + elif len(heavy_entered_pos[pos_hash]) > 1: + # more than two heavy collide while moving, less powerful unit yields. + most_power_unit = units[0] + for u in units: + if u.unit_type == "HEAVY": + if u.power > most_power_unit.power: + most_power_unit = u + surviving_unit = most_power_unit + for u in units: + if u.unit_id != surviving_unit.unit_id: + stopped_units.add(u) + elif len(heavy_entered_pos[pos_hash]) > 0: + # one heavy and other light collide while moving, light yields. + surviving_unit = heavy_entered_pos[pos_hash][0] + for u in units: + if u.unit_id != surviving_unit.unit_id: + stopped_units.add(u) + # new_units_map_after_collision[pos_hash].append(surviving_unit) + else: + ... + # this is for factory spawn collision checking, which is skipped for now + all_stopped_units.update(stopped_units) + + for u in all_stopped_units: + actions[u.unit_id] = [u.move(0)] + return actions diff --git a/new_src/lux/unit.py b/new_src/lux/unit.py index 8bba5ee..8f07b55 100644 --- a/new_src/lux/unit.py +++ b/new_src/lux/unit.py @@ -74,4 +74,7 @@ def recharge(self, x, repeat=0, n=1): def __str__(self) -> str: out = f"[{self.team_id}] {self.unit_id} {self.unit_type} at {self.pos}" - return out \ No newline at end of file + return out + + def __hash__(self): + return hash(self.unit_id)