Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Sep 13, 2021
1 parent aac44ab commit 72e4c0f
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 101 deletions.
4 changes: 3 additions & 1 deletion coltra/envs/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def _flatten_scalar(values: List[Dict[str, Any]]) -> Dict[str, np.ndarray]:
return {k: np.array([v[k] for v in values]) for k in keys}


def _flatten_info(infos: List[Dict[str, np.ndarray]]) -> Dict[str, Union[np.ndarray, List]]:
def _flatten_info(
infos: List[Dict[str, np.ndarray]]
) -> Dict[str, Union[np.ndarray, List]]:
all_metrics = {}

all_keys = set([k for dictionary in infos for k in dictionary])
Expand Down
4 changes: 3 additions & 1 deletion coltra/envs/unity_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def step(

return obs_dict, reward_dict, done_dict, info_dict

def reset(self, mode: Optional[Mode] = None, num_agents: Optional[int] = None, **kwargs) -> ObsDict:
def reset(
self, mode: Optional[Mode] = None, num_agents: Optional[int] = None, **kwargs
) -> ObsDict:
if mode:
self.param_channel.set_float_parameter("mode", mode.value)
if num_agents:
Expand Down
5 changes: 2 additions & 3 deletions coltra/scripts/train_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Parser(BaseParser):
}


if __name__ == '__main__':
if __name__ == "__main__":
CUDA = torch.cuda.is_available()

args = Parser()
Expand Down Expand Up @@ -83,7 +83,6 @@ class Parser(BaseParser):
model_cls = FancyMLPModel
agent_cls = CAgent if isinstance(action_space, gym.spaces.Box) else DAgent


if args.start_dir:
agent = agent_cls.load_agent(args.start_dir, weight_idx=args.start_idx)
else:
Expand All @@ -94,4 +93,4 @@ class Parser(BaseParser):
agent.cuda()

trainer = PPOCrowdTrainer(agent, env, config)
trainer.train(args.iters, disable_tqdm=False, save_path=trainer.path)
trainer.train(args.iters, disable_tqdm=False, save_path=trainer.path)
16 changes: 8 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from setuptools import setup

setup(
name='coltra-rl',
version='0.1.0',
packages=['coltra', 'coltra.envs', 'coltra.models'],
url='https://github.com/redtachyon/coltra-rl',
license='GNU GPLv3',
author='RedTachyon',
author_email='[email protected]',
description='Coltra-RL is a simple moddable RL algorithm implementation'
name="coltra-rl",
version="0.1.0",
packages=["coltra", "coltra.envs", "coltra.models"],
url="https://github.com/redtachyon/coltra-rl",
license="GNU GPLv3",
author="RedTachyon",
author_email="[email protected]",
description="Coltra-RL is a simple moddable RL algorithm implementation",
)
20 changes: 13 additions & 7 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_constant_agent():
obs = Observation(vector=np.random.randn(5, 81), buffer=np.random.randn(5, 10, 4))
agent = ConstantAgent(np.array([1., 1.], dtype=np.float32))
agent = ConstantAgent(np.array([1.0, 1.0], dtype=np.float32))

actions, _, _ = agent.act(obs_batch=obs)

Expand Down Expand Up @@ -79,8 +79,10 @@ def test_constant_agent():


def test_fancy_mlp_agent():
obs = Observation(vector=np.random.randn(5, 81).astype(np.float32),
buffer=np.random.randn(5, 10, 4).astype(np.float32))
obs = Observation(
vector=np.random.randn(5, 81).astype(np.float32),
buffer=np.random.randn(5, 10, 4).astype(np.float32),
)

model = FancyMLPModel({"input_size": 81, "hidden_sizes": [32, 32]})

Expand Down Expand Up @@ -126,10 +128,14 @@ def test_fancy_mlp_agent():


def test_discrete_fancy_mlp_agent():
obs = Observation(vector=np.random.randn(5, 81).astype(np.float32),
buffer=np.random.randn(5, 10, 4).astype(np.float32))

model = FancyMLPModel({"input_size": 81, "hidden_sizes": [32, 32], "discrete": True})
obs = Observation(
vector=np.random.randn(5, 81).astype(np.float32),
buffer=np.random.randn(5, 10, 4).astype(np.float32),
)

model = FancyMLPModel(
{"input_size": 81, "hidden_sizes": [32, 32], "discrete": True}
)

assert len(model.policy_network.hidden_layers) == 2
assert model.discrete
Expand Down
62 changes: 43 additions & 19 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@
import numpy as np
import torch

from coltra.buffers import Observation, Action, TensorArray, Reward, LogProb, Value, Done, MemoryRecord, MemoryBuffer, \
AgentMemoryBuffer
from coltra.buffers import (
Observation,
Action,
TensorArray,
Reward,
LogProb,
Value,
Done,
MemoryRecord,
MemoryBuffer,
AgentMemoryBuffer,
)


def test_observation_array():
Expand All @@ -27,7 +37,9 @@ def test_observation_misshaped():


def test_obs_to_tensor():
obs = Observation(vector=np.random.randn(5, 81), buffer=np.random.randn(5, 10, 4)).tensor()
obs = Observation(
vector=np.random.randn(5, 81), buffer=np.random.randn(5, 10, 4)
).tensor()
assert obs.batch_size == 5
assert obs.vector.shape == (5, 81)
assert obs.buffer.shape == (5, 10, 4)
Expand All @@ -44,8 +56,10 @@ def test_obs_get():


def test_obs_stack():
obs_list = [Observation(vector=np.random.randn(81), buffer=np.random.randn(10, 4))
for _ in range(7)]
obs_list = [
Observation(vector=np.random.randn(81), buffer=np.random.randn(10, 4))
for _ in range(7)
]

obs = Observation.stack_tensor(obs_list, dim=0)

Expand All @@ -55,8 +69,10 @@ def test_obs_stack():


def test_obs_cat():
obs_list = [Observation(vector=np.random.randn(5, 81), buffer=np.random.randn(5, 10, 4))
for _ in range(5)]
obs_list = [
Observation(vector=np.random.randn(5, 81), buffer=np.random.randn(5, 10, 4))
for _ in range(5)
]

obs = Observation.cat_tensor(obs_list, dim=0)

Expand Down Expand Up @@ -87,10 +103,10 @@ def test_action_misshaped():

def test_apply():
obs = Observation(vector=np.ones((5, 81)), buffer=np.ones((5, 10, 4)))
new_obs = obs.apply(lambda x: 2*x)
new_obs = obs.apply(lambda x: 2 * x)

assert np.allclose(new_obs.vector, 2*np.ones((5, 81)))
assert np.allclose(new_obs.buffer, 2*np.ones((5, 10, 4)))
assert np.allclose(new_obs.vector, 2 * np.ones((5, 81)))
assert np.allclose(new_obs.buffer, 2 * np.ones((5, 10, 4)))


def test_memory_buffer():
Expand All @@ -99,10 +115,20 @@ def test_memory_buffer():
batch_size = 100

for _ in range(batch_size):
obs = {agent_id: Observation(vector=np.random.randn(81).astype(np.float32),
buffer=np.random.randn(10, 4).astype(np.float32)) for agent_id in agents}
action = {agent_id: Action(continuous=np.random.randn(2).astype(np.float32)) for agent_id in agents}
reward = {agent_id: np.random.randn(1).astype(np.float32) for agent_id in agents}
obs = {
agent_id: Observation(
vector=np.random.randn(81).astype(np.float32),
buffer=np.random.randn(10, 4).astype(np.float32),
)
for agent_id in agents
}
action = {
agent_id: Action(continuous=np.random.randn(2).astype(np.float32))
for agent_id in agents
}
reward = {
agent_id: np.random.randn(1).astype(np.float32) for agent_id in agents
}
value = {agent_id: np.random.randn(1).astype(np.float32) for agent_id in agents}
done = {agent_id: False for agent_id in agents}

Expand All @@ -117,8 +143,6 @@ def test_memory_buffer():

crowd_data = memory.crowd_tensorify()
assert isinstance(crowd_data, MemoryRecord)
assert crowd_data.obs.vector.shape == (3*batch_size, 81)
assert crowd_data.obs.buffer.shape == (3*batch_size, 10, 4)
assert crowd_data.obs.batch_size == 3*batch_size


assert crowd_data.obs.vector.shape == (3 * batch_size, 81)
assert crowd_data.obs.buffer.shape == (3 * batch_size, 10, 4)
assert crowd_data.obs.batch_size == 3 * batch_size
7 changes: 6 additions & 1 deletion tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def test_const_reward():

assert data.obs.vector.shape == (1000, 1)
assert torch.allclose(data.obs.vector, torch.ones(1000, 1))
assert torch.allclose(data.reward, torch.ones(1000, ))
assert torch.allclose(
data.reward,
torch.ones(
1000,
),
)

assert all(data.done)
assert env.render() == 0
Expand Down
27 changes: 16 additions & 11 deletions tests/test_discounting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import numpy as np
import torch
from coltra.discounting import discount_experience, _discount_bgae, convert_params, get_beta_vector
from coltra.discounting import (
discount_experience,
_discount_bgae,
convert_params,
get_beta_vector,
)


def test_convert():
assert np.allclose(convert_params(0.5, 0), (0.5, np.inf))
assert np.allclose(convert_params(0.9, 0.5), (9*2, 2))
assert np.allclose(convert_params(0.99, 0.5), (99*2, 2))
assert np.allclose(convert_params(0.9, 0.5), (9 * 2, 2))
assert np.allclose(convert_params(0.99, 0.5), (99 * 2, 2))
assert np.allclose(convert_params(0.9, 1), (9, 1))


Expand All @@ -15,16 +20,16 @@ def test_beta_vector():
Γ = get_beta_vector(T=100, α=0.9, β=np.inf)

assert Γ.shape == (100,)
assert np.allclose(Γ, np.array([0.9**t for t in range(100)]))
assert np.allclose(Γ, np.array([0.9 ** t for t in range(100)]))

# Hyperbolic
Γ = get_beta_vector(T=100, α=0.9, β=1)

assert Γ.shape == (100,)
assert np.allclose(Γ, np.array([1 / (1 + (1/0.9) * t) for t in range(100)]))
assert np.allclose(Γ, np.array([1 / (1 + (1 / 0.9) * t) for t in range(100)]))

# Some intermediate values
Γ = get_beta_vector(T=100, α=0.9, β=2.)
Γ = get_beta_vector(T=100, α=0.9, β=2.0)
assert Γ.shape == (100,)

Γ = get_beta_vector(T=100, α=0.99, β=0.5)
Expand All @@ -46,9 +51,9 @@ def test_discounting():

rewards = torch.cat([torch.zeros(10), torch.zeros(10) + 1, torch.zeros(10) + 2])
values = torch.cat([torch.zeros(10), torch.zeros(10) + 1, torch.zeros(10) + 2])
dones = torch.tensor([False if (t+1) % 5 else True for t in range(30)])
dones = torch.tensor([False if (t + 1) % 5 else True for t in range(30)])

returns, advantages = discount_experience(rewards, values, dones, 0.99, 0., 1.)
returns, advantages = discount_experience(rewards, values, dones, 0.99, 0.0, 1.0)

assert isinstance(returns, torch.Tensor)
assert isinstance(advantages, torch.Tensor)
Expand All @@ -58,7 +63,7 @@ def test_discounting():

rewards = torch.randn(1000)
values = torch.randn(1000)
dones = torch.tensor([False if (t+1) % 500 else True for t in range(1000)])
dones = torch.tensor([False if (t + 1) % 500 else True for t in range(1000)])

returns, advantages = discount_experience(rewards, values, dones, 0.99, 0.5, 0.95)

Expand All @@ -72,7 +77,7 @@ def test_discounting():

rewards = torch.ones(2000)
values = torch.zeros(2000)
dones = torch.tensor([False if (t+1) % 1000 else True for t in range(2000)])
dones = torch.tensor([False if (t + 1) % 1000 else True for t in range(2000)])

returns, advantages = discount_experience(rewards, values, dones, 0.99, 1.0, 0.95)

Expand All @@ -90,4 +95,4 @@ def test_discounting():
#
# def test_episode_rewards():
# rewards = torch.cat([torch.zeros(10), torch.zeros(10) + 1, torch.zeros(10) + 2])
# dones = torch.tensor([...])
# dones = torch.tensor([...])
4 changes: 1 addition & 3 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def test_multigym():
assert isinstance(obs[name], Observation)
assert isinstance(obs[name].vector, np.ndarray)

action = {key: Action(discrete=env.action_space.sample())
for key in obs}
action = {key: Action(discrete=env.action_space.sample()) for key in obs}

obs, reward, done, info = env.step(action)
assert isinstance(obs, dict)
Expand Down Expand Up @@ -147,4 +146,3 @@ def test_training():

trainer = PPOCrowdTrainer(agent, env, config)
trainer.train(2, disable_tqdm=False, save_path=None)

6 changes: 3 additions & 3 deletions tests/test_info_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
def test_flatten_info():
# I had some problems with this, so an explicit test
infos = [
{"foo": "asdf", "m_metric": np_float(1.)},
{"bar": 1, "m_metric": np_float(2.)},
{"foo": "asdf", "m_metric": np_float(1.0)},
{"bar": 1, "m_metric": np_float(2.0)},
{},
{"foo": "potato", "bar": "saf"}
{"foo": "potato", "bar": "saf"},
]

info = _flatten_info(infos)
Expand Down
Loading

0 comments on commit 72e4c0f

Please sign in to comment.