Skip to content

Commit

Permalink
Improve data loading from D4RL and convert RL Unplugged to D4RL format (
Browse files Browse the repository at this point in the history
  • Loading branch information
nuance1979 authored May 3, 2022
1 parent dd16818 commit a7c789f
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 116 deletions.
26 changes: 14 additions & 12 deletions examples/offline/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
Expand Down Expand Up @@ -118,18 +119,19 @@ def test_discrete_bcq(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
if args.buffer_from_rl_unplugged:
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)

# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
Expand Down
24 changes: 16 additions & 8 deletions examples/offline/atari_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from examples.atari.atari_network import QRDQN
from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer
Expand Down Expand Up @@ -57,6 +58,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
Expand Down Expand Up @@ -100,15 +104,19 @@ def test_discrete_cql(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)

# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
Expand Down
24 changes: 16 additions & 8 deletions examples/offline/atari_crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
Expand Down Expand Up @@ -59,6 +60,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
Expand Down Expand Up @@ -120,15 +124,19 @@ def test_discrete_crr(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)

# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
Expand Down
24 changes: 16 additions & 8 deletions examples/offline/atari_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from examples.offline.utils import load_buffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import ImitationPolicy
from tianshou.trainer import offline_trainer
Expand Down Expand Up @@ -50,6 +51,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
Expand Down Expand Up @@ -85,15 +89,19 @@ def test_il(args=get_args()):
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = load_buffer(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
print("Replay buffer size:", len(buffer), flush=True)

# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
Expand Down
61 changes: 38 additions & 23 deletions examples/offline/convert_rl_unplugged_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Adapted from
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
#
"""Convert Atari RL Unplugged datasets to Tianshou replay buffers.
"""Convert Atari RL Unplugged datasets to HDF5 format.
Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.
Expand All @@ -30,11 +30,13 @@
import os
from argparse import ArgumentParser

import h5py
import numpy as np
import requests
import tensorflow as tf
from tqdm import tqdm

from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch

tf.config.set_visible_devices([], 'GPU')

Expand Down Expand Up @@ -108,7 +110,7 @@ def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
pngs: String Tensor of size (4,) containing PNG encoded images.
Returns:
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
4 84x84 grayscale images packed in a (4, 84, 84) uint8 Tensor.
"""
# Statically unroll png decoding
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
Expand Down Expand Up @@ -195,17 +197,30 @@ def download(url: str, fname: str, chunk_size=1024):

def process_shard(url: str, fname: str, ofname: str) -> None:
download(url, fname)
maxsize = 500000
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
act = np.ndarray((maxsize, ), dtype="int64")
rew = np.ndarray((maxsize, ), dtype="float32")
done = np.ndarray((maxsize, ), dtype="bool")
obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
i = 0
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
buffer = ReplayBuffer(500000)
cnt = 0
for example in file_ds:
batch = _tf_example_to_tianshou_batch(example)
buffer.add(batch)
cnt += 1
if cnt % 1000 == 0:
print(f"...{cnt}", end="", flush=True)
print("\nReplayBuffer size:", len(buffer))
buffer.save_hdf5(ofname, compression="gzip")
obs[i], act[i], rew[i], done[i], obs_next[i] = (
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
)
i += 1
if i % 1000 == 0:
print(f"...{i}", end="", flush=True)
print("\nDataset size:", i)
# Following D4RL dataset naming conventions
with h5py.File(ofname, "w") as f:
f.create_dataset("observations", data=obs, compression="gzip")
f.create_dataset("actions", data=act, compression="gzip")
f.create_dataset("rewards", data=rew, compression="gzip")
f.create_dataset("terminals", data=done, compression="gzip")
f.create_dataset("next_observations", data=obs_next, compression="gzip")


def process_dataset(
Expand All @@ -227,19 +242,19 @@ def main(args):
if args.task not in ALL_GAMES:
raise KeyError(f"`{args.task}` is not in the list of games.")
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5")
if os.path.exists(buffer_path):
raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.")
dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
if os.path.exists(dataset_path):
raise IOError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir)
dataset_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dataset_path, exist_ok=True)
dst_path = os.path.join(args.buffer_dir, args.task)
cache_path = os.path.join(args.cache_dir, args.task)
os.makedirs(cache_path, exist_ok=True)
dst_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dst_path, exist_ok=True)
process_dataset(
args.task,
args.cache_dir,
args.dataset_dir,
args.buffer_dir,
run_id=args.run_id,
shard_id=args.shard_id,
total_num_shards=args.total_num_shards
Expand Down Expand Up @@ -267,12 +282,12 @@ def main(args):
parser.add_argument(
"--dataset-dir",
default=os.path.expanduser("~/.rl_unplugged/datasets"),
help="Directory for downloaded original datasets.",
help="Directory for converted hdf5 files.",
)
parser.add_argument(
"--buffer-dir",
default=os.path.expanduser("~/.rl_unplugged/buffers"),
help="Directory for converted replay buffers.",
"--cache-dir",
default=os.path.expanduser("~/.rl_unplugged/cache"),
help="Directory for downloaded original datasets.",
)
args = parser.parse_args()
main(args)
22 changes: 3 additions & 19 deletions examples/offline/d4rl_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import os
import pprint

import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Batch, Collector, ReplayBuffer
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer
Expand Down Expand Up @@ -211,23 +211,7 @@ def watch():
collector.collect(n_episode=1, render=1 / 35)

if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset["rewards"].size

print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)

for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = offline_trainer(
policy,
Expand Down
22 changes: 3 additions & 19 deletions examples/offline/d4rl_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import os
import pprint

import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Batch, Collector, ReplayBuffer
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
Expand Down Expand Up @@ -206,23 +206,7 @@ def watch():
collector.collect(n_episode=1, render=1 / 35)

if not args.watch:
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset["rewards"].size

print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)

for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
)
)
print("dataset loaded")
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = offline_trainer(
policy,
Expand Down
Loading

0 comments on commit a7c789f

Please sign in to comment.