From 58a3ed757b64fda209e7fc41041bbe96bfc7cbaa Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Mon, 6 Jan 2025 23:41:46 -0800 Subject: [PATCH] Add support for pinning tensors to device memory in XLA. When a tensor is pinned to device memory it will not be prefetched to alternate memory (or assigned in alternate memory altogether which is possible when it is not pinned). PiperOrigin-RevId: 712789403 --- ...emory_placement_to_internal_annotations.cc | 8 ++- ..._placement_to_internal_annotations_test.cc | 30 ++++++++++ .../service/host_memory_offload_annotations.h | 2 + .../xla/service/memory_space_assignment/BUILD | 10 ++-- .../memory_space_assignment_test.cc | 58 +++++++++++++++++-- .../memory_space_assignment_test_base.h | 1 + 6 files changed, 99 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc index 3c7c89a54cabcb..570afa9e3d501b 100644 --- a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc @@ -48,6 +48,10 @@ absl::StatusOr GetCustomCallTarget( host_memory_offload_annotations::kMemoryTargetDeviceSram) { return host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget; } + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetPinnedDevice) { + return host_memory_offload_annotations::kPinToDeviceCustomCallTarget; + } return absl::InvalidArgumentError( absl::StrCat("Invalid external annotation: ", external_annotation)); } @@ -68,7 +72,9 @@ ConvertCustomCallWithExternalAnnotationToInternalAnnotation( host_memory_offload_annotations::kMemoryTargetUnpinnedHost); const bool is_to_device_case = (it->second == host_memory_offload_annotations::kMemoryTargetDevice || - it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram); + it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram || + it->second == + host_memory_offload_annotations::kMemoryTargetPinnedDevice); if (!is_to_host_case && !is_to_device_case) { return false; } diff --git a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc index db122ae9db5ed1..dab4d055d8f252 100644 --- a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc @@ -540,5 +540,35 @@ TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, EXPECT_EQ(pin_todevice_sramcount, 1); } +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, + ConvertPinToDeviceTest) { + constexpr absl::string_view hlo_string = R"( + HloModule jit_f, entry_computation_layout={(s32[8,2]{0,1:T(2,128)S(1)})->s32[8,2]{0,1:T(2,128)}}, allow_spmd_sharding_propagation_to_output={true} + + ENTRY main.8 { + Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}, metadata={op_name="x"} + constant.2 = s32[] constant(2) + broadcast.3 = s32[8,2]{1,0} broadcast(constant.2), dimensions={} + multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=707} + custom-call.5 = s32[8,2]{1,0} custom-call(multiply.4), custom_call_target="Sharding", sharding={devices=[2,1]<=[2]}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + custom-call.6 = s32[8,2]{1,0} custom-call(custom-call.5), custom_call_target="annotate_device_placement", custom_call_has_side_effect=true, frontend_attributes={_xla_buffer_placement="pinned_device"}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + ROOT multiply.7 = s32[8,2]{1,0} multiply(custom-call.6, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=709} + } // main.8 )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t pin_todevice_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + pin_todevice_count += instr->IsCustomCall( + host_memory_offload_annotations::kPinToDeviceCustomCallTarget); + } + } + EXPECT_EQ(pin_todevice_count, 1); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/host_memory_offload_annotations.h b/third_party/xla/xla/service/host_memory_offload_annotations.h index 42cde9221f5aac..e230fdc8b60764 100644 --- a/third_party/xla/xla/service/host_memory_offload_annotations.h +++ b/third_party/xla/xla/service/host_memory_offload_annotations.h @@ -27,10 +27,12 @@ inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host"; inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host"; inline const absl::string_view kMemoryTargetDevice = "device"; inline const absl::string_view kMemoryTargetDeviceSram = "device_sram"; +inline const absl::string_view kMemoryTargetPinnedDevice = "pinned_device"; // Internal annotations: inline const absl::string_view kMoveToHostCustomCallTarget = "MoveToHost"; inline const absl::string_view kMoveToDeviceCustomCallTarget = "MoveToDevice"; +inline const absl::string_view kPinToDeviceCustomCallTarget = "PinToDevice"; inline const absl::string_view kPinToDeviceSramCustomCallTarget = "PinToDeviceSram"; diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index db4da1675ad86e..b3a9e47389df87 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -90,6 +90,7 @@ xla_cc_test( ":repacking", ":slice", ":testing_utils", + ":utils", "//xla:comparison_util", "//xla:literal_util", "//xla:shape_util", @@ -98,6 +99,7 @@ xla_cc_test( "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_buffer", @@ -108,6 +110,10 @@ xla_cc_test( "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -118,12 +124,8 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index b631b162a7d555..badefa8a0a3951 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -54,6 +53,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" @@ -75,18 +75,19 @@ limitations under the License. #include "xla/service/memory_space_assignment/repacking.h" #include "xla/service/memory_space_assignment/slice.h" #include "xla/service/memory_space_assignment/testing_utils.h" +#include "xla/service/memory_space_assignment/utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace memory_space_assignment { @@ -237,6 +238,53 @@ TEST_F(MemorySpaceAssignmentTest, NegateChain) { EXPECT_THAT(sequence.instructions()[10], op::CopyDone()); } +TEST_F(MemorySpaceAssignmentTest, PinnedDefaultMemorySpace) { + absl::string_view hlo_string = R"( + HloModule NegateChain, is_scheduled=true, entry_computation_layout={(f32[2,3]{1,0}, f32[2,3]{1,0:S(2)})->f32[2,3]{1,0}} + + ENTRY %NegateChain (p0: f32[2,3], p1: f32[2,3]) -> f32[2,3] { + %p0 = f32[2,3]{1,0} parameter(0) + %p1 = f32[2,3]{1,0:S(2)} parameter(1) + %negate = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0} %p0) + %negate.1 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate) + %negate.2 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.1) + %negate.3 = f32[2,3]{1,0} negate(f32[2,3]{1,0:S(2)} %negate.2) + %negate.4 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0} %negate.3) + %negate.5 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.4) + %negate.6 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.5) + ROOT %add = f32[2,3]{1,0} add(f32[2,3]{1,0:S(2)} %negate.6, f32[2,3]{1,0:S(2)} %p1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + HloInstruction* p1 = FindInstruction(module.get(), "p1"); + HloInstruction* negate = FindInstruction(module.get(), "negate"); + HloInstruction* negate_1 = FindInstruction(module.get(), "negate.1"); + HloInstruction* negate_2 = FindInstruction(module.get(), "negate.2"); + HloInstruction* negate_3 = FindInstruction(module.get(), "negate.3"); + HloInstruction* negate_4 = FindInstruction(module.get(), "negate.4"); + HloInstruction* negate_5 = FindInstruction(module.get(), "negate.5"); + HloInstruction* negate_6 = FindInstruction(module.get(), "negate.6"); + HloInstruction* add = FindInstruction(module.get(), "add"); + std::vector pinned_hbm_instructions = { + p1, negate, negate_1, negate_2, negate_4, negate_5, negate_6}; + for (const HloInstruction* instruction : pinned_hbm_instructions) { + EXPECT_EQ(instruction->shape().layout().memory_space(), + kPinnedDefaultMemorySpace); + } + // Check p0 and add are in the default memory space. + EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); + EXPECT_EQ(add->shape().layout().memory_space(), kDefaultMemorySpace); + // Check negate_3 is in pinned to alternate memory space. + EXPECT_EQ(negate_3->shape().layout().memory_space(), kAlternateMemorySpace); + // Check that p1 is only used once at the add instruction. ie, the there is no + // copy/prefetch. + CHECK_EQ(p1->users().size(), 1); + EXPECT_EQ(p1->users()[0], add); +} + // A simple case where the synchronous copy is actually redundant, because its // operand ends up getting prefetched and the its output is only used once, so // we remove the sync copy. diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h index c798572b2d9109..c81035e25dc954 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h @@ -89,6 +89,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { // and large) and alternate (fast and small) memory spaces. const int64_t kDefaultMemorySpace = 0; const int64_t kAlternateMemorySpace = 1; + const int64_t kPinnedDefaultMemorySpace = 2; static HloCostAnalysis::Options DefaultHloCostAnalysisOptions() { HloCostAnalysis::Options options;