Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing a bot for sc2 collectables minigame #102

Draft
wants to merge 9 commits into
base: v2
Choose a base branch
from
49 changes: 49 additions & 0 deletions sc2_collectables_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pathlib
import sys

sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))

from absl import app
from pysc2.env import sc2_env

from urnai.models.dqn_pytorch import DQNPytorch
from urnai.sc2.actions.collectables import CollectablesActionSpace
from urnai.sc2.agents.sc2_agent import SC2Agent
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.rewards.collectables import CollectablesReward
from urnai.sc2.states.collectables import CollectablesState
from urnai.trainers.trainer import Trainer


def declare_trainer():
players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)


action_space = CollectablesActionSpace()
state_builder = CollectablesState()
reward_builder = CollectablesReward()

model = DQNPytorch(action_space, state_builder)

agent = SC2Agent(action_space, state_builder, model, reward_builder)

trainer = Trainer(env, agent,
max_training_episodes=200, max_steps_training=100000,
max_playing_episodes=200, max_steps_playing=100000)
return trainer

def main(unused_argv):
try:
trainer = declare_trainer()
#trainer.load("saves/")
trainer.train()
# trainer.play()

except KeyboardInterrupt:
pass


if __name__ == '__main__':
app.run(main)
15 changes: 10 additions & 5 deletions tests/units/base/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class FakePersistence(Persistence):
def __init__(self, threaded_saving=False):
super().__init__(threaded_saving)

class SimpleClass:
def __init__(self):
pass

class TestPersistence(unittest.TestCase):

def test_abstract_methods(self):
Expand Down Expand Up @@ -41,7 +45,7 @@ def test_get_default_save_stamp(self):

# THEN
self.assertEqual(_get_def_save_stamp_return,
fake_persistence.__class__.__name__ + '_')
fake_persistence.object_to_save.__class__.__name__ + '_')

def test_get_full_persistance_path(self):

Expand Down Expand Up @@ -90,14 +94,15 @@ def test_threaded_save(self):
def test_restore_attributes(self):

# GIVEN
fake_persistence = FakePersistence()
obj_to_save = SimpleClass()
fake_persistence = FakePersistence(obj_to_save)
dict_to_restore = {"TestAttribute1": 314, "TestAttribute2": "string"}

# WHEN
fake_persistence.attr_block_list = ["TestAttribute2"]
fake_persistence._restore_attributes(dict_to_restore)

# THEN
assert hasattr(fake_persistence, "TestAttribute1") is True
assert hasattr(fake_persistence, "TestAttribute2") is False
self.assertEqual(fake_persistence.TestAttribute1, 314)
assert hasattr(fake_persistence.object_to_save, "TestAttribute1") is True
assert hasattr(fake_persistence.object_to_save, "TestAttribute2") is False
self.assertEqual(fake_persistence.object_to_save.TestAttribute1, 314)
14 changes: 10 additions & 4 deletions tests/units/base/test_persistence_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ class FakePersistencePickle(PersistencePickle):
def __init__(self, threaded_saving=False):
super().__init__(threaded_saving)

class SimpleClass:
def __init__(self):
pass

class TestPersistence(unittest.TestCase):

@patch('urnai.base.persistence_pickle.PersistencePickle._simple_save')
Expand Down Expand Up @@ -46,21 +50,23 @@ def test_load(self, mock_load):
def test_get_attributes(self):

# GIVEN
fake_persistence_pickle = FakePersistencePickle()
obj_to_save = SimpleClass()
fake_persistence_pickle = FakePersistencePickle(obj_to_save)

# WHEN
return_list = fake_persistence_pickle._get_attributes()

# THEN
self.assertEqual(return_list, ['threaded_saving'])
self.assertEqual(return_list, [])

def test_get_dict(self):

# GIVEN
fake_persistence_pickle = FakePersistencePickle()
obj_to_save = SimpleClass()
fake_persistence_pickle = FakePersistencePickle(obj_to_save)

# WHEN
return_dict = fake_persistence_pickle._get_dict()

# THEN
self.assertEqual(return_dict, {"threaded_saving": False})
self.assertEqual(return_dict, {})
4 changes: 2 additions & 2 deletions tests/units/models/test_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(self):
learn_return = fake_model.learn("current_state",
"action", "reward", "next_state")
predict_return = fake_model.predict("state")
learning_dict = fake_model.learning_dict
learning_dict = fake_model.learning_data

assert isinstance(ModelBase, ABCMeta)
assert learn_return is None
assert predict_return is None
assert learning_dict is None
self.assertEqual(learning_dict, {})
2 changes: 1 addition & 1 deletion tests/units/sc2/actions/test_sc2_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pysc2.lib import actions

from urnai.actions.sc2_action import SC2Action
from urnai.sc2.actions.sc2_action import SC2Action

_BUILD_REFINERY = actions.RAW_FUNCTIONS.Build_Refinery_pt
_NO_OP = actions.FUNCTIONS.no_op
Expand Down
8 changes: 4 additions & 4 deletions urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod

from urnai.actions.action_base import ActionBase
from urnai.actions.action_space_base import ActionSpaceBase
from urnai.models.model_base import ModelBase
from urnai.rewards.reward_base import RewardBase
from urnai.states.state_base import StateBase


Expand All @@ -11,7 +11,7 @@ class AgentBase(ABC):
def __init__(self, action_space : ActionSpaceBase,
state_space : StateBase,
model : ModelBase,
reward):
reward : RewardBase):

self.action_space = action_space
self.state_space = state_space
Expand All @@ -29,7 +29,7 @@ def step(self) -> None:
...

@abstractmethod
def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
def choose_action(self, action_space : ActionSpaceBase) -> int:
"""
Method that contains the agent's strategy for choosing actions
"""
Expand All @@ -56,4 +56,4 @@ def learn(self, obs, reward, done) -> None:
if self.previous_state is not None:
next_state = self.state_space.update(obs)
self.model.learn(self.previous_state, self.previous_action,
reward, next_state, done)
reward, next_state, done)
7 changes: 4 additions & 3 deletions urnai/base/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ class Persistence(ABC):
to save on disk.
"""

def __init__(self, threaded_saving=False):
def __init__(self, object_to_save, threaded_saving=False):
self.threaded_saving = threaded_saving
self.attr_block_list = []
self.processes = []
self.object_to_save = object_to_save

def _get_default_save_stamp(self):
"""
This method returns the default
file name that should be used while
persisting the object.
"""
return self.__class__.__name__ + '_'
return self.object_to_save.__class__.__name__ + '_'

def get_full_persistance_path(self, persist_path):
"""This method returns the default persistance path."""
Expand Down Expand Up @@ -75,4 +76,4 @@ def _get_attributes(self):
def _restore_attributes(self, dict_to_restore):
for key in dict_to_restore:
if key not in self.attr_block_list:
setattr(self, key, dict_to_restore[key])
setattr(self.object_to_save, key, dict_to_restore[key])
16 changes: 9 additions & 7 deletions urnai/base/persistence_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class PersistencePickle(Persistence):
to save on disk.
"""

def __init__(self, threaded_saving=False):
super().__init__(threaded_saving)
def __init__(self, object_to_save, threaded_saving=False):
super().__init__(object_to_save, threaded_saving)

def _simple_save(self, persist_path):
"""
Expand Down Expand Up @@ -56,13 +56,15 @@ def _get_attributes(self):
If you wish to block one particular pickleable attribute, put it
in self.attr_block_list as a string.
"""
if not hasattr(self, 'attr_block_list') or self.attr_block_list is None:
if not hasattr(self.object_to_save, 'attr_block_list') \
or self.attr_block_list is None:
self.attr_block_list = []

attr_block_list = self.attr_block_list + ['attr_block_list', 'processes']

full_attr_list = [attr for attr in dir(self) if not attr.startswith('__')
and not callable(getattr(self, attr))
full_attr_list = [attr for attr in dir(self.object_to_save) \
if not attr.startswith('__')
and not callable(getattr(self.object_to_save, attr))
and attr not in attr_block_list
and 'abc' not in attr]

Expand All @@ -71,7 +73,7 @@ def _get_attributes(self):
for key in full_attr_list:
try:
with tempfile.NamedTemporaryFile() as tmp_file:
pickle.dump(getattr(self, key), tmp_file)
pickle.dump(getattr(self.object_to_save, key), tmp_file)
tmp_file.flush()

pickleable_list.append(key)
Expand Down Expand Up @@ -109,6 +111,6 @@ def _get_dict(self):
pickleable_attr_dict = {}

for attr in self._get_attributes():
pickleable_attr_dict[attr] = getattr(self, attr)
pickleable_attr_dict[attr] = getattr(self.object_to_save, attr)

return pickleable_attr_dict
Loading