-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_dmil.py
62 lines (56 loc) · 1.97 KB
/
demo_dmil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import sys
sys.path.append('../../')
import torch
import logging
from d2c.trainers import Trainer
from d2c.models import make_agent
from d2c.envs import benchmark_env, LeaEnv
from d2c.data import Data
from d2c.evaluators import bm_eval
from example.benchmark.config import make_config
logging.basicConfig(level=logging.INFO)
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
prefix = 'env.external.'
command_args = {
prefix + 'benchmark_name': 'd4rl',
prefix + 'data_source': 'mujoco',
prefix + 'env_name': 'HalfCheetah-v2',
prefix + 'data_name': 'halfcheetah_expert-v2',
prefix + 'state_normalize': True,
prefix + 'score_normalize': True,
}
split_ratio = 0.05
command_args.update({
'model.model_name': 'dmil',
'train.data_loader_name': None,
'train.data_split_ratio': split_ratio,
'train.device': device,
'train.seed': 0,
'train.total_train_steps': 1000000,
'train.batch_size': 256,
'train.agent_ckpt_name': '1130',
'model.dmil.hyper_params.rollout_size': None,
})
wandb = {
'entity': 'd2c',
'project': 'dmil_test_1201',
'name': f'halfcheetah_expert-ratio_{split_ratio}',
'reinit': False,
'mode': 'online'
}
command_args.update({'train.wandb': wandb})
config = make_config(command_args)
bm_data = Data(config)
s_norm = dict(zip(['obs_shift', 'obs_scale'], bm_data.state_shift_scale))
data = bm_data.data
# The env of the benchmark to be used for policy evaluation.
env = benchmark_env(config=config, **s_norm)
# Contains dynamics model to be trained.
lea_env = LeaEnv(config)
agent = make_agent(config=config, env=lea_env, data=data)
evaluator = bm_eval(agent=agent, env=env, config=config)
trainer = Trainer(agent=agent, train_data=data, config=config, env=lea_env, evaluator=evaluator)
trainer.train()
if __name__ == '__main__':
main()