From a63ab6c3d2fa76f81fc357cda3e62bf92d8f85e1 Mon Sep 17 00:00:00 2001 From: Haneul Choi Date: Sat, 9 Dec 2023 18:05:50 +0000 Subject: [PATCH] refactor: add docstrings --- src/constants.py | 3 ++ src/ppo.py | 21 +++++++---- src/preprocess.py | 95 ++++++++++++++++++++++++++++++++++++----------- 3 files changed, 90 insertions(+), 29 deletions(-) diff --git a/src/constants.py b/src/constants.py index 3d7e4a5..0439b81 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,3 +1,6 @@ +""" +Constants for default values of LUX AI S2 env config +""" MAP_SIZE = 64 LIGHT_BATTERY_CAPACITY = 150 HEAVY_BATTERY_CAPACITY = 3000 diff --git a/src/ppo.py b/src/ppo.py index ae4fb3c..d3688f5 100644 --- a/src/ppo.py +++ b/src/ppo.py @@ -1,21 +1,26 @@ +"""Implement PPO algorithm for JuxEnvBatch +""" +from typing import NamedTuple + import jax import jax.numpy as jnp from jax import Array, jit, tree_map import optax -from typing import NamedTuple from flax.training.train_state import TrainState -import distrax from jux.env import JuxEnv, JuxEnvBatch from jux.config import JuxBufferConfig, EnvConfig from jux.state import State from jux.actions import JuxAction -from jux.unit_cargo import ResourceType from preprocess import get_feature -from constants import * -from space import ObsSpace, ActionSpace -from utils import StateSkeleton, get_seeds +from constants import MAX_EPISODE_LENGTH +from space import ObsSpace +from utils import get_seeds + + +def calculate_gae() + class PPOConfig(NamedTuple): @@ -26,7 +31,7 @@ class PPOConfig(NamedTuple): N_UPDATES: int = 1000 N_EPISODES_PER_ENV: int = 16 UPDATE_EPOCHS: int = 4 - NUM_MINIBATCHES: int = 32 + NUM_MINIBATCHES: int = 32 GAMMA: float = 0.99 GAE_LAMBDA: float = 0.96 @@ -53,6 +58,8 @@ class UpdateState(NamedTuple): train_state: TrainState rng: jax.Array + + def make_train(env_config: EnvConfig, buf_config: JuxBufferConfig, ppo_config: PPOConfig, actor_critic, bid_agent, factory_placement_agent, rng): """ diff --git a/src/preprocess.py b/src/preprocess.py index d9ccbe6..3a0ddbd 100644 --- a/src/preprocess.py +++ b/src/preprocess.py @@ -1,3 +1,5 @@ +"""Defines utility functions for preprocessing Jux environment state +""" import jux from jux.env import JuxEnv from jux.config import JuxBufferConfig @@ -6,7 +8,7 @@ import jax import jax.numpy as jnp -from jax import tree_map, vmap, jit +from jax import jit, vmap from jax import Array from functools import partial @@ -20,34 +22,42 @@ # First vectorize on the feature axis, and then on the team axis @partial(vmap, in_axes=0, out_axes=0) # team axis @partial(vmap, in_axes=(None, -1), out_axes=-1) # flax follows channels-last convention -def to_board(pos, unit_info): - ''' - n_info: number of features to embed in the board (vectorized axis) +def to_board(pos, unit_info)->Array: + '''Embed unit-related features in the board. + + Args: + pos: ShapedArray(int8[2, MAX_N_UNITS, 2]) + unit_info: ShapedArray(int8[2, MAX_N_UNITS, n_info]) - unit_info: ShapedArray(int8[2, MAX_N_UNITS, n_info]) - pos: ShapedArray(int8[2, MAX_N_UNITS, 2]) - - out: ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, n_info]) + Returns: + ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, n_info]) ''' zeros = jnp.zeros((MAP_SIZE, MAP_SIZE)) # `mode=drop` prevents unexpected index-out-of-bound behavior out = zeros.at[pos.x, pos.y].set(unit_info, mode='drop') - - return out @partial(vmap, in_axes=0, out_axes=0) # batch axis @partial(vmap, in_axes=0, out_axes=-2) # team axis def to_board_for(pos: Array, unit_info: Array): - map = jnp.zeros((MAP_SIZE, MAP_SIZE, unit_info.shape[-1])) - def _to_board_i(i, map): + '''Embed unit-related features in the board. + + Args: + pos: ShapedArray(int8[2, MAX_N_UNITS, 2]) + unit_info: ShapedArray(int8[2, MAX_N_UNITS, n_info]) + + Returns: + ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, n_info]) + ''' + features = jnp.zeros((MAP_SIZE, MAP_SIZE, unit_info.shape[-1])) + def _to_board_i(i, features): loc = pos[i] - return map.at[loc[0], loc[1]].set(unit_info[i], mode='drop') - map = jax.lax.fori_loop(0, pos.shape[0], _to_board_i, map) - return map + return features.at[loc[0], loc[1]].set(unit_info[i], mode='drop') + features = jax.lax.fori_loop(0, pos.shape[0], _to_board_i, features) + return features def get_unit_feature(states: State)->Array: ''' @@ -55,7 +65,7 @@ def get_unit_feature(states: State)->Array: output: ShapedArray(int8[MAP_SIZE, MAP_SIZE, 24]) feature: [light_existence, heavy_existence, (current) ice, ore, water, metal, power, (cargo empty space) ice, ore, water, metal, power] - ''' + ''' unit_mask, unit_type, cargo, power, pos = states.unit_mask, states.units.unit_type, states.units.cargo.stock, states.units.power, states.units.pos @@ -90,7 +100,8 @@ def get_factory_feature(states: State)->Array: return factory_feature_map def get_board_feature(states: State) -> Array: - """ + """Return board-related features from batched + state: State output: ShapedArray(int8[MAP_SIZE, MAP_SIZE, 4]) """ @@ -98,6 +109,23 @@ def get_board_feature(states: State) -> Array: return board_feature_map def get_global_feature(states: State) -> Array: + """Return global-features from batched Jux state. + + Global-features include + 1) One-hot encoded features + - cycle in episode + - turn in cycle + - day/night + + 2) Real-valued features0 + - lichen score of each team + + Args: + state (jux.state.State): A batched Jux environment state + + Returns: + Array: an array of global-features + """ real_env_steps = states.real_env_steps cycle, turn_in_cycle = jnp.divmod(real_env_steps, CYCLE_LENGTH) is_day = (real_env_steps % CYCLE_LENGTH) < DAY_LENGTH @@ -110,15 +138,24 @@ def lichen_score(state: State): vmap(lichen_score, in_axes=(StateSkeleton,))(states), ], axis=-1) +@jit def get_feature(states: State) -> ObsSpace: - """ - state: State - output: ShapedArray(int8[MAP_SIZE, MAP_SIZE, C]) + """Return features for 'player_0' from batched Jux state. + + Returned features including global-feature and local-feature + + Args: + state : State (Batched) + + Returns: + ObsSpace: feature map to be used in a deep rl model """ unit_feature_map = get_unit_feature(states) - unit_feature_map = unit_feature_map.reshape((unit_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1)) + unit_feature_map = unit_feature_map.reshape( + (unit_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1)) factory_feature_map = get_factory_feature(states) - factory_feature_map = factory_feature_map.reshape((factory_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1)) + factory_feature_map = factory_feature_map.reshape( + (factory_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1)) board_feature_map = get_board_feature(states) global_feature = get_global_feature(states) @@ -130,6 +167,20 @@ def get_feature(states: State) -> ObsSpace: return ObsSpace(local_feature, global_feature, states) def get_feature_teams(states: State)-> tuple[ObsSpace, ObsSpace]: + """Return feature for both 'player_0' and 'player_1' + + Among features are ally-related features, enemy-related features and neutr- + al features. As whether a specific feature is related to ally or enemy depe + nds on the id of player, features for both players are provided in this fun + ction so that both player can use models that exploit our set of features. + + Args: + state: State (Batched) + + Returns: + ObsSpace: feature map to be used in a deep rl model (for player_0) + ObsSpace: feature map to be used in a deep rl model (for player_1) + """ unit_feature_map = get_unit_feature(states) factory_feature_map = get_factory_feature(states) board_feature_map = get_board_feature(states)