-
Notifications
You must be signed in to change notification settings - Fork 135
/
iql.py
771 lines (659 loc) · 26.6 KB
/
iql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
# source: https://github.com/gwthomas/IQL-PyTorch
# https://arxiv.org/pdf/2110.06169.pdf
import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
TensorBatch = List[torch.Tensor]
EXP_ADV_MAX = 100.0
LOG_STD_MIN = -20.0
LOG_STD_MAX = 2.0
ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")
@dataclass
class TrainConfig:
# Experiment
device: str = "cuda"
env: str = "antmaze-umaze-v2" # OpenAI gym environment name
seed: int = 0 # Sets Gym, PyTorch and Numpy seeds
eval_seed: int = 0 # Eval environment seed
eval_freq: int = int(5e4) # How often (time steps) we evaluate
n_episodes: int = 100 # How many episodes run during evaluation
offline_iterations: int = int(1e6) # Number of offline updates
online_iterations: int = int(1e6) # Number of online updates
checkpoints_path: Optional[str] = None # Save path
load_model: str = "" # Model load file name, "" doesn't load
# IQL
actor_dropout: float = 0.0 # Dropout in actor network
buffer_size: int = 2_000_000 # Replay buffer size
batch_size: int = 256 # Batch size for all networks
discount: float = 0.99 # Discount factor
tau: float = 0.005 # Target network update rate
beta: float = 3.0 # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
iql_tau: float = 0.7 # Coefficient for asymmetric loss
expl_noise: float = 0.03 # Std of Gaussian exploration noise
noise_clip: float = 0.5 # Range to clip noise
iql_deterministic: bool = False # Use deterministic actor
normalize: bool = True # Normalize states
normalize_reward: bool = False # Normalize reward
vf_lr: float = 3e-4 # V function learning rate
qf_lr: float = 3e-4 # Critic learning rate
actor_lr: float = 3e-4 # Actor learning rate
# Wandb logging
project: str = "CORL"
group: str = "IQL-D4RL"
name: str = "IQL"
def __post_init__(self):
self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
if self.checkpoints_path is not None:
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)
def soft_update(target: nn.Module, source: nn.Module, tau: float):
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)
def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
mean = states.mean(0)
std = states.std(0) + eps
return mean, std
def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
return (states - mean) / std
def wrap_env(
env: gym.Env,
state_mean: Union[np.ndarray, float] = 0.0,
state_std: Union[np.ndarray, float] = 1.0,
reward_scale: float = 1.0,
) -> gym.Env:
# PEP 8: E731 do not assign a lambda expression, use a def
def normalize_state(state):
return (
state - state_mean
) / state_std # epsilon should be already added in std.
def scale_reward(reward):
# Please be careful, here reward is multiplied by scale!
return reward_scale * reward
env = gym.wrappers.TransformObservation(env, normalize_state)
if reward_scale != 1.0:
env = gym.wrappers.TransformReward(env, scale_reward)
return env
class ReplayBuffer:
def __init__(
self,
state_dim: int,
action_dim: int,
buffer_size: int,
device: str = "cpu",
):
self._buffer_size = buffer_size
self._pointer = 0
self._size = 0
self._states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._actions = torch.zeros(
(buffer_size, action_dim), dtype=torch.float32, device=device
)
self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._next_states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._device = device
def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
return torch.tensor(data, dtype=torch.float32, device=self._device)
# Loads data in d4rl format, i.e. from Dict[str, np.array].
def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
if self._size != 0:
raise ValueError("Trying to load data into non-empty replay buffer")
n_transitions = data["observations"].shape[0]
if n_transitions > self._buffer_size:
raise ValueError(
"Replay buffer is smaller than the dataset you are trying to load!"
)
self._states[:n_transitions] = self._to_tensor(data["observations"])
self._actions[:n_transitions] = self._to_tensor(data["actions"])
self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
self._size += n_transitions
self._pointer = min(self._size, n_transitions)
print(f"Dataset size: {n_transitions}")
def sample(self, batch_size: int) -> TensorBatch:
indices = np.random.randint(0, self._size, size=batch_size)
states = self._states[indices]
actions = self._actions[indices]
rewards = self._rewards[indices]
next_states = self._next_states[indices]
dones = self._dones[indices]
return [states, actions, rewards, next_states, dones]
def add_transition(
self,
state: np.ndarray,
action: np.ndarray,
reward: float,
next_state: np.ndarray,
done: bool,
):
# Use this method to add new data into the replay buffer during fine-tuning.
self._states[self._pointer] = self._to_tensor(state)
self._actions[self._pointer] = self._to_tensor(action)
self._rewards[self._pointer] = self._to_tensor(reward)
self._next_states[self._pointer] = self._to_tensor(next_state)
self._dones[self._pointer] = self._to_tensor(done)
self._pointer = (self._pointer + 1) % self._buffer_size
self._size = min(self._size + 1, self._buffer_size)
# raise NotImplementedError
def set_env_seed(env: Optional[gym.Env], seed: int):
env.seed(seed)
env.action_space.seed(seed)
def set_seed(
seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
if env is not None:
set_env_seed(env, seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(deterministic_torch)
def wandb_init(config: dict) -> None:
wandb.init(
config=config,
project=config["project"],
group=config["group"],
name=config["name"],
id=str(uuid.uuid4()),
)
wandb.run.save()
def is_goal_reached(reward: float, info: Dict) -> bool:
if "goal_achieved" in info:
return info["goal_achieved"]
return reward > 0 # Assuming that reaching target is a positive reward
@torch.no_grad()
def eval_actor(
env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
env.seed(seed)
actor.eval()
episode_rewards = []
successes = []
for _ in range(n_episodes):
state, done = env.reset(), False
episode_reward = 0.0
goal_achieved = False
while not done:
action = actor.act(state, device)
state, reward, done, env_infos = env.step(action)
episode_reward += reward
if not goal_achieved:
goal_achieved = is_goal_reached(reward, env_infos)
# Valid only for environments with goal
successes.append(float(goal_achieved))
episode_rewards.append(episode_reward)
actor.train()
return np.asarray(episode_rewards), np.mean(successes)
def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
def modify_reward(dataset: Dict, env_name: str, max_episode_steps: int = 1000) -> Dict:
if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
dataset["rewards"] /= max_ret - min_ret
dataset["rewards"] *= max_episode_steps
return {
"max_ret": max_ret,
"min_ret": min_ret,
"max_episode_steps": max_episode_steps,
}
elif "antmaze" in env_name:
dataset["rewards"] -= 1.0
return {}
def modify_reward_online(reward: float, env_name: str, **kwargs) -> float:
if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
reward /= kwargs["max_ret"] - kwargs["min_ret"]
reward *= kwargs["max_episode_steps"]
elif "antmaze" in env_name:
reward -= 1.0
return reward
def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor:
return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)
class Squeeze(nn.Module):
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.squeeze(dim=self.dim)
class MLP(nn.Module):
def __init__(
self,
dims,
activation_fn: Callable[[], nn.Module] = nn.ReLU,
output_activation_fn: Callable[[], nn.Module] = None,
squeeze_output: bool = False,
dropout: float = 0.0,
):
super().__init__()
n_dims = len(dims)
if n_dims < 2:
raise ValueError("MLP requires at least two dims (input and output)")
layers = []
for i in range(n_dims - 2):
layers.append(nn.Linear(dims[i], dims[i + 1]))
layers.append(activation_fn())
if dropout > 0.0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(dims[-2], dims[-1]))
if output_activation_fn is not None:
layers.append(output_activation_fn())
if squeeze_output:
if dims[-1] != 1:
raise ValueError("Last dim must be 1 when squeezing")
layers.append(Squeeze(-1))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class GaussianPolicy(nn.Module):
def __init__(
self,
state_dim: int,
act_dim: int,
max_action: float,
hidden_dim: int = 256,
n_hidden: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.net = MLP(
[state_dim, *([hidden_dim] * n_hidden), act_dim],
output_activation_fn=nn.Tanh,
dropout=dropout,
)
self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32))
self.max_action = max_action
def forward(self, obs: torch.Tensor) -> Normal:
mean = self.net(obs)
std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX))
return Normal(mean, std)
@torch.no_grad()
def act(self, state: np.ndarray, device: str = "cpu"):
state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
dist = self(state)
action = dist.mean if not self.training else dist.sample()
action = torch.clamp(self.max_action * action, -self.max_action, self.max_action)
return action.cpu().data.numpy().flatten()
class DeterministicPolicy(nn.Module):
def __init__(
self,
state_dim: int,
act_dim: int,
max_action: float,
hidden_dim: int = 256,
n_hidden: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.net = MLP(
[state_dim, *([hidden_dim] * n_hidden), act_dim],
output_activation_fn=nn.Tanh,
dropout=dropout,
)
self.max_action = max_action
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.net(obs)
@torch.no_grad()
def act(self, state: np.ndarray, device: str = "cpu"):
state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
return (
torch.clamp(self(state) * self.max_action, -self.max_action, self.max_action)
.cpu()
.data.numpy()
.flatten()
)
class TwinQ(nn.Module):
def __init__(
self, state_dim: int, action_dim: int, hidden_dim: int = 256, n_hidden: int = 2
):
super().__init__()
dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1]
self.q1 = MLP(dims, squeeze_output=True)
self.q2 = MLP(dims, squeeze_output=True)
def both(
self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
sa = torch.cat([state, action], 1)
return self.q1(sa), self.q2(sa)
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
return torch.min(*self.both(state, action))
class ValueFunction(nn.Module):
def __init__(self, state_dim: int, hidden_dim: int = 256, n_hidden: int = 2):
super().__init__()
dims = [state_dim, *([hidden_dim] * n_hidden), 1]
self.v = MLP(dims, squeeze_output=True)
def forward(self, state: torch.Tensor) -> torch.Tensor:
return self.v(state)
class ImplicitQLearning:
def __init__(
self,
max_action: float,
actor: nn.Module,
actor_optimizer: torch.optim.Optimizer,
q_network: nn.Module,
q_optimizer: torch.optim.Optimizer,
v_network: nn.Module,
v_optimizer: torch.optim.Optimizer,
iql_tau: float = 0.7,
beta: float = 3.0,
max_steps: int = 1000000,
discount: float = 0.99,
tau: float = 0.005,
device: str = "cpu",
):
self.max_action = max_action
self.qf = q_network
self.q_target = copy.deepcopy(self.qf).requires_grad_(False).to(device)
self.vf = v_network
self.actor = actor
self.v_optimizer = v_optimizer
self.q_optimizer = q_optimizer
self.actor_optimizer = actor_optimizer
self.actor_lr_schedule = CosineAnnealingLR(self.actor_optimizer, max_steps)
self.iql_tau = iql_tau
self.beta = beta
self.discount = discount
self.tau = tau
self.total_it = 0
self.device = device
def _update_v(self, observations, actions, log_dict) -> torch.Tensor:
# Update value function
with torch.no_grad():
target_q = self.q_target(observations, actions)
v = self.vf(observations)
adv = target_q - v
v_loss = asymmetric_l2_loss(adv, self.iql_tau)
log_dict["value_loss"] = v_loss.item()
self.v_optimizer.zero_grad()
v_loss.backward()
self.v_optimizer.step()
return adv
def _update_q(
self,
next_v: torch.Tensor,
observations: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
terminals: torch.Tensor,
log_dict: Dict,
):
targets = rewards + (1.0 - terminals.float()) * self.discount * next_v.detach()
qs = self.qf.both(observations, actions)
q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs)
log_dict["q_loss"] = q_loss.item()
self.q_optimizer.zero_grad()
q_loss.backward()
self.q_optimizer.step()
# Update target Q network
soft_update(self.q_target, self.qf, self.tau)
def _update_policy(
self,
adv: torch.Tensor,
observations: torch.Tensor,
actions: torch.Tensor,
log_dict: Dict,
):
exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX)
policy_out = self.actor(observations)
if isinstance(policy_out, torch.distributions.Distribution):
bc_losses = -policy_out.log_prob(actions).sum(-1, keepdim=False)
elif torch.is_tensor(policy_out):
if policy_out.shape != actions.shape:
raise RuntimeError("Actions shape missmatch")
bc_losses = torch.sum((policy_out - actions) ** 2, dim=1)
else:
raise NotImplementedError
policy_loss = torch.mean(exp_adv * bc_losses)
log_dict["actor_loss"] = policy_loss.item()
self.actor_optimizer.zero_grad()
policy_loss.backward()
self.actor_optimizer.step()
self.actor_lr_schedule.step()
def train(self, batch: TensorBatch) -> Dict[str, float]:
self.total_it += 1
(
observations,
actions,
rewards,
next_observations,
dones,
) = batch
log_dict = {}
with torch.no_grad():
next_v = self.vf(next_observations)
# Update value function
adv = self._update_v(observations, actions, log_dict)
rewards = rewards.squeeze(dim=-1)
dones = dones.squeeze(dim=-1)
# Update Q function
self._update_q(next_v, observations, actions, rewards, dones, log_dict)
# Update actor
self._update_policy(adv, observations, actions, log_dict)
return log_dict
def state_dict(self) -> Dict[str, Any]:
return {
"qf": self.qf.state_dict(),
"q_optimizer": self.q_optimizer.state_dict(),
"vf": self.vf.state_dict(),
"v_optimizer": self.v_optimizer.state_dict(),
"actor": self.actor.state_dict(),
"actor_optimizer": self.actor_optimizer.state_dict(),
"actor_lr_schedule": self.actor_lr_schedule.state_dict(),
"total_it": self.total_it,
}
def load_state_dict(self, state_dict: Dict[str, Any]):
self.qf.load_state_dict(state_dict["qf"])
self.q_optimizer.load_state_dict(state_dict["q_optimizer"])
self.q_target = copy.deepcopy(self.qf)
self.vf.load_state_dict(state_dict["vf"])
self.v_optimizer.load_state_dict(state_dict["v_optimizer"])
self.actor.load_state_dict(state_dict["actor"])
self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
self.actor_lr_schedule.load_state_dict(state_dict["actor_lr_schedule"])
self.total_it = state_dict["total_it"]
@pyrallis.wrap()
def train(config: TrainConfig):
env = gym.make(config.env)
eval_env = gym.make(config.env)
is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)
max_steps = env._max_episode_steps
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
dataset = d4rl.qlearning_dataset(env)
reward_mod_dict = {}
if config.normalize_reward:
reward_mod_dict = modify_reward(dataset, config.env)
if config.normalize:
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
else:
state_mean, state_std = 0, 1
dataset["observations"] = normalize_states(
dataset["observations"], state_mean, state_std
)
dataset["next_observations"] = normalize_states(
dataset["next_observations"], state_mean, state_std
)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)
eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
replay_buffer = ReplayBuffer(
state_dim,
action_dim,
config.buffer_size,
config.device,
)
replay_buffer.load_d4rl_dataset(dataset)
max_action = float(env.action_space.high[0])
if config.checkpoints_path is not None:
print(f"Checkpoints path: {config.checkpoints_path}")
os.makedirs(config.checkpoints_path, exist_ok=True)
with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
pyrallis.dump(config, f)
# Set seeds
seed = config.seed
set_seed(seed, env)
set_env_seed(eval_env, config.eval_seed)
q_network = TwinQ(state_dim, action_dim).to(config.device)
v_network = ValueFunction(state_dim).to(config.device)
actor = (
DeterministicPolicy(
state_dim, action_dim, max_action, dropout=config.actor_dropout
)
if config.iql_deterministic
else GaussianPolicy(
state_dim, action_dim, max_action, dropout=config.actor_dropout
)
).to(config.device)
v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)
q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)
kwargs = {
"max_action": max_action,
"actor": actor,
"actor_optimizer": actor_optimizer,
"q_network": q_network,
"q_optimizer": q_optimizer,
"v_network": v_network,
"v_optimizer": v_optimizer,
"discount": config.discount,
"tau": config.tau,
"device": config.device,
# IQL
"beta": config.beta,
"iql_tau": config.iql_tau,
"max_steps": config.offline_iterations,
}
print("---------------------------------------")
print(f"Training IQL, Env: {config.env}, Seed: {seed}")
print("---------------------------------------")
# Initialize actor
trainer = ImplicitQLearning(**kwargs)
if config.load_model != "":
policy_file = Path(config.load_model)
trainer.load_state_dict(torch.load(policy_file))
actor = trainer.actor
wandb_init(asdict(config))
evaluations = []
state, done = env.reset(), False
episode_return = 0
episode_step = 0
goal_achieved = False
eval_successes = []
train_successes = []
print("Offline pretraining")
for t in range(int(config.offline_iterations) + int(config.online_iterations)):
if t == config.offline_iterations:
print("Online tuning")
online_log = {}
if t >= config.offline_iterations:
episode_step += 1
action = actor(
torch.tensor(
state.reshape(1, -1), device=config.device, dtype=torch.float32
)
)
if not config.iql_deterministic:
action = action.sample()
else:
noise = (torch.randn_like(action) * config.expl_noise).clamp(
-config.noise_clip, config.noise_clip
)
action += noise
action = torch.clamp(max_action * action, -max_action, max_action)
action = action.cpu().data.numpy().flatten()
next_state, reward, done, env_infos = env.step(action)
if not goal_achieved:
goal_achieved = is_goal_reached(reward, env_infos)
episode_return += reward
real_done = False # Episode can timeout which is different from done
if done and episode_step < max_steps:
real_done = True
if config.normalize_reward:
reward = modify_reward_online(reward, config.env, **reward_mod_dict)
replay_buffer.add_transition(state, action, reward, next_state, real_done)
state = next_state
if done:
state, done = env.reset(), False
# Valid only for envs with goal, e.g. AntMaze, Adroit
if is_env_with_goal:
train_successes.append(goal_achieved)
online_log["train/regret"] = np.mean(1 - np.array(train_successes))
online_log["train/is_success"] = float(goal_achieved)
online_log["train/episode_return"] = episode_return
normalized_return = eval_env.get_normalized_score(episode_return)
online_log["train/d4rl_normalized_episode_return"] = (
normalized_return * 100.0
)
online_log["train/episode_length"] = episode_step
episode_return = 0
episode_step = 0
goal_achieved = False
batch = replay_buffer.sample(config.batch_size)
batch = [b.to(config.device) for b in batch]
log_dict = trainer.train(batch)
log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
t if t < config.offline_iterations else t - config.offline_iterations
)
log_dict.update(online_log)
wandb.log(log_dict, step=trainer.total_it)
# Evaluate episode
if (t + 1) % config.eval_freq == 0:
print(f"Time steps: {t + 1}")
eval_scores, success_rate = eval_actor(
eval_env,
actor,
device=config.device,
n_episodes=config.n_episodes,
seed=config.seed,
)
eval_score = eval_scores.mean()
eval_log = {}
normalized = eval_env.get_normalized_score(eval_score)
# Valid only for envs with goal, e.g. AntMaze, Adroit
if t >= config.offline_iterations and is_env_with_goal:
eval_successes.append(success_rate)
eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
eval_log["eval/success_rate"] = success_rate
normalized_eval_score = normalized * 100.0
evaluations.append(normalized_eval_score)
eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
print("---------------------------------------")
print(
f"Evaluation over {config.n_episodes} episodes: "
f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
)
print("---------------------------------------")
if config.checkpoints_path is not None:
torch.save(
trainer.state_dict(),
os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
)
wandb.log(eval_log, step=trainer.total_it)
if __name__ == "__main__":
train()