Skip to content

Commit

Permalink
Replace one last use of ScopedDeviceMemory in buffer_comparator_test …
Browse files Browse the repository at this point in the history
…with DeviceHandle.

PiperOrigin-RevId: 630128405
  • Loading branch information
klucke authored and tensorflower-gardener committed May 2, 2024
1 parent 00652e6 commit 5bd8a01
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions third_party/xla/xla/service/gpu/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,20 @@ TEST_F(BufferComparatorTest, BF16) {

auto stream = stream_exec_->CreateStream().value();

se::ScopedDeviceMemory<Eigen::bfloat16> lhs(
se::DeviceMemoryHandle lhs(
stream_exec_,
stream_exec_->AllocateArray<Eigen::bfloat16>(element_count));
InitializeBuffer(stream.get(), BF16, &rng_state, *lhs.ptr());
InitializeBuffer(stream.get(), BF16, &rng_state, lhs.memory());

se::ScopedDeviceMemory<Eigen::bfloat16> rhs(
se::DeviceMemoryHandle rhs(
stream_exec_,
stream_exec_->AllocateArray<Eigen::bfloat16>(element_count));
InitializeBuffer(stream.get(), BF16, &rng_state, *rhs.ptr());
InitializeBuffer(stream.get(), BF16, &rng_state, rhs.memory());

BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}),
HloModuleConfig());
EXPECT_FALSE(
comparator.CompareEqual(stream.get(), *lhs.ptr(), *rhs.ptr()).value());
EXPECT_FALSE(comparator.CompareEqual(stream.get(), lhs.memory(), rhs.memory())
.value());
}

} // namespace
Expand Down

0 comments on commit 5bd8a01

Please sign in to comment.