Skip to content

Commit

Permalink
Fix #811 (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunkafei authored Mar 4, 2023
1 parent c8be85b commit bc222e8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ exclude =
dist
*.egg-info
max-line-length = 87
ignore = B305,W504,B006,B008,B024,W503
ignore = B305,W504,B006,B008,B024,W503,B028

[yapf]
based_on_style = pep8
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@ def get_install_requires() -> str:
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"protobuf~=3.19.0", # breaking change, sphinx fail
"packaging",
]


def get_extras_require() -> str:
req = {
"dev": [
"sphinx<4",
"sphinx",
"sphinx_rtd_theme",
"jinja2<3.1", # temporary fix
"jinja2",
"sphinxcontrib-bibtex",
"flake8",
"flake8-bugbear",
Expand Down
32 changes: 32 additions & 0 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,38 @@ def compute_reward_fn(ag, g):
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)

# Another test case for cycled indices
env_size = 99
bufsize = 15
env = MyGoalEnv(env_size, array_state=False)
buf = HERReplayBuffer(
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
)
buf.future_p = 1
for x, ep_len in enumerate([10, 20]):
obs, _ = env.reset()
for i in range(ep_len):
act = 1
obs_next, rew, terminated, truncated, info = env.step(act)
batch = Batch(
obs=obs,
act=[act],
rew=rew,
terminated=(i == ep_len - 1),
truncated=(i == ep_len - 1),
obs_next=obs_next,
info=info
)
if x == 1 and obs["observation"] < 10:
obs = obs_next
continue
buf.add(batch)
obs = obs_next
buf._restore_cache()
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
buf.rewrite_transitions(sample_indices)
assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]


def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
Expand Down
7 changes: 4 additions & 3 deletions tianshou/data/buffer/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
# Calculate future timestep to use
current = indices[0]
terminal = indices[-1]
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
future_offset = future_offset.astype(int)
future_t = (current + future_offset)
episodes_len = (terminal - current + self.maxsize) % self.maxsize
future_offset = np.random.uniform(size=len(indices[0])) * episodes_len
future_offset = np.round(future_offset).astype(int)
future_t = (current + future_offset) % self.maxsize

# Compute indices
# open indices are used to find longest, unique trajectories among
Expand Down

0 comments on commit bc222e8

Please sign in to comment.