Skip to content

Commit

Permalink
Add support for pinning tensors to device memory in XLA. When a tenso…
Browse files Browse the repository at this point in the history
…r is pinned to device memory it will not be prefetched to alternate memory (or assigned in alternate memory altogether which is possible when it is not pinned).

PiperOrigin-RevId: 712789403
  • Loading branch information
subhankarshah authored and tensorflower-gardener committed Jan 7, 2025
1 parent 9378e92 commit 58a3ed7
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ absl::StatusOr<absl::string_view> GetCustomCallTarget(
host_memory_offload_annotations::kMemoryTargetDeviceSram) {
return host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget;
}
if (external_annotation ==
host_memory_offload_annotations::kMemoryTargetPinnedDevice) {
return host_memory_offload_annotations::kPinToDeviceCustomCallTarget;
}
return absl::InvalidArgumentError(
absl::StrCat("Invalid external annotation: ", external_annotation));
}
Expand All @@ -68,7 +72,9 @@ ConvertCustomCallWithExternalAnnotationToInternalAnnotation(
host_memory_offload_annotations::kMemoryTargetUnpinnedHost);
const bool is_to_device_case =
(it->second == host_memory_offload_annotations::kMemoryTargetDevice ||
it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram);
it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram ||
it->second ==
host_memory_offload_annotations::kMemoryTargetPinnedDevice);
if (!is_to_host_case && !is_to_device_case) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,5 +540,35 @@ TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest,
EXPECT_EQ(pin_todevice_sramcount, 1);
}

TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest,
ConvertPinToDeviceTest) {
constexpr absl::string_view hlo_string = R"(
HloModule jit_f, entry_computation_layout={(s32[8,2]{0,1:T(2,128)S(1)})->s32[8,2]{0,1:T(2,128)}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY main.8 {
Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}, metadata={op_name="x"}
constant.2 = s32[] constant(2)
broadcast.3 = s32[8,2]{1,0} broadcast(constant.2), dimensions={}
multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=707}
custom-call.5 = s32[8,2]{1,0} custom-call(multiply.4), custom_call_target="Sharding", sharding={devices=[2,1]<=[2]}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708}
custom-call.6 = s32[8,2]{1,0} custom-call(custom-call.5), custom_call_target="annotate_device_placement", custom_call_has_side_effect=true, frontend_attributes={_xla_buffer_placement="pinned_device"}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708}
ROOT multiply.7 = s32[8,2]{1,0} multiply(custom-call.6, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=709}
} // main.8 )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
bool changed =
ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value();
EXPECT_TRUE(changed);
XLA_VLOG_LINES(1, module->ToString());
int64_t pin_todevice_count = 0;
for (auto* c : module->computations()) {
for (auto* instr : c->instructions()) {
pin_todevice_count += instr->IsCustomCall(
host_memory_offload_annotations::kPinToDeviceCustomCallTarget);
}
}
EXPECT_EQ(pin_todevice_count, 1);
}

} // namespace
} // namespace xla
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/host_memory_offload_annotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host";
inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host";
inline const absl::string_view kMemoryTargetDevice = "device";
inline const absl::string_view kMemoryTargetDeviceSram = "device_sram";
inline const absl::string_view kMemoryTargetPinnedDevice = "pinned_device";

// Internal annotations:
inline const absl::string_view kMoveToHostCustomCallTarget = "MoveToHost";
inline const absl::string_view kMoveToDeviceCustomCallTarget = "MoveToDevice";
inline const absl::string_view kPinToDeviceCustomCallTarget = "PinToDevice";
inline const absl::string_view kPinToDeviceSramCustomCallTarget =
"PinToDeviceSram";

Expand Down
10 changes: 6 additions & 4 deletions third_party/xla/xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ xla_cc_test(
":repacking",
":slice",
":testing_utils",
":utils",
"//xla:comparison_util",
"//xla:literal_util",
"//xla:shape_util",
Expand All @@ -98,6 +99,7 @@ xla_cc_test(
"//xla/hlo/analysis:hlo_alias_analysis",
"//xla/hlo/analysis:hlo_dataflow_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/hlo/utils:hlo_live_range",
"//xla/hlo/utils:hlo_matchers",
"//xla/service:hlo_buffer",
Expand All @@ -108,6 +110,10 @@ xla_cc_test(
"//xla/tests:test_utils",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status",
"//xla/tsl/platform:status_matchers",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand All @@ -118,12 +124,8 @@ xla_cc_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -54,6 +53,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/hlo/utils/hlo_live_range.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/layout_util.h"
Expand All @@ -75,18 +75,19 @@ limitations under the License.
#include "xla/service/memory_space_assignment/repacking.h"
#include "xla/service/memory_space_assignment/slice.h"
#include "xla/service/memory_space_assignment/testing_utils.h"
#include "xla/service/memory_space_assignment/utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/test_utils.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/status_matchers.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/protobuf.h" // IWYU pragma: keep
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace memory_space_assignment {
Expand Down Expand Up @@ -237,6 +238,53 @@ TEST_F(MemorySpaceAssignmentTest, NegateChain) {
EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
}

TEST_F(MemorySpaceAssignmentTest, PinnedDefaultMemorySpace) {
absl::string_view hlo_string = R"(
HloModule NegateChain, is_scheduled=true, entry_computation_layout={(f32[2,3]{1,0}, f32[2,3]{1,0:S(2)})->f32[2,3]{1,0}}

ENTRY %NegateChain (p0: f32[2,3], p1: f32[2,3]) -> f32[2,3] {
%p0 = f32[2,3]{1,0} parameter(0)
%p1 = f32[2,3]{1,0:S(2)} parameter(1)
%negate = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0} %p0)
%negate.1 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate)
%negate.2 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.1)
%negate.3 = f32[2,3]{1,0} negate(f32[2,3]{1,0:S(2)} %negate.2)
%negate.4 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0} %negate.3)
%negate.5 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.4)
%negate.6 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.5)
ROOT %add = f32[2,3]{1,0} add(f32[2,3]{1,0:S(2)} %negate.6, f32[2,3]{1,0:S(2)} %p1)
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
XLA_VLOG_LINES(1, module->ToString());
HloInstruction* p0 = FindInstruction(module.get(), "p0");
HloInstruction* p1 = FindInstruction(module.get(), "p1");
HloInstruction* negate = FindInstruction(module.get(), "negate");
HloInstruction* negate_1 = FindInstruction(module.get(), "negate.1");
HloInstruction* negate_2 = FindInstruction(module.get(), "negate.2");
HloInstruction* negate_3 = FindInstruction(module.get(), "negate.3");
HloInstruction* negate_4 = FindInstruction(module.get(), "negate.4");
HloInstruction* negate_5 = FindInstruction(module.get(), "negate.5");
HloInstruction* negate_6 = FindInstruction(module.get(), "negate.6");
HloInstruction* add = FindInstruction(module.get(), "add");
std::vector<const HloInstruction*> pinned_hbm_instructions = {
p1, negate, negate_1, negate_2, negate_4, negate_5, negate_6};
for (const HloInstruction* instruction : pinned_hbm_instructions) {
EXPECT_EQ(instruction->shape().layout().memory_space(),
kPinnedDefaultMemorySpace);
}
// Check p0 and add are in the default memory space.
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
EXPECT_EQ(add->shape().layout().memory_space(), kDefaultMemorySpace);
// Check negate_3 is in pinned to alternate memory space.
EXPECT_EQ(negate_3->shape().layout().memory_space(), kAlternateMemorySpace);
// Check that p1 is only used once at the add instruction. ie, the there is no
// copy/prefetch.
CHECK_EQ(p1->users().size(), 1);
EXPECT_EQ(p1->users()[0], add);
}

// A simple case where the synchronous copy is actually redundant, because its
// operand ends up getting prefetched and the its output is only used once, so
// we remove the sync copy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase {
// and large) and alternate (fast and small) memory spaces.
const int64_t kDefaultMemorySpace = 0;
const int64_t kAlternateMemorySpace = 1;
const int64_t kPinnedDefaultMemorySpace = 2;

static HloCostAnalysis::Options DefaultHloCostAnalysisOptions() {
HloCostAnalysis::Options options;
Expand Down

0 comments on commit 58a3ed7

Please sign in to comment.