Skip to content

Commit

Permalink
[XLA:GPU] Remove PriorityFusion inheritance from InstructionFusion.
Browse files Browse the repository at this point in the history
PriorityFusion had been diverging from InstructionFusion for some time already and the last connecting piece was recently removed. There is not reason for PriorityFusion to inherit from InstructionFusion now.

The build dependancy is still there, because we use `FusionDecision` and `InstructionFusion::ShouldFuseInPlaceOp`. But those should be in a shared library or something.

PiperOrigin-RevId: 668441610
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Aug 28, 2024
1 parent 8ba3f3e commit 1813776
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 65 deletions.
2 changes: 0 additions & 2 deletions third_party/xla/xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2075,7 +2075,6 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:dump",
"//xla/service:fusion_queue",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_graph_dumper",
"//xla/service:hlo_pass",
Expand All @@ -2093,7 +2092,6 @@ cc_library(
"//xla/service/gpu/model:gpu_indexing_performance_model",
"//xla/service/gpu/model:gpu_performance_model",
"//xla/service/gpu/model:gpu_performance_model_base",
"//xla/service/gpu/model:symbolic_tile_analysis",
"//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
"//xla/service/gpu/model:triton_emitter_constraints",
"//xla/stream_executor:device_description",
Expand Down
49 changes: 4 additions & 45 deletions third_party/xla/xla/service/gpu/transforms/priority_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
#include <map>
#include <memory>
#include <string>
Expand All @@ -44,7 +43,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/dump.h"
#include "xla/service/fusion_queue.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/fusion_deduplication_cache.h"
#include "xla/service/gpu/fusion_process_dump.pb.h"
Expand All @@ -58,7 +56,6 @@ limitations under the License.
#include "xla/service/gpu/model/gpu_indexing_performance_model.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/gpu/model/triton_emitter_constraints.h"
#include "xla/service/hlo_graph_dumper.h"
Expand All @@ -78,10 +75,6 @@ namespace xla {
namespace gpu {

namespace {
bool ElementIsF32OrF16(const Shape& shape) {
PrimitiveType type = shape.element_type();
return type == F32 || type == F16;
}

bool IsFusible(const HloInstruction& instr) {
// Side-effecting operations are not fusible.
Expand Down Expand Up @@ -886,26 +879,6 @@ class PriorityFusionQueue {

} // namespace

/*static*/ bool PriorityFusion::IsExpensive(const HloInstruction& instruction) {
// Some floating-point math ops are cheap on the GPU.
switch (instruction.opcode()) {
case HloOpcode::kDivide:
case HloOpcode::kSqrt:
case HloOpcode::kRsqrt:
case HloOpcode::kExp:
if (ElementIsF32OrF16(instruction.shape())) {
return false;
}
break;
// Loop fusions are cheap.
case HloOpcode::kFusion:
return false;
default:
break;
}
return InstructionFusion::IsExpensive(instruction);
}

// Return true, if instr is a small constant.
//
// There is not single definition for what is a small constant in XLA.
Expand Down Expand Up @@ -1003,7 +976,7 @@ absl::StatusOr<bool> PriorityFusion::Run(
int64_t consumer_operand_index = consumer->operand_index(producer);

fusion_queue->PreFusion(producer, consumer);
auto fusion_instruction = Fuse(producer, consumer, computation);
auto fusion_instruction = Fuse(producer, consumer);
fusion_deduplication_cache.UpdateFusedInstructionId(
*fusion_instruction, *producer, *consumer, consumer_operand_index);
fusion_queue->OnFusingInstruction(fusion_instruction, producer,
Expand Down Expand Up @@ -1051,7 +1024,7 @@ absl::StatusOr<bool> PriorityFusion::Run(
auto users = constant->users();
for (auto* user : users) {
if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) {
Fuse(constant, user, computation);
Fuse(constant, user);
changed = true;
}
}
Expand All @@ -1072,15 +1045,6 @@ absl::StatusOr<bool> PriorityFusion::Run(
return changed;
}

FusionDecision PriorityFusion::ShouldFuse(HloInstruction* consumer,
int64_t operand_index) {
// This method is called in `InstructionFusion::Run` right before fusion, but
// it will always return true. Fusion decision are fully controlled by the
// PriorityQueue. If the queue returns a producer that shouldn't be fused,
// it's a bug and should be fixed in the queue logic.
return {};
}

HloInstruction::FusionKind PriorityFusion::ChooseKind(
const HloInstruction* producer, const HloInstruction* consumer) {
// Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't
Expand All @@ -1104,11 +1068,11 @@ HloInstruction::FusionKind PriorityFusion::ChooseKind(
}

HloInstruction* PriorityFusion::Fuse(HloInstruction* producer,
HloInstruction* consumer,
HloComputation* computation) {
HloInstruction* consumer) {
VLOG(2) << "Fusing " << producer->ToString() << " into "
<< consumer->ToString();

HloComputation* computation = consumer->parent();
auto kind = ChooseKind(producer, consumer);
HloInstruction* fusion_instruction = consumer;

Expand Down Expand Up @@ -1137,10 +1101,5 @@ HloInstruction* PriorityFusion::Fuse(HloInstruction* producer,
return fusion_instruction;
}

std::unique_ptr<FusionQueue> PriorityFusion::GetFusionQueue(
HloComputation* computation) {
return nullptr;
}

} // namespace gpu
} // namespace xla
23 changes: 5 additions & 18 deletions third_party/xla/xla/service/gpu/transforms/priority_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
#define XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_

#include <stdint.h>

#include <memory>
#include <utility>
Expand All @@ -28,52 +27,40 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/fusion_queue.h"
#include "xla/service/gpu/fusion_process_dump.pb.h"
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/threadpool.h"

namespace xla {
namespace gpu {

class PriorityFusion : public InstructionFusion {
class PriorityFusion : public HloModulePass {
public:
PriorityFusion(tsl::thread::ThreadPool* thread_pool,
const se::DeviceDescription& device,
GpuHloCostAnalysis::Options cost_analysis_options)
: InstructionFusion(PriorityFusion::IsExpensive),
thread_pool_(thread_pool),
: thread_pool_(thread_pool),
device_info_(device),
cost_analysis_options_(std::move(cost_analysis_options)),
fusion_analysis_cache_(device_info_) {}

absl::string_view name() const override { return "priority-fusion"; }

static bool IsExpensive(const HloInstruction& instruction);

using HloPassInterface::Run;
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

protected:
std::unique_ptr<FusionQueue> GetFusionQueue(
HloComputation* computation) override;

FusionDecision ShouldFuse(HloInstruction* consumer,
int64_t operand_index) override;

HloInstruction::FusionKind ChooseKind(
const HloInstruction* producer, const HloInstruction* consumer) override;
HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
const HloInstruction* consumer);

HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer,
HloComputation* computation) override;
HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);

private:
// Consumes a unit of compiler fuel and returns true if we should
Expand Down

0 comments on commit 1813776

Please sign in to comment.