diff --git a/README.md b/README.md index 4e80349c..2bebba27 100644 --- a/README.md +++ b/README.md @@ -1 +1,148 @@ -# smac \ No newline at end of file +```diff +- This is a BETA release. +``` + +# SMAC - StarCraft Multi-Agent Challenge + +[SMAC](https://github.com/oxwhirl/smac) is [WhiRL](http://whirl.cs.ox.ac.uk)'s environment for research in the field of collaborative multi-agent reinforcement learning (MARL) based on [Blizzard](http://blizzard.com)'s [StarCraft II](https://en.wikipedia.org/wiki/StarCraft_II:_Wings_of_Liberty) RTS game. SMAC makes use of Blizzard's [StarCraft II Machine Learning API](https://github.com/Blizzard/s2client-proto) and [DeepMind](https://deepmind.com)'s [PySC2](https://github.com/deepmind/pysc2) to provide a convenient interface for autonomous agents to interact with StarCraft II, getting observations and performing actions. Unlike the [PySC2](https://github.com/deepmind/pysc2), SMAC concentrates on *decentralised micromanamgent* scenarios, where each unit of the game is controlled by an individual RL agent. + +Please refer to the accompanying [paper](https://arxiv.org/abs/TODO) and [blogpost](http://whirl.cs.ox.ac.uk/blog/smac) for the outline of our motivation for using SMAC as a testbed for MARL research and the initial experimental results. + +## About + +Together with SMAC we also release [PyMARL](https://github.com/oxwhirl/pymarl) - our framework for MARL research, which includes implementations of several state-of-the-art algorithms, such as [QMIX](https://arxiv.org/abs/1803.11485) and [COMA](https://arxiv.org/abs/1705.08926). + +Should you have any question, please reach to [mikayel@samvelyan.com](mailto:[mikayel@samvelyan.com) or [tabish.rashid@cs.ox.ac.uk](mailto:[tabish.rashid@cs.ox.ac.uk). + + +# Quick Start + +## Installing SMAC + +You can install SMAC by using the following command: + +```shell +$ pip install git+https://github.com/oxwhirl/smac.git +``` + +Alternatively, you can clone the SMAC repository and then install `smac` with its dependencies: + +```shell +$ git clone https://github.com/oxwhirl/smac.git +$ pip install smac/ +``` + +SMAC uses features of PySC2 that are not included in the latest release yet. If you have PySC2-2.0.1 already installed, please uninstall it first. SMAC will install a newer version from the master branch. You may also need to upgrade pip: `pip install --upgrade pip` for the install to work. + +## Installing StarCraft II + +SMAC is based on the full game of StarCraft II (versions >= 3.16.1). To install the game, follow the commands bellow. + +### Linux + +Please use the Blizzard's [repository](https://github.com/Blizzard/s2client-proto#downloads) to download the Linux version of StarCraft II. By default, the game is expected to be in `~/StarCraftII/` directory. This can be changed by setting the environment variable `SC2PATH`. + +### MacOS/Windows + +Please install StarCraft II from [Battle.net](https://battle.net). The free [Starter Edition](http://battle.net/sc2/en/legacy-of-the-void/) also works. PySC2 will find the latest binary should you use the default install location. Otherwise, similar to the Linux version, you would need to set the `SC2PATH` environment variable with the correct location of the game. + +## SMAC maps + +SMAC is composed of many combat scenarios with pre-configured maps. Before SMAC can be used, these maps need to be downloaded into the `Maps` directory of StarCraft II. + +Download the [SMAC Maps](https://github.com/oxwhirl/smac/releases/download/v1.2/smac_maps.zip) and extract them to your `$SC2PATH/Maps` directory.(**TODO** fix the link) If you installed SMAC via git, simply copy the `SMAC_Maps` directory from `smac/env/starcraft2/maps/` into `$SC2PATH/Maps` directory. + +### List the maps + +To see the list of SMAC maps, together with the number of ally and enemy units and episode limit, run: + +```shell +$ python -m smac.bin.map_list +``` + +## Testing SMAC + +Please run the following command to make sure that `smac` and its maps are properly installed. + +```bash +$ python -m smac.examples.random +``` + +## Watch a replay + +You can watch saved replays by running: + +```shell +$ python -m pysc2.bin.play --norender --rgb_minimap_size 0 --replay +``` + +This works for any replay as long as the map can be found by the game. + +For more information, please refer to [PySC2](https://github.com/deepmind/pysc2) documentation. + +# Documentation + +For the detailed description of the environment, read the [SMAC documentation](docs/smac.md). The initial results of our experiments using SMAC can be found in the [accompanying paper](https://arxiv.org/abs/TODO). + +# Citing SMAC + +If you use SMAC in your research, please cite the [SMAC Paper](https://arxiv.org/abs/TODO). + +*M. Samvelyan, T. Rashid, C. Schroeder de Witt, G. Farquhar, N. Nardelli, T.G.J Rudner, CM Hung, P.H.S. Torr, J. Foerster, S. Whiteson. The StarCraft Multi-Agent Challenge, CoRR abs/TBD, 2018* + +In BibTeX format: + +```tex +@article{samvelyan19smac, + title = {{The} {StarCraft} {Multi}-{Agent} {Challenge}}, + author = {Mikayel Samvelyan and Tabish Rashid and Christian Schroeder de Witt and Gregory Farquhar and Nantas Nardelli and Tim G. J. Rudner and Chia-Man Hung and Philiph H. S. Torr and Jakob Foerster and Shimon Whiteson}, + journal = {CoRR}, + volume = {abs/TBD}, + year = "2019" +} +``` + +# Code Example + +Bellow is a small code example which illustrates how SMAC can be used. Here, individual agents execute random policies after receiving the observations and global state from the environment. + +If you want to try the state-of-the-art algorithms (such as [QMIX](https://arxiv.org/abs/1803.11485) and [COMA](https://arxiv.org/abs/1705.08926)) on SMAC, make use of [PyMARL](https://github.com/oxwhirl/smac) - our framework for MARL research. + +```python +from smac.env import StarCraft2Env +import numpy as np + + +def main(): + env = StarCraft2Env(map_name="8m") + env_info = env.get_env_info() + + n_actions = env_info["n_actions"] + n_agents = env_info["n_agents"] + + n_episodes = 10 + + for e in range(n_episodes): + env.reset() + terminated = False + episode_reward = 0 + + while not terminated: + obs = env.get_obs() + state = env.get_state() + + actions = [] + for agent_id in range(n_agents): + avail_actions = env.get_avail_agent_actions(agent_id) + avail_actions_ind = np.nonzero(avail_actions)[0] + action = np.random.choice(avail_actions_ind) + actions.append(action) + + reward, terminated, _ = env.step(actions) + episode_reward += reward + + print("Total reward in episode {} = {}".format(e, episode_reward)) + + env.close() + +``` diff --git a/docs/smac.md b/docs/smac.md new file mode 100644 index 00000000..2b7acdcf --- /dev/null +++ b/docs/smac.md @@ -0,0 +1,133 @@ +## Table of Contents + +- [StarCraft II](#starcraft-ii) + - [Micromanagement](#micromanagement) +- [SMAC](#smac) + - [Scenarios](#scenarios) + - [State and Observations](#state-and-observations) + - [Action Space](#action-space) + - [Rewards](#rewards) + - [Environment Settings](#environment-settings) + +## StarCraft II + +SMAC is based on the popular real-time strategy (RTS) game [StarCraft II](http://us.battle.net/sc2/en/game/guide/whats-sc2) written by [Blizzard](http://blizzard.com/). +In a regular full game of StarCraft II, one or more humans compete against each other or against a built-in game AI to gather resources, construct buildings, and build armies of units to defeat their opponents. + +Akin to most RTSs, StarCraft has two main gameplay components: macromanagement and micromanagement. _Macromanagement_ refers to high-level strategic considerations, such as economy and resource management. _Micromanagement_ (micro), on the other hand, refers to fine-grained control of individual units. + +### Micromanagement + +StarCraft has been used as a research platform for AI, and more recently, RL. Typically, the game is framed as a single-agent problem: an agent takes the role of a human player, making macromanagement decisions and performing micromanagement as a puppeteer that issues orders to individual units from a centralised controller. + +In order to build a rich multi-agent testbed, we instead focus solely on micromanagement. +Micro is a vital aspect of StarCraft gameplay with a high skill ceiling, and is practiced in isolation by amateur and professional players. +For SMAC, we leverage the natural multi-agent structure of micromanagement by proposing a modified version of the problem designed specifically for decentralised control. +In particular, we require that each unit be controlled by an independent agent that conditions only on local observations restricted to a limited field of view centred on that unit. +Groups of these agents must be trained to solve challenging combat scenarios, battling an opposing army under the centralised control of the game's built-in scripted AI. + +Proper micro of units during battles will maximise the damage dealt to enemy units while minimising damage received, and requires a range of skills. +For example, one important technique is _focus fire_, i.e., ordering units to jointly attack and kill enemy units one after another. When focusing fire, it is important to avoid _overkill_: inflicting more damage to units than is necessary to kill them. + +Other common micromanagement techniques include: assembling units into formations based on their armour types, making enemy units give chase while maintaining enough distance so that little or no damage is incurred (_kiting_), coordinating the positioning of units to attack from different directions or taking advantage of the terrain to defeat the enemy. + +Learning these rich cooperative behaviours under partial observability is challenging task, which can be used to evaluate the effectiveness of multi-agent reinforcement learning (MARL) algorithms. + +## SMAC + +SMAC uses the [StarCraft II Learning Environment](https://github.com/deepmind/pysc2) to introduce a cooperative MARL environment. + +### Scenarios + +SMAC consists of a set of StarCraft II micro scenarios which aim to evaluate how well independent agents are able to learn coordination to solve complex tasks. +These scenarios are carefully designed to necessitate the learning of one or more micromanagement techniques to defeat the enemy. +Each scenario is a confrontation between two armies of units. +The initial position, number, and type of units in each army varies from scenario to scenario, as does the presence or absence of elevated or impassable terrain. + +The first army is controlled by the learned allied agents. +The second army consists of enemy units controlled by the built-in game AI, which uses carefully handcrafted non-learned heuristics. +At the beginning of each episode, the game AI instructs its units to attack the allied agents using its scripted strategies. +An episode ends when all units of either army have died or when a pre-specified time limit is reached (in which case the game is counted as a defeat for the allied agents). +The goal for each scenario is to maximise the win rate of the learned policies, i.e., the expected ratio of games won to games played. +To speed up learning, the enemy AI units are ordered to attack the agents' spawning point in the beginning of each episode. + +Perhaps the simplest scenarios are _symmetric_ battle scenarios. +The most straightforward of these scenarios are _homogeneous_, i.e., each army is composed of only a single unit type (e.g., Marines). +A winning strategy in this setting would be to focus fire, ideally without overkill. +_Heterogeneous_ symmetric scenarios, in which there is more than a single unit type on each side (e.g., Stalkers and Zealots), are more difficult to solve. +Such challenges are particularly interesting when some of the units are extremely effective against others (this is known as _countering_), for example, by dealing bonus damage to a particular armour type. +In such a setting, allied agents must deduce this property of the game and design an intelligent strategy to protect teammates vulnerable to certain enemy attacks. + +SMAC also includes more challenging scenarios, for example, in which the enemy army outnumbers the allied army by one or more units. In such _asymmetric_ scenarios it is essential to consider the health of enemy units in order to effectively target the desired opponent. + +Lastly, SMAC offers a set of interesting _micro-trick_ challenges that require a higher-level of cooperation and a specific micromanagement trick to defeat the enemy. +An example of a challenge scenario is _2m_vs_1z_ (aka Marine Double Team), where two Terran Marines need to defeat an enemy Zealot. In this setting, the Marines must design a strategy which does not allow the Zealot to reach them, otherwise they will die almost immediately. +Another example is _so_many_banelings_ where 7 allied Zealots face 32 enemy Baneling units. Banelings attack by running against a target and exploding when reaching them, causing damage to a certain area around the target. Hence, if a large number of Banelings attack a handful of Zealots located close to each other, the Zealots will be defeated instantly. The optimal strategy, therefore, is to cooperatively spread out around the map far from each other so that the Banelings' damage is distributed as thinly as possible. +The _corridor_ scenario, in which 6 friendly Zealots face 24 enemy Zerglings, requires agents to make effective use of the terrain features. Specifically, agents should collectively wall off the choke point (the narrow region of the map) to block enemy attacks from different directions. Some of the micro-trick challenges are inspired by [StarCraft Master](http://us.battle.net/sc2/en/blog/4544189/new-blizzard-custom-game-starcraft-master-3-1-2012) challenge missions released by Blizzard. + +The complete list of challenges is presented bellow. The difficulty of the game AI is set to _very difficult_ (7). Our experiments, however, suggest that this setting does significantly impact the unit micromanagement of the built-in heuristics. + +| Name | Ally Units | Enemy Units | Type | +| :---: | :---: | :---: | :---:| +| 3m | 3 Marines | 3 Marines | homogeneous & symmetric | +| 8m | 8 Marines | 8 Marines | homogeneous & symmetric | +| 25m | 25 Marines | 25 Marines | homogeneous & symmetric | +| 2s3z | 2 Stalkers & 3 Zealots | 2 Stalkers & 3 Zealots | heterogeneous & symmetric | +| 3s5z | 3 Stalkers & 5 Zealots | 3 Stalkers & 5 Zealots | heterogeneous & symmetric | +| MMM | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 2 Marauders & 7 Marines | heterogeneous & symmetric | +| 5m_vs_6m | 5 Marines | 6 Marines | homogeneous & asymmetric | +| 8m_vs_9m | 8 Marines | 9 Marines | homogeneous & asymmetric | +| 10m_vs_11m | 10 Marines | 11 Marines | homogeneous & asymmetric | +| 27m_vs_30m | 27 Marines | 30 Marines | homogeneous & asymmetric | +| 3s5z_vs_3s6z | 3 Stalkers & 5 Zealots | 3 Stalkers & 6 Zealots | heterogeneous & asymmetric | +| MMM2 | 1 Medivac, 2 Marauders & 7 Marines | 1 Medivac, 3 Marauders & 8 Marines | heterogeneous & asymmetric | +| 2m_vs_1z | 2 Marines | 1 Zealot | micro-trick: alternating fire | +| 2s_vs_1sc| 2 Stalkers | 1 Spine Crawler | micro-trick: alternating fire | +| 3s_vs_3z | 3 Stalkers | 3 Zealots | micro-trick: kiting | +| 3s_vs_4z | 3 Stalkers | 4 Zealots | micro-trick: kiting | +| 3s_vs_5z | 3 Stalkers | 5 Zealots | micro-trick: kiting | +| 6h_vs_8z | 6 Hydralisks | 8 Zealots | micro-trick: focus fire | +| corridor | 6 Zealots | 24 Zerglings | micro-trick: wall off | +| bane_vs_bane | 20 Zerglings & 4 Banelings | 20 Zerglings & 4 Banelings | micro-trick: positioning | +| so_many_banelings| 7 Zealots | 32 Banelings | micro-trick: positioning | +| 2c_vs_64zg| 2 Colossi | 64 Zerglings | micro-trick: positioning | + +### State and Observations + +At each timestep, agents receive local observations drawn within their field of view. This encompasses information about the map within a circular area around each unit and with a radius equal to the _sight range_. The sight range makes the environment partially observable from the standpoint of each agent. Agents can only observe other agents if they are both alive and located within the sight range. Hence, there is no way for agents to determine whether their teammates are far away or dead. + +The feature vector observed by each agent contains the following attributes for both allied and enemy units within the sight range: _distance_, _relative x_, _relative y_, _health_, _shield_, and _unit\_type_.\footnote{_health_, _shield_ and _unit\_type_ of the unit the agent controls is also included in observations_. Shields serve as an additional source of protection that needs to be removed before any damage can be done to the health of units. +All Protos units have shields, which can regenerate if no new damage is dealt +(units of the other two races do not have this attribute). +In addition, agents have access to the last actions of allied units that are in the field of view. Lastly, agents can observe the terrain features surrounding them; particularly, the values of eight points at a fixed radius indicating height and walkability. + +The global state, which is only available to agents during centralised training, contains information about all units on the map. Specifically, the state vector includes the coordinates of all agents relative to the centre of the map, together with unit features present in the observations. Additionally, the state stores the _energy_ of Medivacs and _cooldown_ of the rest of allied units, which represents the minimum delay between attacks. Finally, the last actions of all agents are attached to the central state. + +All features, both in the state as well as in the observations of individual agents, are normalised by their maximum values. The sight range is set to 9 for all agents. + +### Action Space + +The discrete set of actions which agents are allowed to take consists of _move[direction]_ (four directions: north, south, east, or west._, _attack[enemy_id]_, _stop_ and _no-op_. Dead agents can only take _no-op_ action while live agents cannot. +As healer units, Medivacs must use _heal[agent\_id]_ actions instead of _attack[enemy\_id]_. + +To ensure decentralisation of the task, agents are restricted to use the _attack[enemy\_id]_ action only towards enemies in their _shooting range_. +This additionally constrains the unit's ability to use the built-in _attack-move_ macro-actions on the enemies far away. We set the shooting range equal to 6. Having a bigger sight than shooting range forces agents to make use of move commands before starting to fire. + +### Rewards + +The overall goal is to have the highest win rate for each battle scenario. +We provide a corresponding option for _sparse rewards_, which will cause the environment to return only a reward of +1 for winning and -1 for losing an episode. +However, we also provide a default setting for a shaped reward signal calculated from the hit-point damage dealt and received by agents, some positive (negative) reward after having enemy (allied) units killed and/or a positive (negative) bonus for winning (losing) the battle. +The exact values and scales of this shaped reward can be configured using a range of flags, but we strongly discourage disingenuous engineering of the reward function (e.g. tuning different reward functions for different scenarios). + +### Environment Settings + +SMAC, through the [PySC2](https://github.com/deepmind/pysc2) framework of SC2LE, is able to send commands and receive input from the [StarCraft II API](https://github.com/Blizzard/s2client-proto), which provides full control of the StarCraft II game. SC2LE uses a _feature layer_ interface where observations are abstractions of RGB images that maintain the spatial and graphical concepts of the game. However, SMAC uses the _raw API_ also supported by StarCraft II API. +Raw API observations do not have any graphical component and include information about the units on the map such as health, location coordinates, etc. The raw API also allows sending action commands to individual units using their unit IDs. This setting differs from how humans play the actual game, but is convenient for designing decentralised multi-agent learning tasks. + +Since our micro-scenarios are shorter than actual StarCraft II games, restarting the game after each episode presents a computational bottleneck. To overcome this issue, we make use of the API's debug commands. Specifically, when all units of either army have been killed, we kill all remaining units by sending a debug action. Having no units left launches a trigger programmed with the StarCraft II Editor that re-spawns all units in their original location with full health, thereby restarting the scenario quickly and efficiently. + +Furthermore, to encourage agents to explore interesting micro-strategies themselves, we limit the influence of the StarCraft AI on our agents. Specifically we disable the automatic unit attack against enemies that attack agents or that are located nearby. +To do so, we make use of new units created with the StarCraft II Editor that are exact copies of existing units with two attributes modified: _Combat: Default Acquire Level_ is set to _Passive_ (default _Offensive_) and _Behaviour: Response_ is set to _No Response_ (default _Acquire_). These fields are only modified for allied units; enemy units are unchanged. + +The sight and shooting range values might differ from the built-in _sight_ or _range_ attribute of some StarCraft II units. Our goal is not to master the original full StarCraft game, but rather to benchmark MARL methods for decentralised control. diff --git a/setup.py b/setup.py new file mode 100755 index 00000000..3513aa98 --- /dev/null +++ b/setup.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from setuptools import setup + +description = """SMAC - StarCraft Multi-Agent Challenge + +SMAC offers a diverse set of decentralised micromanagement challenges based on +StarCraft II game. In these challenges, each of the units is controlled by an +independent, learning agent that has to act based only on local observations, +while the opponent's units are controlled by the built-in StarCraft II AI. + +The accompanying paper which outlines the motivation for using SMAC as well as +results using the state-of-the-art deep multi-agent reinforcement learning +algorithms can be found at https://www.arxiv.link + +Read the README at https://github.com/oxwhirl/smac for more information. +""" + +setup( + name='SMAC', + version='0.1.0b1', + description='SMAC - StarCraft Multi-Agent Challenge.', + long_description=description, + author='WhiRL', + author_email='mikayel@samvelyan.com', + license='Apache License, Version 2.0', + keywords='StarCraft AI, Multi-Agent Reinforcement Learning', + url='https://github.com/oxwhirl/smac', + packages=[ + 'smac', + 'smac.env', + 'smac.env.starcraft2', + 'smac.env.starcraft2.maps', + 'smac.bin', + 'smac.examples' + ], + install_requires=[ + 's2clientprotocol>=4.6.0.67926.0', + 'absl-py>=0.1.0', + 'numpy>=1.10', + 'pysc2 @ git+https://github.com/deepmind/pysc2@a9f093493c4c77adb385602790a480e7f238b97d', + ], +) diff --git a/smac/__init__.py b/smac/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/smac/bin/__init__.py b/smac/bin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/smac/bin/map_list.py b/smac/bin/map_list.py new file mode 100644 index 00000000..ec1be8df --- /dev/null +++ b/smac/bin/map_list.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from smac.env.starcraft2.maps import smac_maps + +from pysc2 import maps as pysc2_maps + + +def main(): + smac_map_registry = smac_maps.get_smac_map_registry() + all_maps = pysc2_maps.get_maps() + print("{:<15} {:7} {:7} {:7}".format("Name", "Agents", "Enemies", "Limit")) + for map_name, map_params in smac_map_registry.items(): + map_class = all_maps[map_name] + if map_class.path: + print( + "{:<15} {:<7} {:<7} {:<7}".format( + map_name, + map_params["n_agents"], + map_params["n_enemies"], + map_params["limit"], + ) + ) + + +if __name__ == "__main__": + main() diff --git a/smac/env/__init__.py b/smac/env/__init__.py new file mode 100644 index 00000000..44697deb --- /dev/null +++ b/smac/env/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from smac.env.multiagentenv import MultiAgentEnv +from smac.env.starcraft2.starcraft2 import StarCraft2Env + +__all__ = ["MultiAgentEnv", "StarCraft2Env"] diff --git a/smac/env/multiagentenv.py b/smac/env/multiagentenv.py new file mode 100644 index 00000000..049ea614 --- /dev/null +++ b/smac/env/multiagentenv.py @@ -0,0 +1,67 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class MultiAgentEnv(object): + + def step(self, actions): + """Returns reward, terminated, info.""" + raise NotImplementedError + + def get_obs(self): + """Returns all agent observations in a list.""" + raise NotImplementedError + + def get_obs_agent(self, agent_id): + """Returns observation for agent_id.""" + raise NotImplementedError + + def get_obs_size(self): + """Returns the size of the observation.""" + raise NotImplementedError + + def get_state(self): + """Returns the global state.""" + raise NotImplementedError + + def get_state_size(self): + """Returns the size of the global state.""" + raise NotImplementedError + + def get_avail_actions(self): + """Returns the available actions of all agents in a list.""" + raise NotImplementedError + + def get_avail_agent_actions(self, agent_id): + """Returns the available actions for agent_id.""" + raise NotImplementedError + + def get_total_actions(self): + """Returns the total number of actions an agent could ever take.""" + raise NotImplementedError + + def reset(self): + """Returns initial observations and states.""" + raise NotImplementedError + + def render(self): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + def seed(self): + raise NotImplementedError + + def save_replay(self): + """Save a replay.""" + raise NotImplementedError + + def get_env_info(self): + env_info = {"state_shape": self.get_state_size(), + "obs_shape": self.get_obs_size(), + "n_actions": self.get_total_actions(), + "n_agents": self.n_agents, + "episode_limit": self.episode_limit} + return env_info diff --git a/smac/env/starcraft2/__init__.py b/smac/env/starcraft2/__init__.py new file mode 100644 index 00000000..94e651ec --- /dev/null +++ b/smac/env/starcraft2/__init__.py @@ -0,0 +1,7 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import flags +FLAGS = flags.FLAGS +FLAGS(['main.py']) diff --git a/smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map new file mode 100755 index 00000000..1dc2286d Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map new file mode 100755 index 00000000..fcfdeb09 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/25m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map new file mode 100755 index 00000000..861c7f70 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map new file mode 100755 index 00000000..b740b6c3 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map new file mode 100755 index 00000000..f4c05c40 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map new file mode 100755 index 00000000..59846ccf Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/2s3z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map new file mode 100755 index 00000000..c03328db Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map new file mode 100755 index 00000000..b35ec100 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/3m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map new file mode 100755 index 00000000..e5a4313a Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/3s5z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map new file mode 100755 index 00000000..3927ca4f Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map new file mode 100755 index 00000000..4de7cf80 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map new file mode 100755 index 00000000..8db2dfc6 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map new file mode 100755 index 00000000..70c99d29 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map new file mode 100755 index 00000000..f2ae42c2 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map new file mode 100755 index 00000000..df01eb64 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map new file mode 100755 index 00000000..6593c72f Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/8m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map new file mode 100755 index 00000000..5b8815f6 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map new file mode 100755 index 00000000..ed26fe44 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/MMM.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map new file mode 100755 index 00000000..ab25a02b Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/MMM2.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map new file mode 100755 index 00000000..bb81284c Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map new file mode 100755 index 00000000..90daed60 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/corridor.SC2Map differ diff --git a/smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map b/smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map new file mode 100755 index 00000000..6a184e35 Binary files /dev/null and b/smac/env/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map differ diff --git a/smac/env/starcraft2/maps/__init__.py b/smac/env/starcraft2/maps/__init__.py new file mode 100644 index 00000000..78acbe9d --- /dev/null +++ b/smac/env/starcraft2/maps/__init__.py @@ -0,0 +1,10 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from smac.env.starcraft2.maps import smac_maps + + +def get_map_params(map_name): + map_param_registry = smac_maps.get_smac_map_registry() + return map_param_registry[map_name] diff --git a/smac/env/starcraft2/maps/smac_maps.py b/smac/env/starcraft2/maps/smac_maps.py new file mode 100644 index 00000000..1e9e033e --- /dev/null +++ b/smac/env/starcraft2/maps/smac_maps.py @@ -0,0 +1,223 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from pysc2.maps import lib + + +class SMACMap(lib.Map): + directory = "SMAC_Maps" + download = "https://github.com/oxwhirl/smac#smac-maps" + players = 2 + step_mul = 8 + game_steps_per_episode = 0 + + +map_param_registry = { + "3m": { + "n_agents": 3, + "n_enemies": 3, + "limit": 60, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "8m": { + "n_agents": 8, + "n_enemies": 8, + "limit": 120, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "25m": { + "n_agents": 25, + "n_enemies": 25, + "limit": 150, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "5m_vs_6m": { + "n_agents": 5, + "n_enemies": 6, + "limit": 70, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "8m_vs_9m": { + "n_agents": 8, + "n_enemies": 9, + "limit": 120, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "10m_vs_11m": { + "n_agents": 10, + "n_enemies": 11, + "limit": 150, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "27m_vs_30m": { + "n_agents": 27, + "n_enemies": 30, + "limit": 180, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 0, + "map_type": "marines", + }, + "MMM": { + "n_agents": 10, + "n_enemies": 10, + "limit": 150, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 3, + "map_type": "MMM", + }, + "MMM2": { + "n_agents": 10, + "n_enemies": 12, + "limit": 180, + "a_race": "T", + "b_race": "T", + "unit_type_bits": 3, + "map_type": "MMM", + }, + "2s3z": { + "n_agents": 5, + "n_enemies": 5, + "limit": 120, + "a_race": "P", + "b_race": "P", + "unit_type_bits": 2, + "map_type": "stalkers_and_zealots", + }, + "3s5z": { + "n_agents": 8, + "n_enemies": 8, + "limit": 150, + "a_race": "P", + "b_race": "P", + "unit_type_bits": 2, + "map_type": "stalkers_and_zealots", + }, + "3s5z_vs_3s6z": { + "n_agents": 8, + "n_enemies": 9, + "limit": 170, + "a_race": "P", + "b_race": "P", + "unit_type_bits": 2, + "map_type": "stalkers_and_zealots", + }, + "3s_vs_3z": { + "n_agents": 3, + "n_enemies": 3, + "limit": 150, + "a_race": "P", + "b_race": "P", + "unit_type_bits": 0, + "map_type": "stalkers", + }, + "3s_vs_4z": { + "n_agents": 3, + "n_enemies": 4, + "limit": 200, + "a_race": "P", + "b_race": "P", + "unit_type_bits": 0, + "map_type": "stalkers", + }, + "3s_vs_5z": { + "n_agents": 3, + "n_enemies": 5, + "limit": 250, + "a_race": "P", + "b_race": "P", + "unit_type_bits": 0, + "map_type": "stalkers", + }, + "2m_vs_1z": { + "n_agents": 2, + "n_enemies": 1, + "limit": 150, + "a_race": "T", + "b_race": "P", + "unit_type_bits": 0, + "map_type": "marines", + }, + "corridor": { + "n_agents": 6, + "n_enemies": 24, + "limit": 400, + "a_race": "P", + "b_race": "Z", + "unit_type_bits": 0, + "map_type": "zealots", + }, + "6h_vs_8z": { + "n_agents": 6, + "n_enemies": 8, + "limit": 150, + "a_race": "Z", + "b_race": "P", + "unit_type_bits": 0, + "map_type": "hydralisks", + }, + "2s_vs_1sc": { + "n_agents": 2, + "n_enemies": 1, + "limit": 300, + "a_race": "P", + "b_race": "Z", + "unit_type_bits": 0, + "map_type": "stalkers", + }, + "so_many_baneling": { + "n_agents": 7, + "n_enemies": 32, + "limit": 100, + "a_race": "P", + "b_race": "Z", + "unit_type_bits": 0, + "map_type": "zealots", + }, + "bane_vs_bane": { + "n_agents": 24, + "n_enemies": 24, + "limit": 200, + "a_race": "Z", + "b_race": "Z", + "unit_type_bits": 2, + "map_type": "bane", + }, + "2c_vs_64zg": { + "n_agents": 2, + "n_enemies": 64, + "limit": 400, + "a_race": "P", + "b_race": "Z", + "unit_type_bits": 0, + "map_type": "colossus", + }, +} + + +def get_smac_map_registry(): + return map_param_registry + + +for name in map_param_registry.keys(): + globals()[name] = type(name, (SMACMap,), dict(filename=name)) diff --git a/smac/env/starcraft2/starcraft2.py b/smac/env/starcraft2/starcraft2.py new file mode 100644 index 00000000..cc32673f --- /dev/null +++ b/smac/env/starcraft2/starcraft2.py @@ -0,0 +1,1266 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from smac.env.multiagentenv import MultiAgentEnv +from smac.env.starcraft2.maps import get_map_params + +from operator import attrgetter +from copy import deepcopy +import numpy as np +import enum +import math +from absl import logging + +from pysc2 import maps +from pysc2 import run_configs +from pysc2.lib import protocol + +from s2clientprotocol import common_pb2 as sc_common +from s2clientprotocol import sc2api_pb2 as sc_pb +from s2clientprotocol import raw_pb2 as r_pb +from s2clientprotocol import debug_pb2 as d_pb + +races = { + "R": sc_common.Random, + "P": sc_common.Protoss, + "T": sc_common.Terran, + "Z": sc_common.Zerg, +} + +difficulties = { + "1": sc_pb.VeryEasy, + "2": sc_pb.Easy, + "3": sc_pb.Medium, + "4": sc_pb.MediumHard, + "5": sc_pb.Hard, + "6": sc_pb.Harder, + "7": sc_pb.VeryHard, + "8": sc_pb.CheatVision, + "9": sc_pb.CheatMoney, + "A": sc_pb.CheatInsane, +} + +actions = { + "move": 16, # target: PointOrUnit + "attack": 23, # target: PointOrUnit + "stop": 4, # target: None + "heal": 386, # Unit +} + + +class Direction(enum.IntEnum): + NORTH = 0 + SOUTH = 1 + EAST = 2 + WEST = 3 + + +class StarCraft2Env(MultiAgentEnv): + """The StarCraft II environment for decentralised multi-agent + micromanagement scenarios. + """ + def __init__( + self, + map_name="8m", + step_mul=None, + move_amount=2, + difficulty="7", + game_version=None, + seed=None, + continuing_episode=False, + obs_all_health=True, + obs_own_health=True, + obs_last_action=False, + obs_pathing_grid=False, + obs_terrain_height=False, + obs_instead_of_state=False, + state_last_action=True, + reward_sparse=False, + reward_only_positive=True, + reward_death_value=10, + reward_win=200, + reward_defeat=0, + reward_negative_scale=0.5, + reward_scale=True, + reward_scale_rate=20, + replay_dir="", + replay_prefix="", + window_size_x=1920, + window_size_y=1200, + debug=False, + ): + """ + Create a StarCraftC2Env environment. + + Parameters + ---------- + map_name : str, optional + The name of the SC2 map to play (default is "8m"). The full list + can be found by running bin/map_list. + step_mul : int, optional + How many game steps per agent step (default is None). None + indicates to use the default map step_mul. + move_amount : float, optional + How far away units are ordered to move per step (default is 2). + difficulty : str, optional + The difficulty of built-in computer AI bot (default is "7"). + game_version : str, optional + StarCraft II game version (default is None). None indicates the + latest version. + seed : int, optional + Random seed used during game initialisation. This allows to + continuing_episode : bool, optional + Whether to consider episodes continuing or finished after time + limit is reached (default is False). + obs_all_health : bool, optional + Agents receive the health of all units (in the sight range) as part + of observations (default is True). + obs_own_health : bool, optional + Agents receive their own health as a part of observations (default + is False). This flag is ignored when obs_all_health == True. + obs_last_action : bool, optional + Agents receive the last actions of all units (in the sight range) + as part of observations (default is False). + obs_pathing_grid : bool, optional + Whether observations include pathing values surrounding the agent + (default is False). + obs_terrain_height : bool, optional + Whether observations include terrain height values surrounding the + agent (default is False). + obs_instead_of_state : bool, optional + Use combination of all agents' observations as the global state + (default is False). + state_last_action : bool, optional + Include the last actions of all agents as part of the global state + (default is True). + reward_sparse : bool, optional + Receive 1/-1 reward for winning/loosing an episode (default is + False). Whe rest of reward parameters are ignored if True. + reward_only_positive : bool, optional + Reward is always positive (default is True). + reward_death_value : float, optional + The amount of reward received for killing an enemy unit (default + is 10). This is also the negative penalty for having an allied unit + killed if reward_only_positive == False. + reward_win : float, optional + The reward for winning in an episode (default is 200). + reward_defeat : float, optional + The reward for loosing in an episode (default is 0). This value + should be nonpositive. + reward_negative_scale : float, optional + Scaling factor for negative rewards (default is 0.5). This + parameter is ignored when reward_only_positive == True. + reward_scale : bool, optional + Whether or not to scale the reward (default is True). + reward_scale_rate : float, optional + Reward scale rate (default is 20). When reward_scale == True, the + reward received by the agents is divided by (max_reward / + reward_scale_rate), where max_reward is the maximum possible + reward per episode without considering the shield regeneration + of Protoss units. + replay_dir : str, optional + The directory to save replays (default is None). If None, the + replay will be saved in Replays directory where StarCraft II is + installed. + replay_prefix : str, optional + The prefix of the replay to be saved (default is None). If None, + the name of the map will be used. + window_size_x : int, optional + The length of StarCraft II window size (default is 1920). + window_size_y: int, optional + The height of StarCraft II window size (default is 1200). + debug: bool, optional + Log messages about observations, state, actions and rewards for + debugging purposes (default is False). + """ + # Map arguments + self.map_name = map_name + map_params = get_map_params(self.map_name) + self.n_agents = map_params["n_agents"] + self.n_enemies = map_params["n_enemies"] + self.episode_limit = map_params["limit"] + self._move_amount = move_amount + self._step_mul = step_mul + self.difficulty = difficulty + + # Observations and state + self.obs_own_health = obs_own_health + self.obs_all_health = obs_all_health + self.obs_instead_of_state = obs_instead_of_state + self.obs_last_action = obs_last_action + self.obs_pathing_grid = obs_pathing_grid + self.obs_terrain_height = obs_terrain_height + self.state_last_action = state_last_action + if self.obs_all_health: + self.obs_own_health = True + self.n_obs_pathing = 8 + self.n_obs_height = 9 + + # Rewards args + self.reward_sparse = reward_sparse + self.reward_only_positive = reward_only_positive + self.reward_negative_scale = reward_negative_scale + self.reward_death_value = reward_death_value + self.reward_win = reward_win + self.reward_defeat = reward_defeat + self.reward_scale = reward_scale + self.reward_scale_rate = reward_scale_rate + + # Other + self.game_version = game_version + self.continuing_episode = continuing_episode + self.seed = seed + self.debug = debug + self.window_size = (window_size_x, window_size_y) + self.replay_dir = replay_dir + self.replay_prefix = replay_prefix + + # Actions + self.n_actions_no_attack = 6 + self.n_actions_move = 4 + self.n_actions = self.n_actions_no_attack + self.n_enemies + + # Map info + self._agent_race = map_params["a_race"] + self._bot_race = map_params["b_race"] + self.shield_bits_ally = 1 if self._agent_race == "P" else 0 + self.shield_bits_enemy = 1 if self._bot_race == "P" else 0 + self.unit_type_bits = map_params["unit_type_bits"] + self.map_type = map_params["map_type"] + + self._launch() + + self.max_reward = ( + self.n_enemies * self.reward_death_value + self.reward_win + ) + self._game_info = self.controller.game_info() + self._map_info = self._game_info.start_raw + self.map_x = self._map_info.map_size.x + self.map_y = self._map_info.map_size.y + self.map_play_area_min = self._map_info.playable_area.p0 + self.map_play_area_max = self._map_info.playable_area.p1 + self.max_distance_x = ( + self.map_play_area_max.x - self.map_play_area_min.x + ) + self.max_distance_y = ( + self.map_play_area_max.y - self.map_play_area_min.y + ) + self.terrain_height = np.flip( + np.transpose(np.array(list(self._map_info.terrain_height.data)) + .reshape(self.map_x, self.map_y)), 1) / 255 + self.pathing_grid = np.flip( + np.transpose(np.array(list(self._map_info.pathing_grid.data)) + .reshape(self.map_x, self.map_y)), 1) / 255 + + self.agents = {} + self.enemies = {} + self._episode_count = 0 + self._episode_steps = 0 + self._total_steps = 0 + self._obs = None + self.battles_won = 0 + self.battles_game = 0 + self.timeouts = 0 + self.force_restarts = 0 + self.last_stats = None + self.death_tracker_ally = np.zeros(self.n_agents) + self.death_tracker_enemy = np.zeros(self.n_enemies) + self.previous_ally_units = None + self.previous_enemy_units = None + self.last_action = np.zeros((self.n_agents, self.n_actions)) + self._min_unit_type = 0 + self.marine_id = self.marauder_id = self.medivac_id = 0 + self.hydralisk_id = self.zergling_id = self.baneling_id = 0 + self.stalker_id = self.colossus_id = self.zealot_id = 0 + + def _launch(self): + """Launch the StarCraft II game.""" + self._run_config = run_configs.get() + _map = maps.get(self.map_name) + + # Setting up the interface + interface_options = sc_pb.InterfaceOptions(raw=True, score=False) + + self._sc2_proc = self._run_config.start(game_version=self.game_version, + window_size=self.window_size) + self.controller = self._sc2_proc.controller + + # Request to create the game + create = sc_pb.RequestCreateGame( + local_map=sc_pb.LocalMap( + map_path=_map.path, + map_data=self._run_config.map_data(_map.path)), + realtime=False, + random_seed=self.seed) + create.player_setup.add(type=sc_pb.Participant) + create.player_setup.add(type=sc_pb.Computer, race=races[self._bot_race], + difficulty=difficulties[self.difficulty]) + self.controller.create_game(create) + + join = sc_pb.RequestJoinGame(race=races[self._agent_race], + options=interface_options) + self.controller.join_game(join) + + def reset(self): + """Reset the environment. Required after each full episode. + Returns initial observations and states. + """ + self._episode_steps = 0 + if self._episode_count > 0: + # No need to restart for the first episode. + self._restart() + + self._episode_count += 1 + + # Information kept for counting the reward + self.death_tracker_ally = np.zeros(self.n_agents) + self.death_tracker_enemy = np.zeros(self.n_enemies) + self.previous_ally_units = None + self.previous_enemy_units = None + + self.last_action = np.zeros((self.n_agents, self.n_actions)) + + try: + self._obs = self.controller.observe() + self.init_units() + except (protocol.ProtocolError, protocol.ConnectionError): + self.full_restart() + + if self.debug: + logging.debug("Started Episode {}" + .format(self._episode_count).center(60, "*")) + + return self.get_obs(), self.get_state() + + def _restart(self): + """Restart the environment by killing all units on the map. + There is a trigger in the SC2Map file, which restarts the + episode when there are no units left. + """ + try: + self.kill_all_units() + self.controller.step(2) + except (protocol.ProtocolError, protocol.ConnectionError): + self.full_restart() + + def full_restart(self): + """Full restart. Closes the SC2 process and launches a new one. """ + self._sc2_proc.close() + self._launch() + self.force_restarts += 1 + + def step(self, actions): + """A single environment step. Returns reward, terminated, info.""" + actions = [int(a) for a in actions] + + self.last_action = np.eye(self.n_actions)[np.array(actions)] + + # Collect individual actions + sc_actions = [] + if self.debug: + logging.debug("Actions".center(60, "-")) + + for a_id, action in enumerate(actions): + agent_action = self.get_agent_action(a_id, action) + if agent_action: + sc_actions.append(agent_action) + + # Send action request + req_actions = sc_pb.RequestAction(actions=sc_actions) + try: + self.controller.actions(req_actions) + # Make step in SC2, i.e. apply actions + self.controller.step(self._step_mul) + # Observe here so that we know if the episode is over. + self._obs = self.controller.observe() + except (protocol.ProtocolError, protocol.ConnectionError): + self.full_restart() + return 0, True, {} + + self._total_steps += 1 + self._episode_steps += 1 + + # Update units + game_end_code = self.update_units() + + terminated = False + reward = self.reward_battle() + info = {"battle_won": False} + + if game_end_code is not None: + # Battle is over + terminated = True + self.battles_game += 1 + if game_end_code == 1: + self.battles_won += 1 + info["battle_won"] = True + if not self.reward_sparse: + reward += self.reward_win + else: + reward = 1 + elif game_end_code == -1: + if not self.reward_sparse: + reward += self.reward_defeat + else: + reward = -1 + + elif self._episode_steps >= self.episode_limit: + # Episode limit reached + terminated = True + if self.continuing_episode: + info["episode_limit"] = True + self.battles_game += 1 + self.timeouts += 1 + + if self.debug: + logging.debug("Reward = {}".format(reward).center(60, '-')) + + if self.reward_scale: + reward /= self.max_reward / self.reward_scale_rate + + return reward, terminated, info + + def get_agent_action(self, a_id, action): + """Construct the action for agent a_id.""" + avail_actions = self.get_avail_agent_actions(a_id) + assert avail_actions[action] == 1, \ + "Agent {} cannot perform action {}".format(a_id, action) + + unit = self.get_unit_by_id(a_id) + tag = unit.tag + x = unit.pos.x + y = unit.pos.y + + if action == 0: + # no-op (valid only when dead) + assert unit.health == 0, "No-op only available for dead agents." + if self.debug: + logging.debug("Agent {}: Dead".format(a_id)) + return None + elif action == 1: + # stop + cmd = r_pb.ActionRawUnitCommand( + ability_id=actions["stop"], + unit_tags=[tag], + queue_command=False) + if self.debug: + logging.debug("Agent {}: Stop".format(a_id)) + + elif action == 2: + # move north + cmd = r_pb.ActionRawUnitCommand( + ability_id=actions["move"], + target_world_space_pos=sc_common.Point2D( + x=x, y=y + self._move_amount), + unit_tags=[tag], + queue_command=False) + if self.debug: + logging.debug("Agent {}: Move North".format(a_id)) + + elif action == 3: + # move south + cmd = r_pb.ActionRawUnitCommand( + ability_id=actions["move"], + target_world_space_pos=sc_common.Point2D( + x=x, y=y - self._move_amount), + unit_tags=[tag], + queue_command=False) + if self.debug: + logging.debug("Agent {}: Move South".format(a_id)) + + elif action == 4: + # move east + cmd = r_pb.ActionRawUnitCommand( + ability_id=actions["move"], + target_world_space_pos=sc_common.Point2D( + x=x + self._move_amount, y=y), + unit_tags=[tag], + queue_command=False) + if self.debug: + logging.debug("Agent {}: Move East".format(a_id)) + + elif action == 5: + # move west + cmd = r_pb.ActionRawUnitCommand( + ability_id=actions["move"], + target_world_space_pos=sc_common.Point2D( + x=x - self._move_amount, y=y), + unit_tags=[tag], + queue_command=False) + if self.debug: + logging.debug("Agent {}: Move West".format(a_id)) + else: + # attack/heal units that are in range + target_id = action - self.n_actions_no_attack + if self.map_type == "MMM" and unit.unit_type == self.medivac_id: + target_unit = self.agents[target_id] + action_name = "heal" + else: + target_unit = self.enemies[target_id] + action_name = "attack" + + action_id = actions[action_name] + target_tag = target_unit.tag + + cmd = r_pb.ActionRawUnitCommand( + ability_id=action_id, + target_unit_tag=target_tag, + unit_tags=[tag], + queue_command=False) + + if self.debug: + logging.debug("Agent {} {}s unit # {}".format( + a_id, action_name, target_id)) + + sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd)) + return sc_action + + def reward_battle(self): + """Reward function when self.reward_spare==False. + Returns accumulative hit/shield point damage dealt to the enemy + + reward_death_value per enemy unit killed, and, in case + self.reward_only_positive == False, - (damage dealt to ally units + + reward_death_value per ally unit killed) * self.reward_negative_scale + """ + if self.reward_sparse: + return 0 + + reward = 0 + delta_deaths = 0 + delta_ally = 0 + delta_enemy = 0 + + neg_scale = self.reward_negative_scale + + # update deaths + for al_id, al_unit in self.agents.items(): + if not self.death_tracker_ally[al_id]: + # did not die so far + prev_health = ( + self.previous_ally_units[al_id].health + + self.previous_ally_units[al_id].shield + ) + if al_unit.health == 0: + # just died + self.death_tracker_ally[al_id] = 1 + if not self.reward_only_positive: + delta_deaths -= self.reward_death_value * neg_scale + delta_ally += prev_health * neg_scale + else: + # still alive + delta_ally += neg_scale * ( + prev_health - al_unit.health - al_unit.shield + ) + + for e_id, e_unit in self.enemies.items(): + if not self.death_tracker_enemy[e_id]: + prev_health = ( + self.previous_enemy_units[e_id].health + + self.previous_enemy_units[e_id].shield + ) + if e_unit.health == 0: + self.death_tracker_enemy[e_id] = 1 + delta_deaths += self.reward_death_value + delta_enemy += prev_health + else: + delta_enemy += prev_health - e_unit.health - e_unit.shield + + if self.reward_only_positive: + reward = abs(delta_enemy + delta_deaths) # shield regeneration + else: + reward = delta_enemy + delta_deaths - delta_ally + + return reward + + def get_total_actions(self): + """Returns the total number of actions an agent could ever take.""" + return self.n_actions + + @staticmethod + def distance(x1, y1, x2, y2): + """Distance between two points.""" + return math.hypot(x2 - x1, y2 - y1) + + def unit_shoot_range(self, agent_id): + """Returns the shooting range for an agent.""" + return 6 + + def unit_sight_range(self, agent_id): + """Returns the sight range for an agent.""" + return 9 + + def unit_max_cooldown(self, unit): + """Returns the maximal cooldown for a unit.""" + switcher = { + self.marine_id: 15, + self.marauder_id: 25, + self.medivac_id: 200, # max energy + self.stalker_id: 35, + self.zealot_id: 22, + self.colossus_id: 24, + self.hydralisk_id: 10, + self.zergling_id: 11, + self.baneling_id: 1 + } + return switcher.get(unit.unit_type, 15) + + def save_replay(self): + """Save a replay.""" + prefix = self.replay_prefix or self.map_name + replay_dir = self.replay_dir or "" + replay_path = self._run_config.save_replay( + self.controller.save_replay(), replay_dir=replay_dir, prefix=prefix) + logging.info("Replay saved at: %s" % replay_path) + + def unit_max_shield(self, unit): + """Returns maximal shield for a given unit.""" + if unit.unit_type == 74 or unit.unit_type == self.stalker_id: + return 80 # Protoss's Stalker + if unit.unit_type == 73 or unit.unit_type == self.zealot_id: + return 50 # Protoss's Zaelot + if unit.unit_type == 4 or unit.unit_type == self.colossus_id: + return 150 # Protoss's Colossus + + def can_move(self, unit, direction): + """Whether a unit can move in a given direction.""" + m = self._move_amount / 2 + + if direction == Direction.NORTH: + x, y = int(unit.pos.x), int(unit.pos.y + m) + elif direction == Direction.SOUTH: + x, y = int(unit.pos.x), int(unit.pos.y - m) + elif direction == Direction.EAST: + x, y = int(unit.pos.x + m), int(unit.pos.y) + else: + x, y = int(unit.pos.x - m), int(unit.pos.y) + + if self.check_bounds(x, y) and self.pathing_grid[x, y] == 0: + return True + + return False + + def get_surrounding_points(self, unit, include_self=False): + """Returns the surrounding points of the unit in 8 directions.""" + x = int(unit.pos.x) + y = int(unit.pos.y) + + ma = self._move_amount + + points = [ + (x, y + 2 * ma), + (x, y - 2 * ma), + (x + 2 * ma, y), + (x - 2 * ma, y), + (x + ma, y + ma), + (x - ma, y - ma), + (x + ma, y - ma), + (x - ma, y + ma), + ] + + if include_self: + points.append((x, y)) + + return points + + def check_bounds(self, x, y): + """Whether a point is within the map bounds.""" + return (0 <= x < self.map_x and 0 <= y < self.map_y) + + def get_surrounding_pathing(self, unit): + """Returns pathing values of the grid surrounding the given unit.""" + points = self.get_surrounding_points(unit, include_self=False) + vals = [ + self.pathing_grid[x, y] if self.check_bounds(x, y) else 1 + for x, y in points + ] + return vals + + def get_surrounding_height(self, unit): + """Returns height values of the grid surrounding the given unit.""" + points = self.get_surrounding_points(unit, include_self=True) + vals = [ + self.terrain_height[x, y] if self.check_bounds(x, y) else 1 + for x, y in points + ] + return vals + + def get_obs_agent(self, agent_id): + """Returns observation for agent_id. + NOTE: Agents should have access only to their local observations + during decentralised execution. + """ + unit = self.get_unit_by_id(agent_id) + + nf_al = 4 + self.unit_type_bits + nf_en = 4 + self.unit_type_bits + + if self.obs_all_health: + nf_al += 1 + self.shield_bits_ally + nf_en += 1 + self.shield_bits_enemy + + if self.obs_last_action: + nf_al += self.n_actions + + nf_own = self.unit_type_bits + if self.obs_own_health: + nf_own += 1 + self.shield_bits_ally + + move_feats_len = self.n_actions_move + if self.obs_pathing_grid: + move_feats_len += self.n_obs_pathing + if self.obs_terrain_height: + move_feats_len += self.n_obs_height + + move_feats = np.zeros(move_feats_len, dtype=np.float32) + enemy_feats = np.zeros((self.n_enemies, nf_en), dtype=np.float32) + ally_feats = np.zeros((self.n_agents - 1, nf_al), dtype=np.float32) + own_feats = np.zeros(nf_own, dtype=np.float32) + + if unit.health > 0: # otherwise dead, return all zeros + x = unit.pos.x + y = unit.pos.y + sight_range = self.unit_sight_range(agent_id) + + # Movement features + avail_actions = self.get_avail_agent_actions(agent_id) + for m in range(self.n_actions_move): + move_feats[m] = avail_actions[m + 2] + + ind = self.n_actions_move + + if self.obs_pathing_grid: + move_feats[ + ind : ind + self.n_obs_pathing + ] = self.get_surrounding_pathing(unit) + ind += self.n_obs_pathing + + if self.obs_terrain_height: + move_feats[ind:] = self.get_surrounding_height(unit) + + # Enemy features + for e_id, e_unit in self.enemies.items(): + e_x = e_unit.pos.x + e_y = e_unit.pos.y + dist = self.distance(x, y, e_x, e_y) + + if ( + dist < sight_range and e_unit.health > 0 + ): # visible and alive + # Sight range > shoot range + enemy_feats[e_id, 0] = avail_actions[ + self.n_actions_no_attack + e_id + ] # available + enemy_feats[e_id, 1] = dist / sight_range # distance + enemy_feats[e_id, 2] = ( + e_x - x + ) / sight_range # relative X + enemy_feats[e_id, 3] = ( + e_y - y + ) / sight_range # relative Y + + ind = 4 + if self.obs_all_health: + enemy_feats[e_id, ind] = ( + e_unit.health / e_unit.health_max + ) # health + ind += 1 + if self.shield_bits_enemy > 0: + max_shield = self.unit_max_shield(e_unit) + enemy_feats[e_id, ind] = ( + e_unit.shield / max_shield + ) # shield + ind += 1 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(e_unit, False) + enemy_feats[e_id, ind + type_id] = 1 # unit type + + # Ally features + al_ids = [ + al_id for al_id in range(self.n_agents) if al_id != agent_id + ] + for i, al_id in enumerate(al_ids): + + al_unit = self.get_unit_by_id(al_id) + al_x = al_unit.pos.x + al_y = al_unit.pos.y + dist = self.distance(x, y, al_x, al_y) + + if ( + dist < sight_range and al_unit.health > 0 + ): # visible and alive + ally_feats[i, 0] = 1 # visible + ally_feats[i, 1] = dist / sight_range # distance + ally_feats[i, 2] = (al_x - x) / sight_range # relative X + ally_feats[i, 3] = (al_y - y) / sight_range # relative Y + + ind = 4 + if self.obs_all_health: + ally_feats[i, ind] = ( + al_unit.health / al_unit.health_max + ) # health + ind += 1 + if self.shield_bits_ally > 0: + max_shield = self.unit_max_shield(al_unit) + ally_feats[i, ind] = ( + al_unit.shield / max_shield + ) # shield + ind += 1 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(al_unit, True) + ally_feats[i, ind + type_id] = 1 + ind += self.unit_type_bits + + if self.obs_last_action: + ally_feats[i, ind:] = self.last_action[al_id] + + # Own features + ind = 0 + if self.obs_own_health: + own_feats[ind] = unit.health / unit.health_max + ind += 1 + if self.shield_bits_ally > 0: + max_shield = self.unit_max_shield(unit) + own_feats[ind] = unit.shield / max_shield + ind += 1 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(unit, True) + own_feats[ind + type_id] = 1 + + agent_obs = np.concatenate( + ( + move_feats.flatten(), + enemy_feats.flatten(), + ally_feats.flatten(), + own_feats.flatten(), + ) + ) + + if self.debug: + logging.debug("Obs Agent: {}".format(agent_id).center(60, "-")) + logging.debug("Avail. actions {}".format( + self.get_avail_agent_actions(agent_id))) + logging.debug("Move feats {}".format(move_feats)) + logging.debug("Enemy feats {}".format(enemy_feats)) + logging.debug("Ally feats {}".format(ally_feats)) + logging.debug("Own feats {}".format(own_feats)) + + return agent_obs + + def get_obs(self): + """Returns all agent observations in a list. + NOTE: Agents should have access only to their local observations + during decentralised execution. + """ + agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)] + return agents_obs + + def get_state(self): + """Returns the global state. + NOTE: This functon should not be used during decentralised execution. + """ + if self.obs_instead_of_state: + obs_concat = np.concatenate(self.get_obs(), axis=0).astype( + np.float32 + ) + return obs_concat + + nf_al = 4 + self.shield_bits_ally + self.unit_type_bits + nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits + + ally_state = np.zeros((self.n_agents, nf_al)) + enemy_state = np.zeros((self.n_enemies, nf_en)) + + center_x = self.map_x / 2 + center_y = self.map_y / 2 + + for al_id, al_unit in self.agents.items(): + if al_unit.health > 0: + x = al_unit.pos.x + y = al_unit.pos.y + max_cd = self.unit_max_cooldown(al_unit) + + ally_state[al_id, 0] = ( + al_unit.health / al_unit.health_max + ) # health + if ( + self.map_type == "MMM" + and al_unit.unit_type == self.medivac_id + ): + ally_state[al_id, 1] = al_unit.energy / max_cd # energy + else: + ally_state[al_id, 1] = ( + al_unit.weapon_cooldown / max_cd + ) # cooldown + ally_state[al_id, 2] = ( + x - center_x + ) / self.max_distance_x # relative X + ally_state[al_id, 3] = ( + y - center_y + ) / self.max_distance_y # relative Y + + ind = 4 + if self.shield_bits_ally > 0: + max_shield = self.unit_max_shield(al_unit) + ally_state[al_id, ind] = ( + al_unit.shield / max_shield + ) # shield + ind += 1 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(al_unit, True) + ally_state[al_id, ind + type_id] = 1 + + for e_id, e_unit in self.enemies.items(): + if e_unit.health > 0: + x = e_unit.pos.x + y = e_unit.pos.y + + enemy_state[e_id, 0] = ( + e_unit.health / e_unit.health_max + ) # health + enemy_state[e_id, 1] = ( + x - center_x + ) / self.max_distance_x # relative X + enemy_state[e_id, 2] = ( + y - center_y + ) / self.max_distance_y # relative Y + + ind = 3 + if self.shield_bits_enemy > 0: + max_shield = self.unit_max_shield(e_unit) + enemy_state[e_id, ind] = ( + e_unit.shield / max_shield + ) # shield + ind += 1 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(e_unit, False) + enemy_state[e_id, ind + type_id] = 1 + + state = np.append(ally_state.flatten(), enemy_state.flatten()) + if self.state_last_action: + state = np.append(state, self.last_action.flatten()) + state = state.astype(dtype=np.float32) + + if self.debug: + logging.debug("STATE".center(60, "-")) + logging.debug("Ally state {}".format(ally_state)) + logging.debug("Enemy state {}".format(enemy_state)) + if self.state_last_action: + logging.debug("Last actions {}".format(self.last_action)) + + return state + + def get_obs_size(self): + """Returns the size of the observation.""" + nf_al = 4 + self.unit_type_bits + nf_en = 4 + self.unit_type_bits + + if self.obs_all_health: + nf_al += 1 + self.shield_bits_ally + nf_en += 1 + self.shield_bits_enemy + + own_feats = self.unit_type_bits + if self.obs_own_health: + own_feats += 1 + self.shield_bits_ally + + if self.obs_last_action: + nf_al += self.n_actions + + move_feats = self.n_actions_move + if self.obs_pathing_grid: + move_feats += self.n_obs_pathing + if self.obs_terrain_height: + move_feats += self.n_obs_height + + enemy_feats = self.n_enemies * nf_en + ally_feats = (self.n_agents - 1) * nf_al + + return move_feats + enemy_feats + ally_feats + own_feats + + def get_state_size(self): + """Returns the size of the global state.""" + if self.obs_instead_of_state: + return self.get_obs_size() * self.n_agents + + nf_al = 4 + self.shield_bits_ally + self.unit_type_bits + nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits + + enemy_state = self.n_enemies * nf_en + ally_state = self.n_agents * nf_al + + size = enemy_state + ally_state + + if self.state_last_action: + size += self.n_agents * self.n_actions + + return size + + def get_unit_type_id(self, unit, ally): + """Returns the ID of unit type in the given scenario.""" + if ally: # use new SC2 unit types + type_id = unit.unit_type - self._min_unit_type + else: # use default SC2 unit types + if self.map_type == "stalkers_and_zealots": + # id(Stalker) = 74, id(Zealot) = 73 + type_id = unit.unit_type - 73 + if self.map_type == "bane": + if unit.unit_type == 9: + type_id = 0 + else: + type_id = 1 + elif self.map_type == "MMM": + if unit.unit_type == 51: + type_id = 0 + elif unit.unit_type == 48: + type_id = 1 + else: + type_id = 2 + + return type_id + + def get_avail_agent_actions(self, agent_id): + """Returns the available actions for agent_id.""" + unit = self.get_unit_by_id(agent_id) + if unit.health > 0: + # cannot choose no-op when alive + avail_actions = [0] * self.n_actions + + # stop should be allowed + avail_actions[1] = 1 + + # see if we can move + if self.can_move(unit, Direction.NORTH): + avail_actions[2] = 1 + if self.can_move(unit, Direction.SOUTH): + avail_actions[3] = 1 + if self.can_move(unit, Direction.EAST): + avail_actions[4] = 1 + if self.can_move(unit, Direction.WEST): + avail_actions[5] = 1 + + # Can attack only alive units that are alive in the shooting range + shoot_range = self.unit_shoot_range(agent_id) + + target_items = self.enemies.items() + if self.map_type == "MMM" and unit.unit_type == self.medivac_id: + # Medivacs cannot heal themselves or other flying units + target_items = [ + (t_id, t_unit) + for (t_id, t_unit) in self.agents.items() + if t_unit.unit_type != self.medivac_id + ] + + for t_id, t_unit in target_items: + if t_unit.health > 0: + dist = self.distance( + unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y + ) + if dist <= shoot_range: + avail_actions[t_id + self.n_actions_no_attack] = 1 + + return avail_actions + + else: + # only no-op allowed + return [1] + [0] * (self.n_actions - 1) + + def get_avail_actions(self): + """Returns the available actions of all agents in a list.""" + avail_actions = [] + for agent_id in range(self.n_agents): + avail_agent = self.get_avail_agent_actions(agent_id) + avail_actions.append(avail_agent) + return avail_actions + + def close(self): + """Close StarCraft II.""" + self._sc2_proc.close() + + def seed(self): + """Returns the random seed used by the environment.""" + return self.seed + + def render(self): + """Not implemented.""" + pass + + def kill_all_units(self): + """Kill all units on the map.""" + units_alive = [ + unit.tag for unit in self.agents.values() if unit.health > 0 + ] + [unit.tag for unit in self.enemies.values() if unit.health > 0] + debug_command = [ + d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive)) + ] + self.controller.debug(debug_command) + + def init_units(self): + """Initialise the units.""" + while True: + # Sometimes not all units have yet been created by SC2 + self.agents = {} + self.enemies = {} + + ally_units = [ + unit + for unit in self._obs.observation.raw_data.units + if unit.owner == 1 + ] + ally_units_sorted = sorted( + ally_units, + key=attrgetter("unit_type", "pos.x", "pos.y"), + reverse=False, + ) + + for i in range(len(ally_units_sorted)): + self.agents[i] = ally_units_sorted[i] + if self.debug: + logging.debug( + "Unit {} is {}, x = {}, y = {}".format( + len(self.agents), + self.agents[i].unit_type, + self.agents[i].pos.x, + self.agents[i].pos.y, + ) + ) + + for unit in self._obs.observation.raw_data.units: + if unit.owner == 2: + self.enemies[len(self.enemies)] = unit + if self._episode_count == 1: + self.max_reward += unit.health_max + unit.shield_max + + if self._episode_count == 1: + min_unit_type = min( + unit.unit_type for unit in self.agents.values() + ) + self._init_ally_unit_types(min_unit_type) + + all_agents_created = (len(self.agents) == self.n_agents) + all_enemies_created = (len(self.enemies) == self.n_enemies) + + if all_agents_created and all_enemies_created: # all good + return + + try: + self.controller.step(1) + self._obs = self.controller.observe() + except (protocol.ProtocolError, protocol.ConnectionError): + self.full_restart() + self.reset() + + def update_units(self): + """Update units after an environment step. + This function assumes that self._obs is up-to-date. + """ + n_ally_alive = 0 + n_enemy_alive = 0 + + # Store previous state + self.previous_ally_units = deepcopy(self.agents) + self.previous_enemy_units = deepcopy(self.enemies) + + for al_id, al_unit in self.agents.items(): + updated = False + for unit in self._obs.observation.raw_data.units: + if al_unit.tag == unit.tag: + self.agents[al_id] = unit + updated = True + n_ally_alive += 1 + break + + if not updated: # dead + al_unit.health = 0 + + for e_id, e_unit in self.enemies.items(): + updated = False + for unit in self._obs.observation.raw_data.units: + if e_unit.tag == unit.tag: + self.enemies[e_id] = unit + updated = True + n_enemy_alive += 1 + break + + if not updated: # dead + e_unit.health = 0 + + if (n_ally_alive == 0 and n_enemy_alive > 0 + or self.only_medivac_left(ally=True)): + return -1 # lost + if (n_ally_alive > 0 and n_enemy_alive == 0 + or self.only_medivac_left(ally=False)): + return 1 # won + if n_ally_alive == 0 and n_enemy_alive == 0: + return 0 + + return None + + def _init_ally_unit_types(self, min_unit_type): + """Initialise ally unit types. Should be called once from the + init_units function. + """ + self._min_unit_type = min_unit_type + if self.map_type == "marines": + self.marine_id = min_unit_type + elif self.map_type == "stalkers_and_zealots": + self.stalker_id = min_unit_type + self.zealot_id = min_unit_type + 1 + elif self.map_type == "MMM": + self.marauder_id = min_unit_type + self.marine_id = min_unit_type + 1 + self.medivac_id = min_unit_type + 2 + elif self.map_type == "zealots": + self.zealot_id = min_unit_type + elif self.map_type == "hydralisks": + self.hydralisk_id = min_unit_type + elif self.map_type == "stalkers": + self.stalker_id = min_unit_type + elif self.map_type == "colossus": + self.colossus_id = min_unit_type + elif self.map_type == "bane": + self.baneling_id = min_unit_type + self.zergling_id = min_unit_type + 1 + + def only_medivac_left(self, ally): + """Check if only Medivac units are left.""" + if self.map_type != "MMM": + return False + + if ally: + units_alive = [ + a + for a in self.agents.values() + if (a.health > 0 and a.unit_type != self.medivac_id) + ] + if len(units_alive) == 0: + return True + return False + else: + units_alive = [ + a + for a in self.enemies.values() + if (a.health > 0 and a.unit_type != self.medivac_id) + ] + if len(units_alive) == 1 and units_alive[0].unit_type == 54: + return True + return False + + def get_unit_by_id(self, a_id): + """Get unit by ID.""" + return self.agents[a_id] + + def get_stats(self): + stats = { + "battles_won": self.battles_won, + "battles_game": self.battles_game, + "battles_draw": self.timeouts, + "win_rate": self.battles_won / self.battles_game, + "timeouts": self.timeouts, + "restarts": self.force_restarts, + } + return stats diff --git a/smac/examples/__init__.py b/smac/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/smac/examples/random.py b/smac/examples/random.py new file mode 100644 index 00000000..e39d1582 --- /dev/null +++ b/smac/examples/random.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from smac.env import StarCraft2Env +import numpy as np + + +def main(): + env = StarCraft2Env(map_name="8m") + env_info = env.get_env_info() + + n_actions = env_info["n_actions"] + n_agents = env_info["n_agents"] + + n_episodes = 10 + + for e in range(n_episodes): + env.reset() + terminated = False + episode_reward = 0 + + while not terminated: + obs = env.get_obs() + state = env.get_state() + + actions = [] + for agent_id in range(n_agents): + avail_actions = env.get_avail_agent_actions(agent_id) + avail_actions_ind = np.nonzero(avail_actions)[0] + action = np.random.choice(avail_actions_ind) + actions.append(action) + + reward, terminated, _ = env.step(actions) + episode_reward += reward + + print("Total reward in episode {} = {}".format(e, episode_reward)) + + env.close() + + +if __name__ == "__main__": + main()