From 69f26cf6a03e9e7fc4c45c138dc3917cc3aa36f8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 12:40:57 -0800 Subject: [PATCH] [XLA:CPU] Add a direct implementation of ReduceScatter, instead of lowering ReduceScatter to AllReduce+DynamicSlice. PiperOrigin-RevId: 586424242 --- xla/service/cpu/BUILD | 3 +- xla/service/cpu/collectives_interface.h | 6 + xla/service/cpu/cpu_compiler.cc | 2 - xla/service/cpu/cpu_layout_assignment.cc | 7 + xla/service/cpu/cpu_runtime.cc | 77 +++++- xla/service/cpu/cpu_runtime.h | 8 + xla/service/cpu/in_process_collectives.cc | 280 ++++++++++++++++------ xla/service/cpu/in_process_collectives.h | 6 + xla/service/cpu/ir_emitter.cc | 98 ++++++-- xla/service/cpu/simple_orc_jit.cc | 1 + 10 files changed, 381 insertions(+), 107 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 5f5ad2c218641..f3a1573d6ae49 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -300,7 +300,6 @@ cc_library( "//xla/service:optimization_barrier_expander", "//xla/service:qr_expander", "//xla/service:reduce_decomposer", - "//xla/service:reduce_scatter_decomposer", "//xla/service:reshape_decomposer", "//xla/service:reshape_mover", "//xla/service:result_caster", @@ -877,6 +876,7 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -886,6 +886,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/xla/service/cpu/collectives_interface.h b/xla/service/cpu/collectives_interface.h index 4191df1d831fa..e2c8190c81b98 100644 --- a/xla/service/cpu/collectives_interface.h +++ b/xla/service/cpu/collectives_interface.h @@ -67,6 +67,12 @@ class CollectivesCommunicator { virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, void* output_buffer, absl::Duration timeout) = 0; + + // Performs a reduce-scatter + virtual absl::Status ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index e2723c77c18a7..598281a271924 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -191,7 +191,6 @@ limitations under the License. #include "xla/service/optimization_barrier_expander.h" #include "xla/service/qr_expander.h" #include "xla/service/reduce_decomposer.h" -#include "xla/service/reduce_scatter_decomposer.h" #include "xla/service/reshape_decomposer.h" #include "xla/service/reshape_mover.h" #include "xla/service/result_caster.h" @@ -685,7 +684,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. diff --git a/xla/service/cpu/cpu_layout_assignment.cc b/xla/service/cpu/cpu_layout_assignment.cc index 9b2b8331e2d3f..71d7073dd5304 100644 --- a/xla/service/cpu/cpu_layout_assignment.cc +++ b/xla/service/cpu/cpu_layout_assignment.cc @@ -129,6 +129,13 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); + } else if (instruction->opcode() == HloOpcode::kReduceScatter) { + // XLA:CPU can only support reduce-scatter where the scatter dimension + // is the most major dimension in the layout. + auto ars = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()), + ars)); } else if (instruction->opcode() == HloOpcode::kAllGather) { // XLA:CPU can only support all-gathers where the gather dimension is the // most major dimension in the layout. diff --git a/xla/service/cpu/cpu_runtime.cc b/xla/service/cpu/cpu_runtime.cc index 81bd76652a3ac..185ea948d2b19 100644 --- a/xla/service/cpu/cpu_runtime.cc +++ b/xla/service/cpu/cpu_runtime.cc @@ -29,6 +29,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -46,6 +49,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -143,6 +147,8 @@ extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; +extern const char* const kReduceScatterSymbolName = + "__xla_cpu_runtime_ReduceScatter"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -315,6 +321,19 @@ CollectivesInterface* GetInProcessCollectivesImpl() { absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); } +absl::StatusOr RankInGlobalDevices( + absl::Span devices, GlobalDeviceId device) { + auto it = absl::c_find(devices, device); + if (it == devices.end()) { + return InvalidArgument( + "Device %d not present in global devices %s.", device.value(), + absl::StrJoin(devices, ", ", [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })); + } + return std::distance(devices.begin(), it); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllToAllImpl(const ExecutableRunOptions* run_options, int32_t channel_id_present, int64_t op_id, @@ -331,9 +350,7 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -361,9 +378,7 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -374,6 +389,35 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void ReduceScatterImpl(const ExecutableRunOptions* run_options, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int32_t channel_id_present, int64_t op_id, + int32_t reduction_kind, int32_t element_type, + int64_t chunk_elems, void* input_buffer, + void* output_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + /*use_global_device_ids=*/std::nullopt, op_id); + + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->ReduceScatter( + rendezvous_key, static_cast(reduction_kind), + static_cast(element_type), chunk_elems, input_buffer, + output_buffer, DefaultCollectiveTimeout())); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -399,9 +443,7 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, CHECK((num_buffers > 1 && shape.IsTuple()) || (num_buffers == 1 && LayoutUtil::IsDenseArray(shape))); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -450,9 +492,7 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, {}, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -542,6 +582,19 @@ void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, run_options, channel_id_present, op_id, replica_groups_str, replica_groups_str_size, buffer_size, source_buffer, destination_buffer); } + +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int64_t op_id, int32_t reduction_kind, + int32_t element_type, int64_t chunk_elems, void* input_buffer, + void* output_buffer) { + return xla::cpu::runtime::ReduceScatterImpl( + run_options, replica_groups_str, replica_groups_str_size, + channel_id_present, op_id, reduction_kind, element_type, chunk_elems, + input_buffer, output_buffer); +} + void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/xla/service/cpu/cpu_runtime.h b/xla/service/cpu/cpu_runtime.h index 9429242d5f1b8..dd00571cb2e8d 100644 --- a/xla/service/cpu/cpu_runtime.h +++ b/xla/service/cpu/cpu_runtime.h @@ -85,6 +85,7 @@ extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; extern const char* const kAllGatherSymbolName; +extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -202,6 +203,13 @@ extern void __xla_cpu_runtime_AllGather( int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, void* destination_buffer); +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int64_t op_id, int32_t reduction_kind, + int32_t element_type, int64_t chunk_elems, void* input_buffer, + void* output_buffer); + // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/xla/service/cpu/in_process_collectives.cc b/xla/service/cpu/in_process_collectives.cc index 78eee1aa74f73..ed30082be82e8 100644 --- a/xla/service/cpu/in_process_collectives.cc +++ b/xla/service/cpu/in_process_collectives.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { @@ -93,7 +94,7 @@ template constexpr bool always_false_v = false; template -void Reduce(absl::Span acc, absl::Span const> inputs) { +void ReduceHelper(absl::Span acc, absl::Span inputs) { // TODO(penporn): make sure this gets vectorized. if constexpr (reduction_kind == ReductionKind::SUM) { for (size_t j = 0; j < inputs.size(); ++j) { @@ -124,6 +125,49 @@ void Reduce(absl::Span acc, absl::Span const> inputs) { } } +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = typename primitive_util::PrimitiveTypeToNative::type; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + class CpuAllReduceRendezvous : public Rendezvous { public: @@ -146,110 +190,86 @@ class CpuAllReduceRendezvous return nullptr; } + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + void* reduce_output = + reinterpret_cast(me.destination_data) + chunk_offset; + + std::vector inputs; + inputs.reserve(world_size); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_data) + + chunk_offset); + } + switch (me.primitive_type) { case S8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case PRED: case U8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C128: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; default: return absl::UnimplementedError("Unexpected datatype"); } - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; + // All-gather the reduced chunks. for (const auto& p : participants_) { if (p->local_rank != me.local_rank) { - std::memcpy( - reinterpret_cast(p->destination_data) + chunk_offset, - reinterpret_cast(me.destination_data) + chunk_offset, - chunk_bytes); + std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, + reduce_output, chunk_bytes); } } return nullptr; } - - template - absl::Status DoAllReduce(const AllReduceParticipantData& me, - int64_t start_elem, int64_t num_elems) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - T initial_value = GetInitialValue(me.reduction_kind); - T* acc = reinterpret_cast(me.destination_data); - for (int64_t i = start_elem; i < start_elem + num_elems; ++i) { - acc[i] = initial_value; - } - - absl::Span out_chunk = absl::MakeSpan( - reinterpret_cast(me.destination_data) + start_elem, num_elems); - std::vector> inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(absl::Span( - reinterpret_cast(p->source_data) + start_elem, num_elems)); - } - switch (me.reduction_kind) { - case ReductionKind::SUM: - Reduce(out_chunk, inputs); - break; - case ReductionKind::PRODUCT: - Reduce(out_chunk, inputs); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); - } }; struct CollectivePermuteParticipantData : ParticipantData { @@ -378,6 +398,109 @@ class CpuAllGatherRendezvous } }; +struct ReduceScatterParticipantData : ParticipantData { + ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + ReductionKind reduction_kind; + PrimitiveType element_type; + const void* source_buffer; + void* destination_buffer; + size_t chunk_elems; + + std::string ToString() const override { + return absl::StrFormat( + "ReduceScatterParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_elems=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_elems); + } +}; + +class CpuReduceScatterRendezvous + : public Rendezvous { + public: + explicit CpuReduceScatterRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const ReduceScatterParticipantData& me) override { + auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); + int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; + + std::vector inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_buffer) + + chunk_offset); + } + + switch (me.element_type) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + default: + return absl::UnimplementedError("Unexpected datatype"); + } + + return nullptr; + } +}; + } // namespace struct InProcessCollectivesState { @@ -389,6 +512,8 @@ struct InProcessCollectivesState { all_to_all_rendezvous_map; RefcountingHashMap all_gather_rendezvous_map; + RefcountingHashMap + reduce_scatter_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -488,6 +613,27 @@ absl::Status InProcessCollectivesCommunicator::AllGather( .status(); } +absl::Status InProcessCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + ReduceScatterParticipantData participant(key, rank_); + participant.element_type = element_type; + participant.reduction_kind = reduction_kind; + participant.chunk_elems = chunk_elems; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuReduceScatterRendezvous::SubmitParticipant( + [&] { + return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/xla/service/cpu/in_process_collectives.h b/xla/service/cpu/in_process_collectives.h index aaedc474fa39b..f80baf38c4ebd 100644 --- a/xla/service/cpu/in_process_collectives.h +++ b/xla/service/cpu/in_process_collectives.h @@ -59,6 +59,12 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { const void* input_buffer, void* output_buffer, absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + private: InProcessCollectivesState* state_; int rank_; diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index f5d7a4c2c40fa..46ae3978aaa2e 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -1169,35 +1169,36 @@ Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { return OkStatus(); } +// Data types supported by ReduceScatter and AllReduce. +static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) { + // TODO(cheshire): Fix duplication wrt. cpu_runtime + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } +} + Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { CHECK_GE(crs->operand_count(), 1); PrimitiveType datatype = crs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); - bool is_datatype_supported = [&] { - // TODO(cheshire): Fix duplication wrt. cpu_runtime - switch (datatype) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: - case C64: - case C128: - return true; - default: - return false; - } - }(); - - if (!is_datatype_supported) { + if (!DataTypeIsSupportedByReduceScatter(datatype)) { return Unimplemented("AllReduce for datatype '%s' is not supported", primitive_util::LowercasePrimitiveTypeName(datatype)); } @@ -1285,7 +1286,54 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { } Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { - return Unimplemented("ReduceScatter is not implemented on CPU."); + CHECK_EQ(rs->operand_count(), 1); + PrimitiveType datatype = rs->operand(0)->shape().element_type(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs)); + + if (!DataTypeIsSupportedByReduceScatter(datatype)) { + return Unimplemented("ReduceScatter for datatype '%s' is not supported", + primitive_util::LowercasePrimitiveTypeName(datatype)); + } + + if (!MatchReductionComputation(rs->to_apply()).has_value()) { + return Unimplemented("ReduceScatter for computation '%s' is not supported", + rs->to_apply()->ToString()); + } + + std::string replica_groups = ReplicaGroupsToString(rs->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + Shape shape = rs->operand(0)->shape(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, + assignment_.GetUniqueSlice(rs->operand(0), {})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + assignment_.GetUniqueSlice(rs, {})); + llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); + llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); + + EmitCallToFunc( + runtime::kReduceScatterSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + + /*channel_id_present=*/ + b_.getInt32(static_cast(rs->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(rs->channel_id().has_value() ? *rs->channel_id() + : rs->GetModule()->unique_id()), + /*reduction_kind=*/ + b_.getInt32( + static_cast(*MatchReductionComputation(rs->to_apply()))), + /*element_type=*/ + b_.getInt32(static_cast(datatype)), + /*shape=*/b_.getInt64(ShapeUtil::ElementsIn(rs->shape())), + /*input_buffer=*/input_buffer, + /*output_buffer=*/output_buffer}, + b_.getVoidTy()); + + return OkStatus(); } Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { diff --git a/xla/service/cpu/simple_orc_jit.cc b/xla/service/cpu/simple_orc_jit.cc index 2e27a7c810869..0cc07a27246f7 100644 --- a/xla/service/cpu/simple_orc_jit.cc +++ b/xla/service/cpu/simple_orc_jit.cc @@ -486,6 +486,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); REGISTER_CPU_RUNTIME_SYMBOL(AllGather); + REGISTER_CPU_RUNTIME_SYMBOL(ReduceScatter); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32);