Skip to content

Commit

Permalink
refactor: add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
caelum02 authored Dec 9, 2023
1 parent b6c69a6 commit a63ab6c
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 29 deletions.
3 changes: 3 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Constants for default values of LUX AI S2 env config
"""
MAP_SIZE = 64
LIGHT_BATTERY_CAPACITY = 150
HEAVY_BATTERY_CAPACITY = 3000
Expand Down
21 changes: 14 additions & 7 deletions src/ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
"""Implement PPO algorithm for JuxEnvBatch
"""
from typing import NamedTuple

import jax
import jax.numpy as jnp
from jax import Array, jit, tree_map
import optax
from typing import NamedTuple
from flax.training.train_state import TrainState
import distrax

from jux.env import JuxEnv, JuxEnvBatch
from jux.config import JuxBufferConfig, EnvConfig
from jux.state import State
from jux.actions import JuxAction
from jux.unit_cargo import ResourceType

from preprocess import get_feature
from constants import *
from space import ObsSpace, ActionSpace
from utils import StateSkeleton, get_seeds
from constants import MAX_EPISODE_LENGTH
from space import ObsSpace
from utils import get_seeds


def calculate_gae()



class PPOConfig(NamedTuple):
Expand All @@ -26,7 +31,7 @@ class PPOConfig(NamedTuple):
N_UPDATES: int = 1000
N_EPISODES_PER_ENV: int = 16
UPDATE_EPOCHS: int = 4
NUM_MINIBATCHES: int = 32
NUM_MINIBATCHES: int = 32

GAMMA: float = 0.99
GAE_LAMBDA: float = 0.96
Expand All @@ -53,6 +58,8 @@ class UpdateState(NamedTuple):
train_state: TrainState
rng: jax.Array



def make_train(env_config: EnvConfig, buf_config: JuxBufferConfig, ppo_config: PPOConfig,
actor_critic, bid_agent, factory_placement_agent, rng):
"""
Expand Down
95 changes: 73 additions & 22 deletions src/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Defines utility functions for preprocessing Jux environment state
"""
import jux
from jux.env import JuxEnv
from jux.config import JuxBufferConfig
Expand All @@ -6,7 +8,7 @@

import jax
import jax.numpy as jnp
from jax import tree_map, vmap, jit
from jax import jit, vmap
from jax import Array

from functools import partial
Expand All @@ -20,42 +22,50 @@
# First vectorize on the feature axis, and then on the team axis
@partial(vmap, in_axes=0, out_axes=0) # team axis
@partial(vmap, in_axes=(None, -1), out_axes=-1) # flax follows channels-last convention
def to_board(pos, unit_info):
'''
n_info: number of features to embed in the board (vectorized axis)
def to_board(pos, unit_info)->Array:
'''Embed unit-related features in the board.
Args:
pos: ShapedArray(int8[2, MAX_N_UNITS, 2])
unit_info: ShapedArray(int8[2, MAX_N_UNITS, n_info])
unit_info: ShapedArray(int8[2, MAX_N_UNITS, n_info])
pos: ShapedArray(int8[2, MAX_N_UNITS, 2])
out: ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, n_info])
Returns:
ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, n_info])
'''

zeros = jnp.zeros((MAP_SIZE, MAP_SIZE))

# `mode=drop` prevents unexpected index-out-of-bound behavior
out = zeros.at[pos.x, pos.y].set(unit_info, mode='drop')


return out


@partial(vmap, in_axes=0, out_axes=0) # batch axis
@partial(vmap, in_axes=0, out_axes=-2) # team axis
def to_board_for(pos: Array, unit_info: Array):
map = jnp.zeros((MAP_SIZE, MAP_SIZE, unit_info.shape[-1]))
def _to_board_i(i, map):
'''Embed unit-related features in the board.
Args:
pos: ShapedArray(int8[2, MAX_N_UNITS, 2])
unit_info: ShapedArray(int8[2, MAX_N_UNITS, n_info])
Returns:
ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, n_info])
'''
features = jnp.zeros((MAP_SIZE, MAP_SIZE, unit_info.shape[-1]))
def _to_board_i(i, features):
loc = pos[i]
return map.at[loc[0], loc[1]].set(unit_info[i], mode='drop')
map = jax.lax.fori_loop(0, pos.shape[0], _to_board_i, map)
return map
return features.at[loc[0], loc[1]].set(unit_info[i], mode='drop')
features = jax.lax.fori_loop(0, pos.shape[0], _to_board_i, features)
return features

def get_unit_feature(states: State)->Array:
'''
state: State
output: ShapedArray(int8[MAP_SIZE, MAP_SIZE, 24])
feature: [light_existence, heavy_existence, (current) ice, ore, water, metal, power, (cargo empty space) ice, ore, water, metal, power]
'''
'''

unit_mask, unit_type, cargo, power, pos = states.unit_mask, states.units.unit_type, states.units.cargo.stock, states.units.power, states.units.pos

Expand Down Expand Up @@ -90,14 +100,32 @@ def get_factory_feature(states: State)->Array:
return factory_feature_map

def get_board_feature(states: State) -> Array:
"""
"""Return board-related features from batched
state: State
output: ShapedArray(int8[MAP_SIZE, MAP_SIZE, 4])
"""
board_feature_map = jnp.stack([states.board.lichen, states.board.map.rubble, states.board.map.ice, states.board.map.ore], axis=-1)
return board_feature_map

def get_global_feature(states: State) -> Array:
"""Return global-features from batched Jux state.
Global-features include
1) One-hot encoded features
- cycle in episode
- turn in cycle
- day/night
2) Real-valued features0
- lichen score of each team
Args:
state (jux.state.State): A batched Jux environment state
Returns:
Array: an array of global-features
"""
real_env_steps = states.real_env_steps
cycle, turn_in_cycle = jnp.divmod(real_env_steps, CYCLE_LENGTH)
is_day = (real_env_steps % CYCLE_LENGTH) < DAY_LENGTH
Expand All @@ -110,15 +138,24 @@ def lichen_score(state: State):
vmap(lichen_score, in_axes=(StateSkeleton,))(states),
], axis=-1)

@jit
def get_feature(states: State) -> ObsSpace:
"""
state: State
output: ShapedArray(int8[MAP_SIZE, MAP_SIZE, C])
"""Return features for 'player_0' from batched Jux state.
Returned features including global-feature and local-feature
Args:
state : State (Batched)
Returns:
ObsSpace: feature map to be used in a deep rl model
"""
unit_feature_map = get_unit_feature(states)
unit_feature_map = unit_feature_map.reshape((unit_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1))
unit_feature_map = unit_feature_map.reshape(
(unit_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1))
factory_feature_map = get_factory_feature(states)
factory_feature_map = factory_feature_map.reshape((factory_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1))
factory_feature_map = factory_feature_map.reshape(
(factory_feature_map.shape[0], MAP_SIZE, MAP_SIZE, -1))
board_feature_map = get_board_feature(states)
global_feature = get_global_feature(states)

Expand All @@ -130,6 +167,20 @@ def get_feature(states: State) -> ObsSpace:
return ObsSpace(local_feature, global_feature, states)

def get_feature_teams(states: State)-> tuple[ObsSpace, ObsSpace]:
"""Return feature for both 'player_0' and 'player_1'
Among features are ally-related features, enemy-related features and neutr-
al features. As whether a specific feature is related to ally or enemy depe
nds on the id of player, features for both players are provided in this fun
ction so that both player can use models that exploit our set of features.
Args:
state: State (Batched)
Returns:
ObsSpace: feature map to be used in a deep rl model (for player_0)
ObsSpace: feature map to be used in a deep rl model (for player_1)
"""
unit_feature_map = get_unit_feature(states)
factory_feature_map = get_factory_feature(states)
board_feature_map = get_board_feature(states)
Expand Down

0 comments on commit a63ab6c

Please sign in to comment.