From 6a7ad649c72b20d6aca5dfd089cdca9033cc120a Mon Sep 17 00:00:00 2001
From: Vincent Moens <vmoens@meta.com>
Date: Thu, 19 Dec 2024 15:53:46 +0000
Subject: [PATCH] Update

[ghstack-poisoned]
---
 torchrl/data/replay_buffers/samplers.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py
index bbdf2387683..2ad0550ed06 100644
--- a/torchrl/data/replay_buffers/samplers.py
+++ b/torchrl/data/replay_buffers/samplers.py
@@ -1485,13 +1485,13 @@ def _get_index(
                 truncated[seq_length.cumsum(0) - 1] = 1
             index = index.to(torch.long).unbind(-1)
             st_index = storage[index]
-            try:
-                done = st_index[done_key] | truncated
-            except KeyError:
+            done = st_index.get(done_key, default=None)
+            if done is None:
                 done = truncated.clone()
-            try:
-                terminated = st_index[terminated_key]
-            except KeyError:
+            else:
+                done = done | truncated
+            terminated = st_index.get(terminated_key, default=None)
+            if terminated is None:
                 terminated = torch.zeros_like(truncated)
             return index, {
                 truncated_key: truncated,