Skip to content

Commit

Permalink
Use LLs for register-to-register convert-layout ops. (#4125)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlebar authored Jun 13, 2024
1 parent 7378f3a commit 2329531
Show file tree
Hide file tree
Showing 9 changed files with 515 additions and 211 deletions.
4 changes: 4 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,14 @@ bool supportMMA(Value value, int version);

bool isSingleValue(Value value);

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// TODO(jlebar): Remove this function; it's subsumed by the linear-layout case
// in cvtNeedsSharedMemory.
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
Expand Down
27 changes: 22 additions & 5 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ class LinearLayout {
//
// This only works across the first (i.e. the most-minor) dimension of in/out.
// If you want it to work across more dimensions, flatten the layout.
//
// TODO(jlebar): Replace with divideLeft.
int32_t getNumConsecutiveInOut() const;

// Reorders the in/out dimensions of the layout. This is mostly cosmetic
Expand Down Expand Up @@ -574,10 +576,13 @@ class LinearLayout {
return *this;
}

// TODO(jlebar): Implement the inverse of operator*, namely
// std::optional<LinearLayout> divideLeft(const LinearLayout&);
// std::optional<LinearLayout> divideRight(const LinearLayout&);
// In particular, these might subsume getNumConsecutiveInOut.
// divideLeft and divideRight are the inverses of operator*.
//
// If c = a * b, then a = c.divideRight(b) and b = c.divideLeft(a).
//
// TODO(jlebar): Implement divideLeft.
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
std::optional<LinearLayout> divideRight(const LinearLayout &divisor);

// Computes and returns L(x, y, z).
//
Expand Down Expand Up @@ -648,7 +653,19 @@ class LinearLayout {
}

private:
void checkInvariants(bool requireSurjective);
// Factory function that gracefully fails rather than asserts if the layout is
// not well-formed.
static std::optional<LinearLayout>
tryCreate(BasesT bases, ArrayRef<std::pair<StringAttr, int32_t>> outDims,
bool requireSurjective);

// Constructor that does not check invariants. Used by tryCreate.
struct NoCheckInvariants {};
LinearLayout(BasesT bases, ArrayRef<std::pair<StringAttr, int32_t>> outDims,
NoCheckInvariants);

[[nodiscard]] std::optional<std::string>
checkInvariants(bool requireSurjective);
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
Expand Down
21 changes: 4 additions & 17 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,15 @@ SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (!cvtNeedsSharedMemory(srcTy, dstTy)) {
return {};
}

if (shouldUseDistSmem(srcLayout, dstLayout)) {
// TODO: padding to avoid bank conflicts
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}

if (isMfmaToDotShortcut(srcTy, dstTy))
return {};

// MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem
if (auto srcMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout)) {
if (mlir::isa<DotOperandEncodingAttr>(dstLayout)) {
if (isMmaToDotShortcut(srcTy, dstTy)) {
return {};
}
} else if (auto dstMmaLayout =
mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout)) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
return {};
}
}
}

assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()");

auto srcShapePerCTA = getShapePerCTA(srcTy);
Expand Down
69 changes: 52 additions & 17 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/Sys/GetEnv.hpp"

namespace mlir {
Expand Down Expand Up @@ -583,10 +585,8 @@ bool supportMMA(Value value, int version) {
}

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcLayout);
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstLayout);
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
Expand Down Expand Up @@ -618,28 +618,63 @@ bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mmaLayout = cast<NvidiaMmaEncodingAttr>(srcLayout);
auto dotOperandLayout = cast<DotOperandEncodingAttr>(dstLayout);
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!mmaLayout || !dotOperandLayout) {
return false;
}
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) &&
(elementTypeSize == 16 || elementTypeSize == 8);
auto ans =
mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcTy.getEncoding()) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (srcLayout.has_value() && dstLayout.has_value()) {
// comp describes the layout function for converting from src to dst.
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
// In principle, there's no need for shared memory if there's no
// communication between warps. However, right now we only have implemented
// the shortcut case where there's no communication between *threads*.
//
// TODO(jlebar): Remove the kLane layout once we add support for
// shuffle-based layout conversions in ConvertLayoutToLLVM.
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane),
kLane, kLane) *
LinearLayout::identity1D(comp.getInDimSize(kWarp),
kWarp, kWarp) *
LinearLayout::identity1D(comp.getInDimSize(kBlock),
kBlock, kBlock))
.has_value()) {
return false;
}
}

// TODO(jlebar): Remove these special cases once they're fully subsumed by the
// linear-layout check above.
return !isMmaToMmaShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mmaLayout = mlir::cast<NvidiaMmaEncodingAttr>(srcLayout);
auto dotOperandLayout = mlir::cast<DotOperandEncodingAttr>(dstLayout);
return mmaLayout.getVersionMajor() == 2 &&
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
Expand Down
Loading

0 comments on commit 2329531

Please sign in to comment.