-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
92 lines (74 loc) · 2.6 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import os
import sys
import random
import yaml
import numpy as np
import torch as th
import torch.nn.functional as F
from crafter.env import Env
from crafter.recorder import VideoRecorder
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from achievement_distillation.algorithm import *
from achievement_distillation.model import *
from achievement_distillation.wrapper import VecPyTorch
def main(args):
# Load config file
config_file = open(f"configs/{args.exp_name}.yaml", "r")
config = yaml.load(config_file, Loader=yaml.FullLoader)
# Fix Random seed
random.seed(args.eval_seed)
np.random.seed(args.eval_seed)
th.manual_seed(args.eval_seed)
th.cuda.manual_seed_all(args.eval_seed)
# CUDA setting
th.set_num_threads(1)
cuda = th.cuda.is_available()
device = th.device("cuda:0" if cuda else "cpu")
# Define checkpoint directory
run_name = f"{args.exp_name}-{args.timestamp}-s{args.train_seed:02}"
ckpt_dir = os.path.join("./models", run_name)
# Create environment
env = Env(seed=args.eval_seed)
env = VideoRecorder(env, directory=f"./videos/{run_name}")
venv = DummyVecEnv([lambda: env])
venv = VecPyTorch(venv, device=device)
# Create model
model_cls = getattr(sys.modules[__name__], config["model_cls"])
model: BaseModel = model_cls(
observation_space=venv.observation_space,
action_space=venv.action_space,
**config["model_kwargs"],
)
model.to(device)
print(model)
# Load checkpoint
ckpt_path = os.path.join(ckpt_dir, f"agent-e250.pt")
state_dict = th.load(ckpt_path)
model.load_state_dict(state_dict)
# Eval
model.eval()
obs = venv.reset()
states = th.zeros(1, config["model_kwargs"]["hidsize"]).to(device)
while True:
outputs = model.act(obs, states=states)
latents = outputs["latents"]
actions = outputs["actions"]
obs, rewards, dones, _ = venv.step(actions)
# Done
if dones.any():
break
# Update states
if (rewards > 0.1).any():
with th.no_grad():
next_latents = model.encode(obs)
states = next_latents - latents
states = F.normalize(states, dim=-1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--exp_name", type=str, required=True)
parser.add_argument("--timestamp", type=str, required=True)
parser.add_argument("--train_seed", type=int, default=0)
parser.add_argument("--eval_seed", type=int, default=123)
args = parser.parse_args()
main(args)