-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinvest_test.py
48 lines (38 loc) · 1.25 KB
/
invest_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gym, ray
from negotiate_model import InvestorModel
from ray.rllib.models import ModelCatalog
from ray.rllib.agents import ppo
import numpy as np
class InvestTestEnv(gym.Env):
observation_space = gym.spaces.Box(np.array([0]), np.array([1]))
action_space = gym.spaces.Box(np.array([0]), np.array([1]))
def __init__(self, config):
super().__init__()
def reset(self):
return np.array([0.0])
def render(self, mode='human'):
pass
def step(self, action: np.ndarray):
if np.random.rand() < 0.01:
print(action[0] * 15)
return [0.0], action[0], True, {}
if __name__ == "__main__":
ray.init(local_mode=True)
ModelCatalog.register_custom_model("invest_model", InvestorModel)
trainer = ppo.PPOTrainer(env=InvestTestEnv, config={
"env_config": {}, # config to pass to env class
"framework": "torch",
"num_workers": 0,
"model": {
"custom_model": "invest_model",
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
"num_sgd_iter": 30,
# Stepsize of SGD.
"lr": 5e-5,
"gamma": 1,
"monitor": True,
})
while True:
print(trainer.train())