Skip to content

Commit

Permalink
Merge pull request oxwhirl#75 from oxwhirl/douglasrizzo-state_attr_names
Browse files Browse the repository at this point in the history
Provide state in structured form as a dict and features names in a list
  • Loading branch information
samvelyan authored Jul 5, 2021
2 parents 013cf27 + 2c94fd1 commit fcdd0d3
Showing 1 changed file with 92 additions and 33 deletions.
125 changes: 92 additions & 33 deletions smac/env/starcraft2/starcraft2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from smac.env.starcraft2.maps import get_map_params

import atexit
from warnings import warn
from operator import attrgetter
from copy import deepcopy
import numpy as np
Expand Down Expand Up @@ -249,11 +250,26 @@ def __init__(
self.shield_bits_enemy = 1 if self._bot_race == "P" else 0
self.unit_type_bits = map_params["unit_type_bits"]
self.map_type = map_params["map_type"]
self._unit_types = None

self.max_reward = (
self.n_enemies * self.reward_death_value + self.reward_win
)

# create lists containing the names of attributes returned in states
self.ally_state_attr_names = ['health', 'energy/cooldown', 'rel_x', 'rel_y']
self.enemy_state_attr_names = ['health', 'rel_x', 'rel_y']

if self.shield_bits_ally > 0:
self.ally_state_attr_names += ['shield']
if self.shield_bits_enemy > 0:
self.enemy_state_attr_names += ['shield']

if self.unit_type_bits > 0:
bit_attr_names = ['type_{}'.format(bit) for bit in range(self.unit_type_bits)]
self.ally_state_attr_names += bit_attr_names
self.enemy_state_attr_names += bit_attr_names

self.agents = {}
self.enemies = {}
self._episode_count = 0
Expand Down Expand Up @@ -1036,8 +1052,45 @@ def get_state(self):
)
return obs_concat

nf_al = 4 + self.shield_bits_ally + self.unit_type_bits
nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits
state_dict = self.get_state_dict()

state = np.append(state_dict['allies'].flatten(), state_dict['enemies'].flatten())
if 'last_action' in state_dict:
state = np.append(state, state_dict['last_action'].flatten())
if 'timestep' in state_dict:
state = np.append(state, state_dict['timestep'])

state = state.astype(dtype=np.float32)

if self.debug:
logging.debug("STATE".center(60, "-"))
logging.debug("Ally state {}".format(state_dict['allies']))
logging.debug("Enemy state {}".format(state_dict['enemies']))
if self.state_last_action:
logging.debug("Last actions {}".format(self.last_action))

return state

def get_ally_num_attributes(self):
return len(self.ally_state_attr_names)

def get_enemy_num_attributes(self):
return len(self.enemy_state_attr_names)

def get_state_dict(self):
"""Returns the global state as a dictionary.
- allies: numpy array containing agents and their attributes
- enemies: numpy array containing enemies and their attributes
- last_action: numpy array of previous actions for each agent
- timestep: current no. of steps divided by total no. of steps
NOTE: This function should not be used during decentralised execution.
"""

# number of features equals the number of attribute names
nf_al = self.get_ally_num_attributes()
nf_en = self.get_enemy_num_attributes()

ally_state = np.zeros((self.n_agents, nf_al))
enemy_state = np.zeros((self.n_enemies, nf_en))
Expand Down Expand Up @@ -1070,17 +1123,15 @@ def get_state(self):
y - center_y
) / self.max_distance_y # relative Y

ind = 4
if self.shield_bits_ally > 0:
max_shield = self.unit_max_shield(al_unit)
ally_state[al_id, ind] = (
al_unit.shield / max_shield
) # shield
ind += 1
ally_state[al_id, 4] = (
al_unit.shield / max_shield) # shield

if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(al_unit, True)
ally_state[al_id, ind + type_id] = 1
type_id = self.get_unit_type_id(
al_unit, True)
ally_state[al_id, type_id - self.unit_type_bits] = 1

for e_id, e_unit in self.enemies.items():
if e_unit.health > 0:
Expand All @@ -1097,33 +1148,22 @@ def get_state(self):
y - center_y
) / self.max_distance_y # relative Y

ind = 3
if self.shield_bits_enemy > 0:
max_shield = self.unit_max_shield(e_unit)
enemy_state[e_id, ind] = (
e_unit.shield / max_shield
) # shield
ind += 1
enemy_state[e_id, 3] = (
e_unit.shield / max_shield) # shield

if self.unit_type_bits > 0:
type_id = self.get_unit_type_id(e_unit, False)
enemy_state[e_id, ind + type_id] = 1
type_id = self.get_unit_type_id(
e_unit, False)
enemy_state[e_id, type_id - self.unit_type_bits] = 1

state = {'allies': ally_state, 'enemies': enemy_state}

state = np.append(ally_state.flatten(), enemy_state.flatten())
if self.state_last_action:
state = np.append(state, self.last_action.flatten())
state['last_action'] = self.last_action
if self.state_timestep_number:
state = np.append(state,
self._episode_steps / self.episode_limit)

state = state.astype(dtype=np.float32)

if self.debug:
logging.debug("STATE".center(60, "-"))
logging.debug("Ally state {}".format(ally_state))
logging.debug("Enemy state {}".format(enemy_state))
if self.state_last_action:
logging.debug("Last actions {}".format(self.last_action))
state['timestep'] = self._episode_steps / self.episode_limit

return state

Expand Down Expand Up @@ -1207,12 +1247,12 @@ def get_state_size(self):
return size

def get_visibility_matrix(self):
"""Returns a boolean numpy array of dimensions
"""Returns a boolean numpy array of dimensions
(n_agents, n_agents + n_enemies) indicating which units
are visible to each agent.
"""
arr = np.zeros(
(self.n_agents, self.n_agents + self.n_enemies),
(self.n_agents, self.n_agents + self.n_enemies),
dtype=np.bool,
)

Expand All @@ -1235,7 +1275,7 @@ def get_visibility_matrix(self):

# The matrix for allies is filled symmetrically
al_ids = [
al_id for al_id in range(self.n_agents)
al_id for al_id in range(self.n_agents)
if al_id > agent_id
]
for i, al_id in enumerate(al_ids):
Expand All @@ -1244,7 +1284,7 @@ def get_visibility_matrix(self):
al_y = al_unit.pos.y
dist = self.distance(x, y, al_x, al_y)

if (dist < sight_range and al_unit.health > 0):
if (dist < sight_range and al_unit.health > 0):
# visible and alive
arr[agent_id, al_id] = arr[al_id, agent_id] = 1

Expand Down Expand Up @@ -1403,6 +1443,13 @@ def init_units(self):
all_agents_created = (len(self.agents) == self.n_agents)
all_enemies_created = (len(self.enemies) == self.n_enemies)

self._unit_types = [unit.unit_type
for unit in ally_units_sorted] + [
unit.unit_type
for unit in self._obs.observation.raw_data.
units if unit.owner == 2
]

if all_agents_created and all_enemies_created: # all good
return

Expand All @@ -1413,6 +1460,12 @@ def init_units(self):
self.full_restart()
self.reset()

def get_unit_types(self):
if self._unit_types is None:
warn('unit types have not been initialized yet, please call env.reset() to populate this and call the method again.')

return self._unit_types

def update_units(self):
"""Update units after an environment step.
This function assumes that self._obs is up-to-date.
Expand Down Expand Up @@ -1527,3 +1580,9 @@ def get_stats(self):
"restarts": self.force_restarts,
}
return stats

def get_env_info(self):
env_info = super().get_env_info()
env_info["agent_features"] = self.ally_state_attr_names
env_info["enemy_features"] = self.enemy_state_attr_names
return env_info

0 comments on commit fcdd0d3

Please sign in to comment.