Skip to content

Commit

Permalink
Add factory feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
caelum02 authored Oct 29, 2023
1 parent afd3344 commit e410fa8
Show file tree
Hide file tree
Showing 3 changed files with 890 additions and 356 deletions.
31 changes: 21 additions & 10 deletions src/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from jux.env import JuxEnv
from jux.config import JuxBufferConfig
from jux.unit import UnitType
from jux.state import State

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -39,20 +40,16 @@ def to_board(x, y, unit_info):
return out

@jit
def get_unit_feature(unit_mask, unit_type, cargo, power, x, y):
def get_unit_feature(state: State)->jnp.ndarray:
'''
unit_mask : ShapedArray(bool[2, MAX_N_UNITS])
unit_type : ShapedArray(bool[2, MAX_N_UNITS])
cargo: ShapedArray(int32[2, MAX_N_UNITS, 4])
power: ShapedArray(int32[2, MAX_N_UNITS])
x : ShapedArray(int8[2, MAX_N_UNITS])
y : ShapedArray(int8[2, MAX_N_UNITS])
state: State
output: ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, 12])
feature: [light_existence, heavy_existence, (current) ice, ore, water, metal, power, (cargo empty space) ice, ore, water, metal, power]
'''

unit_mask, unit_type, cargo, power, x, y = state.unit_mask, state.units.unit_type, state.units.cargo.stock, state.units.power, state.units.pos.x, state.units.pos.y

light_mask = unit_mask & (unit_type==UnitType.LIGHT)
heavy_mask = unit_mask & (unit_type==UnitType.HEAVY)
unit_mask_per_type = jnp.stack((light_mask, heavy_mask), axis=-1)
Expand All @@ -65,9 +62,23 @@ def get_unit_feature(unit_mask, unit_type, cargo, power, x, y):

feature = jnp.concatenate((unit_mask_per_type, cargo, power[...,None], cargo_left, battery_left[...,None]), axis=-1)

unit_resource_map = to_board(x, y, feature)
unit_feature_map = to_board(x, y, feature)

return unit_feature_map

@jit
def get_factory_feature(state: State, power_previous: jnp.ndarray)->jnp.ndarray:
"""
state: State
output: ShapedArray(int8[2, MAP_SIZE, MAP_SIZE, 7])
"""
factory_mask, cargo, power, x, y = state.factory_mask, state.factories.cargo.stock, state.factories.power, state.factories.pos.x, state.factories.pos.y
feature = jnp.concatenate((factory_mask[..., None], cargo, power[...,None], power_previous[...,None]), axis=-1)

factory_feature_map = to_board(x, y, feature)

return factory_feature_map

return unit_resource_map

if __name__=="__main__":

Expand Down
Loading

0 comments on commit e410fa8

Please sign in to comment.