Skip to content

Commit

Permalink
[xla:gpu] Support emitting memcpy thunks for command buffers tensorfl…
Browse files Browse the repository at this point in the history
…ow#6224

PiperOrigin-RevId: 589288499
  • Loading branch information
anlunx authored and tensorflower-gardener committed Dec 9, 2023
1 parent 1cd4b5c commit 6f3a097
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 43 deletions.
7 changes: 7 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cc_library(
"//xla:statusor",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo:lhlo",
"//xla/service:buffer_assignment",
"//xla/service:elemental_ir_emitter",
"//xla/service/gpu:gpu_executable",
"//xla/service/gpu:ir_emission_utils",
Expand All @@ -40,6 +41,7 @@ cc_library(
"//xla/service/gpu:thunk",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
],
)

Expand Down Expand Up @@ -87,14 +89,19 @@ cc_library(
":reduction",
":transpose",
"//xla:shape_util",
"//xla:status",
"//xla:statusor",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo:lhlo",
"//xla/service:buffer_assignment",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)

Expand Down
29 changes: 15 additions & 14 deletions third_party/xla/xla/service/gpu/fusions/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,24 @@ namespace xla {
namespace gpu {

StatusOr<FusionEmissionResult> MemcpyFusion::Emit(
IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter,
IrEmitterContext& ir_emitter_context, ElementalIrEmitter&,
mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion,
KernelReuseCache& kernel_cache, llvm::IRBuilder<>*) const {
KernelReuseCache&, llvm::IRBuilder<>*) const {
FusionEmissionResult result;
for (auto [src, dst] : llvm::zip(srcs_, dsts_)) {
auto src_buffer =
*GetAllocationSlice(src, ir_emitter_context.allocations());
auto dst_buffer =
*GetAllocationSlice(dst, ir_emitter_context.allocations());
if (src_buffer != dst_buffer) {
for (int i = 0; i < src_buffers_.size(); ++i) {
if (src_buffers_[i] != dst_buffers_[i]) {
result.thunks.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(fusion_op),
/*source_buffer=*/src_buffer,
/*destination_buffer=*/dst_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(src)),
/*source_value=*/src,
/*destination_value=*/dst));
ir_emitter_context.emit_ir_from_hlo()
? Thunk::ThunkInfo::WithProfileAnnotation(&fusion)
: Thunk::ThunkInfo::WithProfileAnnotation(fusion_op),
/*source_buffer=*/src_buffers_[i],
/*destination_buffer=*/dst_buffers_[i],
/*mem_size=*/src_buffers_[i].size(),
/*source_value=*/ir_emitter_context.emit_ir_from_hlo() ? nullptr
: srcs_[i],
/*destination_value=*/ir_emitter_context.emit_ir_from_hlo()
? nullptr
: dsts_[i]));
}
}
return result;
Expand Down
16 changes: 14 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#include <vector>

#include "mlir/IR/Value.h" // from @llvm-project
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
#include "xla/service/gpu/ir_emitter_context.h"

Expand All @@ -27,8 +29,13 @@ namespace gpu {
// implemented using `memcpy`s.
class MemcpyFusion : public FusionInterface {
public:
MemcpyFusion(std::vector<mlir::Value> srcs, std::vector<mlir::Value> dsts)
: srcs_(std::move(srcs)), dsts_(std::move(dsts)) {}
MemcpyFusion(std::vector<BufferAllocation::Slice> src_buffers,
std::vector<BufferAllocation::Slice> dst_buffers,
std::vector<mlir::Value> srcs, std::vector<mlir::Value> dsts)
: src_buffers_(std::move(src_buffers)),
dst_buffers_(std::move(dst_buffers)),
srcs_(std::move(srcs)),
dsts_(std::move(dsts)) {}

StatusOr<FusionEmissionResult> Emit(IrEmitterContext& ir_emitter_context,
ElementalIrEmitter& elemental_emitter,
Expand All @@ -38,6 +45,11 @@ class MemcpyFusion : public FusionInterface {
llvm::IRBuilder<>*) const final;

private:
std::vector<BufferAllocation::Slice> src_buffers_;
std::vector<BufferAllocation::Slice> dst_buffers_;

// These are only used by the LMHLO code path and are empty if emitting from
// HLO.
std::vector<mlir::Value> srcs_;
std::vector<mlir::Value> dsts_;
};
Expand Down
107 changes: 92 additions & 15 deletions third_party/xla/xla/service/gpu/fusions/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ limitations under the License.
#include <memory>
#include <optional>
#include <utility>
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "mlir/IR/Value.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
Expand All @@ -37,6 +40,11 @@ limitations under the License.
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -64,13 +72,11 @@ bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) {

} // namespace

std::optional<std::unique_ptr<FusionInterface>> GetCopyFusion(
HloFusionAnalysis& analysis,
absl::Span<const BufferAllocation* const> allocations,
mlir::lmhlo::FusionOp fusion_op) {
if (!fusion_op) {
return std::nullopt;
}
StatusOr<std::optional<std::unique_ptr<FusionInterface>>> GetCopyFusionImpl(
HloFusionAnalysis& analysis, LmhloFusionInfo fusion_info) {
mlir::lmhlo::FusionOp fusion_op = fusion_info.fusion_op;
absl::Span<const BufferAllocation* const> allocations =
fusion_info.allocations;

auto params = GetHloOperands(fusion_op);
auto outputs = GetHloOutputs(fusion_op);
Expand All @@ -91,31 +97,102 @@ std::optional<std::unique_ptr<FusionInterface>> GetCopyFusion(
srcs.emplace_back(src);
}

return std::make_unique<MemcpyFusion>(
std::move(srcs),
std::vector<mlir::Value>(outputs.begin(), outputs.end()));
auto dsts = std::vector<mlir::Value>(outputs.begin(), outputs.end());
DCHECK(srcs.size() == dsts.size());
std::vector<BufferAllocation::Slice> src_buffers;
std::vector<BufferAllocation::Slice> dst_buffers;
for (int i = 0; i < srcs.size(); ++i) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer,
GetAllocationSlice(srcs[i], allocations));
src_buffers.push_back(src_buffer);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer,
GetAllocationSlice(dsts[i], allocations));
dst_buffers.push_back(dst_buffer);
}

return std::make_unique<MemcpyFusion>(std::move(src_buffers),
std::move(dst_buffers), std::move(srcs),
std::move(dsts));
}

StatusOr<std::optional<std::unique_ptr<FusionInterface>>> GetCopyFusionImpl(
HloFusionAnalysis& analysis, HloFusionInfo fusion_info) {
const HloFusionInstruction* fusion = fusion_info.instr;
const BufferAssignment* buffer_assignment = fusion_info.buffer_assignment;

std::vector<BufferAllocation::Slice> src_buffers;
for (auto* root : analysis.fusion_roots()) {
if (root->opcode() != HloOpcode::kCopy ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
!LayoutUtil::Equal(root->operand(0)->shape().layout(),
root->shape().layout())) {
return std::nullopt;
}

const HloInstruction* src_instr =
fusion->operands()[root->operand(0)->parameter_number()];
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
buffer_assignment->GetUniqueSlice(src_instr, {}));
src_buffers.push_back(slice);
}

std::vector<BufferAllocation::Slice> dst_buffers;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
fusion->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!subshape.IsArray()) {
return OkStatus();
}
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
buffer_assignment->GetUniqueSlice(fusion, index));
dst_buffers.push_back(slice);
return OkStatus();
}));

DCHECK(src_buffers.size() == dst_buffers.size());
std::vector<mlir::Value> srcs;
std::vector<mlir::Value> dsts;
return std::make_unique<MemcpyFusion>(std::move(src_buffers),
std::move(dst_buffers),
/*srcs=*/std::vector<mlir::Value>(),
/*dsts=*/std::vector<mlir::Value>());
}

StatusOr<std::optional<std::unique_ptr<FusionInterface>>> GetCopyFusion(
HloFusionAnalysis& analysis,
std::variant<HloFusionInfo, LmhloFusionInfo> fusion_info) {
if (std::holds_alternative<HloFusionInfo>(fusion_info)) {
return GetCopyFusionImpl(analysis, std::get<HloFusionInfo>(fusion_info));
} else {
return GetCopyFusionImpl(analysis, std::get<LmhloFusionInfo>(fusion_info));
}
}

} // namespace

std::optional<std::unique_ptr<FusionInterface>> GetFusionEmitter(
StatusOr<std::optional<std::unique_ptr<FusionInterface>>> GetFusionEmitter(
HloFusionAnalysis& analysis,
absl::Span<const BufferAllocation* const> allocations,
mlir::lmhlo::FusionOp fusion_op) {
std::variant<HloFusionInfo, LmhloFusionInfo> fusion_info) {
switch (analysis.GetEmitterFusionKind()) {
case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
return std::make_unique<InputSlicesFusion>(analysis);
case HloFusionAnalysis::EmitterFusionKind::kLoop: {
if (IsDynamicUpdateSliceFusion(analysis)) {
if (allocations.empty() || fusion_op == nullptr) {
if (!std::holds_alternative<LmhloFusionInfo>(fusion_info)) {
return std::nullopt;
}
auto lmhlo_fusion_info = std::get<LmhloFusionInfo>(fusion_info);
absl::Span<const BufferAllocation* const> allocations =
lmhlo_fusion_info.allocations;
mlir::lmhlo::FusionOp fusion_op = lmhlo_fusion_info.fusion_op;
if (CanEmitFusedDynamicUpdateSliceInPlaceForGpu(fusion_op,
allocations)) {
return std::make_unique<InPlaceDynamicUpdateSliceEmitter>(analysis);
}
}
if (auto copy_fusion = GetCopyFusion(analysis, allocations, fusion_op)) {
TF_ASSIGN_OR_RETURN(
std::optional<std::unique_ptr<FusionInterface>> copy_fusion,
GetCopyFusion(analysis, fusion_info));
if (copy_fusion.has_value()) {
return copy_fusion;
}
return std::make_unique<LoopFusion>(analysis);
Expand Down
30 changes: 24 additions & 6 deletions third_party/xla/xla/service/gpu/fusions/fusions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,43 @@ limitations under the License.

#include <memory>
#include <optional>
#include <variant>

#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/statusor.h"

namespace xla {
namespace gpu {

struct LmhloFusionInfo {
mlir::lmhlo::FusionOp fusion_op;
absl::Span<const BufferAllocation* const> allocations;

explicit LmhloFusionInfo(
mlir::lmhlo::FusionOp fusion_op,
absl::Span<const BufferAllocation* const> allocations)
: fusion_op(fusion_op), allocations(allocations) {}
};

struct HloFusionInfo {
const HloFusionInstruction* instr;
const BufferAssignment* buffer_assignment;

explicit HloFusionInfo(const HloFusionInstruction* instr,
const BufferAssignment* buffer_assignment)
: instr(instr), buffer_assignment(buffer_assignment) {}
};

// Returns the emitter for the given fusion. Returns nullopt if the fusion
// type is not yet supported.
// `allocations` may be empty and `fusion_op` may be nullptr if no LMHLO ops are
// available. In this case, this function will return nullopt if it cannot
// detect whether a loop fusion can be optimized.
std::optional<std::unique_ptr<FusionInterface>> GetFusionEmitter(
StatusOr<std::optional<std::unique_ptr<FusionInterface>>> GetFusionEmitter(
HloFusionAnalysis& analysis,
absl::Span<const BufferAllocation* const> allocations,
mlir::lmhlo::FusionOp fusion_op);
std::variant<HloFusionInfo, LmhloFusionInfo> fusion_info);

} // namespace gpu
} // namespace xla
Expand Down
18 changes: 12 additions & 6 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2183,9 +2183,13 @@ Status IrEmitterUnnested::EmitFusion(
case HloFusionAnalysis::EmitterFusionKind::kLoop:
case HloFusionAnalysis::EmitterFusionKind::kTranspose:
case HloFusionAnalysis::EmitterFusionKind::kReduction: {
auto emitter = GetFusionEmitter(fusion_analysis, {}, nullptr);
// TODO(anlunx): Support MemcpyFusion and InPlaceDynamicUpdateSlice and
// remove this fallback.
TF_ASSIGN_OR_RETURN(
std::optional<std::unique_ptr<FusionInterface>> emitter,
GetFusionEmitter(
fusion_analysis,
HloFusionInfo(instr, &ir_emitter_context_->buffer_assignment())));
// TODO(anlunx): Support InPlaceDynamicUpdateSlice and remove this
// fallback.
if (!emitter) {
TF_RET_CHECK(op)
<< "Fusion should have been handled by GetFusionEmitter, fallback "
Expand Down Expand Up @@ -2268,9 +2272,11 @@ Status IrEmitterUnnested::EmitFusion(
case HloFusionAnalysis::EmitterFusionKind::kLoop:
case HloFusionAnalysis::EmitterFusionKind::kReduction:
case HloFusionAnalysis::EmitterFusionKind::kTranspose: {
std::optional<std::unique_ptr<FusionInterface>> emitter =
GetFusionEmitter(fusion_analysis, ir_emitter_context_->allocations(),
fusion_op);
TF_ASSIGN_OR_RETURN(
std::optional<std::unique_ptr<FusionInterface>> emitter,
GetFusionEmitter(
fusion_analysis,
LmhloFusionInfo(fusion_op, ir_emitter_context_->allocations())));
if (emitter == std::nullopt) {
return FailedPrecondition(
"Fusion should have been handled by GetFusionEmitter.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <utility>

#include "xla/service/gpu/copy_thunk.h"
#include "xla/service/gpu/kernel_thunk.h"
#include "xla/service/gpu/runtime3/command_buffer_cmd.h"
#include "xla/service/gpu/thunk.h"
Expand All @@ -40,6 +41,13 @@ StatusOr<std::unique_ptr<CommandBufferCmd>> ConvertToCommand(
kernel_thunk.launch_dimensions(), kernel_thunk.shmem_bytes());
return kernel_cmd;
}
case Thunk::Kind::kCopy: {
auto& copy_thunk = static_cast<const DeviceToDeviceCopyThunk&>(thunk);
auto copy_cmd = std::make_unique<MemcpyDeviceToDeviceCmd>(
copy_thunk.destination(), copy_thunk.source(),
copy_thunk.size_bytes());
return copy_cmd;
}
default:
return InternalError("Unsupported thunk kind");
}
Expand Down

0 comments on commit 6f3a097

Please sign in to comment.