Skip to content

Commit

Permalink
Add collision stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
gyusang committed Nov 17, 2023
1 parent 0e280dc commit 2e5287b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
4 changes: 4 additions & 0 deletions new_src/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from matplotlib import pyplot as plt
from lux.kit import obs_to_game_state, GameState, EnvConfig
from lux.utils import direction_to, my_turn_to_place_factory
from lux.forward_sim import stop_movement_collisions
import numpy as np
import sys
class Agent():
Expand Down Expand Up @@ -127,6 +128,9 @@ def act(self, step: int, obs, remainingOverageTime: int = 60):
move_cost = unit.move_cost(game_state, direction)
if move_cost is not None and unit.power >= move_cost + unit.action_queue_cost(game_state):
actions[unit_id] = [unit.move(direction, repeat=0, n=1)]

actions = stop_movement_collisions(obs, game_state, self.env_cfg, self.player, actions)

return actions


Expand Down
116 changes: 115 additions & 1 deletion new_src/lux/forward_sim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from collections import defaultdict
from typing import Dict, List, Set, Tuple, Union

from lux.unit import Unit
import numpy as np
import sys

def forward_sim(full_obs, env_cfg, n=2):
"""
Forward sims for `n` steps given the current full observation and env_cfg
Expand All @@ -24,4 +31,111 @@ def forward_sim(full_obs, env_cfg, n=2):
return [full_obs]
obs, _, _, _, _ = env.step(empty_actions)
forward_obs.append(obs[agent])
return forward_obs
return forward_obs

def forward_sim_act(full_obs, env_cfg, player, action):
from luxai_s2 import LuxAI_S2
# from luxai_s2.config import UnitConfig
# import copy
env = LuxAI_S2(collect_stats=False, verbose=0)
env.reset(seed=0)
env.state = env.state.from_obs(full_obs, env_cfg)
env.env_cfg = env.state.env_cfg
env.env_cfg.verbose = 0
env.env_steps = env.state.env_steps
forward_obs = [full_obs]
empty_actions = dict()
for agent in env.agents:
empty_actions[agent] = dict()
empty_actions[player] = action
obs, _, _, _, _ = env.step(empty_actions)
return obs[player]

move_deltas = np.array([[0, 0], [0, -1], [1, 0], [0, 1], [-1, 0]])

def stop_movement_collisions(obs, game_state, env_cfg, agent, actions):
units_map = defaultdict(list)
actions_by_type = defaultdict(list)
for unit in game_state.units[agent].values():
units_map[tuple(unit.pos)].append(unit)
if unit.unit_id in actions:
unit_a = actions[unit.unit_id][0]
elif unit.action_queue:
unit_a = unit.action_queue[0]
else:
unit_a = None
if unit_a is not None:
actions_by_type[unit_a[0]].append((unit, unit_a))
#for factory in game_state.factories[agent].values():
# if len(factory.action_queue) > 0:
# unit_a: Action = factory.action_queue.pop(0)
# actions_by_type[unit_a.act_type].append((factory, unit_a))
new_units_map: Dict[str, List[Unit]] = defaultdict(list)
heavy_entered_pos: Dict[str, List[Unit]] = defaultdict(list)
light_entered_pos: Dict[str, List[Unit]] = defaultdict(list)

for unit, move_action in actions_by_type[0]:
# skip move center
if move_action[1] != 0:
old_pos_hash = tuple(unit.pos)
target_pos = (
unit.pos + move_deltas[move_action[1]]
)
# power_required = move_action.power_cost
# unit.pos = target_pos
new_pos_hash = tuple(target_pos)
if len(units_map[old_pos_hash]) == 1:
del units_map[old_pos_hash]
else:
units_map[old_pos_hash].remove(unit)
new_units_map[new_pos_hash].append(unit)

if unit.unit_type == "HEAVY":
heavy_entered_pos[new_pos_hash].append(unit)
else:
light_entered_pos[new_pos_hash].append(unit)


for pos_hash, units in units_map.items():
# add in all the stationary units
new_units_map[pos_hash] += units

all_stopped_units: Set[Unit] = set()
# new_units_map_after_collision: Dict[str, List[Unit]] = defaultdict(list)
for pos_hash, units in new_units_map.items():
stopped_units: Set[Unit] = set()
if len(units) <= 1:
# no collision
continue
if len(units_map[pos_hash]) > 0:
# There is a stationary unit, avoid.
surviving_unit = units_map[pos_hash][0]
for u in units:
if u.unit_id != surviving_unit.unit_id:
stopped_units.add(u)
elif len(heavy_entered_pos[pos_hash]) > 1:
# more than two heavy collide while moving, less powerful unit yields.
most_power_unit = units[0]
for u in units:
if u.unit_type == "HEAVY":
if u.power > most_power_unit.power:
most_power_unit = u
surviving_unit = most_power_unit
for u in units:
if u.unit_id != surviving_unit.unit_id:
stopped_units.add(u)
elif len(heavy_entered_pos[pos_hash]) > 0:
# one heavy and other light collide while moving, light yields.
surviving_unit = heavy_entered_pos[pos_hash][0]
for u in units:
if u.unit_id != surviving_unit.unit_id:
stopped_units.add(u)
# new_units_map_after_collision[pos_hash].append(surviving_unit)
else:
...
# this is for factory spawn collision checking, which is skipped for now
all_stopped_units.update(stopped_units)

for u in all_stopped_units:
actions[u.unit_id] = [u.move(0)]
return actions
5 changes: 4 additions & 1 deletion new_src/lux/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,7 @@ def recharge(self, x, repeat=0, n=1):

def __str__(self) -> str:
out = f"[{self.team_id}] {self.unit_id} {self.unit_type} at {self.pos}"
return out
return out

def __hash__(self):
return hash(self.unit_id)

0 comments on commit 2e5287b

Please sign in to comment.