Skip to content

Commit

Permalink
add full buffer return to CircularBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
jtigue-bdai committed Nov 19, 2024
1 parent a9b338d commit f7ede99
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def current_length(self) -> torch.Tensor:
"""
return torch.minimum(self._num_pushes, self._max_len)

@property
def buffer(self) -> torch.Tensor:
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning.
Returns:
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]].
"""
buf = self._buffer.clone()
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0)
return torch.transpose(buf, dim0=0, dim1=1)

"""
Operations.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,32 @@ def test_key_greater_than_pushes(self):
retrieved_data = self.buffer[torch.tensor([5, 5, 5], device=self.device)]
self.assertTrue(torch.equal(retrieved_data, data1))

def test_return_buffer_prop(self):
"""Test retrieving the whole buffer for correct size and contents.
Returning the whole buffer should have the shape [batch_size,max_len,data.shape[1:]]
"""
num_overflow = 2
for i in range(self.buffer.max_length + num_overflow):
data = torch.tensor([[i]], device=self.device).repeat(3, 2)
self.buffer.append(data)

retrieved_buffer = self.buffer.buffer
# check shape
self.assertTrue(retrieved_buffer.shape == torch.Size([self.buffer.batch_size, self.buffer.max_length, 2]))
# check that batch is first dimension
torch.testing.assert_close(retrieved_buffer[0], retrieved_buffer[1])
# check oldest
torch.testing.assert_close(
retrieved_buffer[:, 0], torch.tensor([[num_overflow]], device=self.device).repeat(3, 2)
)
# check most recent
torch.testing.assert_close(
retrieved_buffer[:, -1],
torch.tensor([[self.buffer.max_length + num_overflow - 1]], device=self.device).repeat(3, 2),
)
# check that it is returned oldest first
for idx in range(self.buffer.max_length - 1):
self.assertTrue(torch.all(torch.le(retrieved_buffer[:, idx], retrieved_buffer[:, idx + 1])))

if __name__ == "__main__":
run_tests()

0 comments on commit f7ede99

Please sign in to comment.