-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmain.py
98 lines (82 loc) · 3.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from pathlib import Path
import gym
import d4rl
import numpy as np
import torch
from tqdm import trange
from src.iql import ImplicitQLearning
from src.policy import GaussianPolicy, DeterministicPolicy
from src.value_functions import TwinQ, ValueFunction
from src.util import return_range, set_seed, Log, sample_batch, torchify, evaluate_policy
def get_env_and_dataset(log, env_name, max_episode_steps):
env = gym.make(env_name)
dataset = d4rl.qlearning_dataset(env)
if any(s in env_name for s in ('halfcheetah', 'hopper', 'walker2d')):
min_ret, max_ret = return_range(dataset, max_episode_steps)
log(f'Dataset returns have range [{min_ret}, {max_ret}]')
dataset['rewards'] /= (max_ret - min_ret)
dataset['rewards'] *= max_episode_steps
elif 'antmaze' in env_name:
dataset['rewards'] -= 1.
for k, v in dataset.items():
dataset[k] = torchify(v)
return env, dataset
def main(args):
torch.set_num_threads(1)
log = Log(Path(args.log_dir)/args.env_name, vars(args))
log(f'Log dir: {log.dir}')
env, dataset = get_env_and_dataset(log, args.env_name, args.max_episode_steps)
obs_dim = dataset['observations'].shape[1]
act_dim = dataset['actions'].shape[1] # this assume continuous actions
set_seed(args.seed, env=env)
if args.deterministic_policy:
policy = DeterministicPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden)
else:
policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden)
def eval_policy():
eval_returns = np.array([evaluate_policy(env, policy, args.max_episode_steps) \
for _ in range(args.n_eval_episodes)])
normalized_returns = d4rl.get_normalized_score(args.env_name, eval_returns) * 100.0
log.row({
'return mean': eval_returns.mean(),
'return std': eval_returns.std(),
'normalized return mean': normalized_returns.mean(),
'normalized return std': normalized_returns.std(),
})
iql = ImplicitQLearning(
qf=TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden),
vf=ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden),
policy=policy,
optimizer_factory=lambda params: torch.optim.Adam(params, lr=args.learning_rate),
max_steps=args.n_steps,
tau=args.tau,
beta=args.beta,
alpha=args.alpha,
discount=args.discount
)
for step in trange(args.n_steps):
iql.update(**sample_batch(dataset, args.batch_size))
if (step+1) % args.eval_period == 0:
eval_policy()
torch.save(iql.state_dict(), log.dir/'final.pt')
log.close()
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--env-name', required=True)
parser.add_argument('--log-dir', required=True)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--discount', type=float, default=0.99)
parser.add_argument('--hidden-dim', type=int, default=256)
parser.add_argument('--n-hidden', type=int, default=2)
parser.add_argument('--n-steps', type=int, default=10**6)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--learning-rate', type=float, default=3e-4)
parser.add_argument('--alpha', type=float, default=0.005)
parser.add_argument('--tau', type=float, default=0.7)
parser.add_argument('--beta', type=float, default=3.0)
parser.add_argument('--deterministic-policy', action='store_true')
parser.add_argument('--eval-period', type=int, default=5000)
parser.add_argument('--n-eval-episodes', type=int, default=10)
parser.add_argument('--max-episode-steps', type=int, default=1000)
main(parser.parse_args())