Skip to content

Commit

Permalink
Add view_slice ttgir instruction (#427)
Browse files Browse the repository at this point in the history
* Add view_slice op in ttgir

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
Co-authored-by: Ognjen <[email protected]>
Co-authored-by: Lixun Zhang <[email protected]>
  • Loading branch information
4 people authored Jan 2, 2024
1 parent 98589ac commit 6a52056
Show file tree
Hide file tree
Showing 7 changed files with 372 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ unsigned getNumCTAs(Attribute layout);

bool isaDistributedLayout(Attribute layout);

bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
BlockedEncodingAttr blockedB);

bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB);

bool isSharedEncoding(Value value);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
Expand Down
9 changes: 9 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace OpTrait {
// instantiated/duplicated.
namespace impl {
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
LogicalResult verifyOperandAndResultHaveSameEncoding(Operation *op);
} // namespace impl

template <typename ConcreteType>
Expand All @@ -25,6 +26,14 @@ class ResultsAreSharedEncoding
}
};

template <typename ConcreteType>
class OperandAndResultHaveSameEncoding
: public TraitBase<ConcreteType, OperandAndResultHaveSameEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOperandAndResultHaveSameEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir

Expand Down
62 changes: 62 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"

def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
def OperandAndResultHaveSameEncoding: NativeOpTrait<"OperandAndResultHaveSameEncoding">;

class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
Expand Down Expand Up @@ -179,6 +180,67 @@ def TTG_InsertSliceOp : TTG_Op<"insert_slice",
}


def TTG_ViewSliceOp : TTG_Op<"view_slice",
[AttrSizedOperandSegments,
OperandAndResultHaveSameEncoding,
Pure,
OffsetSizeAndStrideOpInterface
]> {
let summary = "view slice operation";
let description = [{
Represents view of the slice of the tensor in registers. Syntax of the operation is the same
as for extract_slice op. However, unlike 'extract_slice' which slices in shared memory,
'view_slice' specifically slices within registers.
Slice of the tensor is required to have the same layout as the original tensor.
In a way, semantics of the 'view_slice' operation is a combination of the 'extract_slice' and 'view' operations semantics.
}];

let arguments = (ins
AnyRankedTensor:$source,
Variadic<I32>:$offsets,
Variadic<I32>:$sizes,
Variadic<I32>:$strides,
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);

let builders = [
// Build an ExtractSliceOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];

let extraClassDeclaration = [{
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }

/// Returns the type of the base tensor operand.
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}

std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank = getSourceType().getRank();
return {rank, rank, rank};
}
}];

let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $static_offsets)
custom<DynamicIndexList>($sizes, $static_sizes)
custom<DynamicIndexList>($strides, $static_strides)
attr-dict `:` type($source) `to` type($result)
}];
}


def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
Expand Down
130 changes: 130 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,135 @@ struct ExtractSliceOpConversion
}
};

// clang-format off
/***
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# WO # W1 # | #
# # # | #
# # # # # | #
# W2 # W3 # .... | #
# # # | SkipElems #
# # # # # | #
# | #
# Slice | #
# . / \ | #
# . / \ | #
# . / \| #
# # # # # # #
# # W0 # W1 # #
# # # # #
# # # # # # tensorStride #
# # W2 # W3 # --------------------------------#
# # # # #
# # # # # # #
# tensorStride # W0 # W1 # #
# ---------------------------------- # # # #
# # # # # # #
# # W2 # W3 # #
# # # # #
# # # # # # ---> lastIdx #
# . #
# . #
# . #
# #
# #
# #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
***/
// clang-format on
struct ViewSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp> {
using OpAdaptor = typename triton::gpu::ViewSliceOp::Adaptor;
explicit ViewSliceOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp>(typeConverter,
benefit) {}

LogicalResult
processBlockedLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(
srcLayout &&
"Currently only blocked layout is supported in view_slice instruction");
auto srcShape = srcTy.getShape();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSource(), rewriter, srcTy);

auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy);
auto sizePerThread = srcLayout.getSizePerThread();
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1];
auto order = srcLayout.getOrder();
auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape);
shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]);
shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]);

auto offsets = op.getStaticOffsets();
auto sizes = op.getStaticSizes();

// ViewSlice only supports slicing where offsets and sizes are multiples of
// shapePerCTA. This condition ensures that slice has the same layout as the
// original tensor.
assert(offsets[0] % shapePerCTA[0] == 0);
assert(offsets[1] % shapePerCTA[1] == 0);
assert(sizes[0] % shapePerCTA[0] == 0);
assert(sizes[1] % shapePerCTA[1] == 0);
assert(op.hasUnitStride() &&
"Only unit stride supported by ViewSliceOpConversion");

// Calculate offsets and sizes in terms of CTA units.
std::vector<long int> CTAOffsets{offsets[0] / shapePerCTA[0],
offsets[1] / shapePerCTA[1]};
std::vector<long int> CTASizes{sizes[0] / shapePerCTA[0],
sizes[1] / shapePerCTA[1]};
std::vector<long int> CTAPerShape{srcShape[0] / shapePerCTA[0],
srcShape[1] / shapePerCTA[1]};

SmallVector<Value> resultVals;
// The diagram above illustrates the graphical representation of the
// skipElems, tensorStride, and lastIdx variables.
auto skipElems = CTAOffsets[order[1]] *
(elemsPerThread[order[0]] * sizePerThread[order[1]]) +
CTAOffsets[order[0]] * totalSizePerThread;
auto tensorStride =
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread;
auto lastIdx =
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
elemsPerThread[order[0]] * sizePerThread[order[1]] +
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread;

assert(lastIdx <= vals.size());
for (int i = skipElems; i < lastIdx; i += tensorStride) {
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) {
assert(i < lastIdx);
resultVals.push_back(vals[i]);
}
}

Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}

LogicalResult
matchAndRewrite(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
if (srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>()) {
return processBlockedLayout(op, adaptor, rewriter);
} else {
assert(false && "unsupported layout in viewSlice");
return failure();
}
}
};

struct AsyncWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
Expand Down Expand Up @@ -954,6 +1083,7 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<ViewSliceOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
Expand Down
53 changes: 53 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,59 @@ bool isaDistributedLayout(Attribute layout) {
layout.isa<MfmaEncodingAttr>() || layout.isa<SliceEncodingAttr>();
}

bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
BlockedEncodingAttr blockedB) {
auto sizePerThreadA = blockedA.getSizePerThread();
auto threadsPerWarpA = blockedA.getThreadsPerWarp();
auto warpsPerCTAA = blockedA.getWarpsPerCTA();
auto orderA = blockedA.getOrder();
size_t rankA = orderA.size();

auto sizePerThreadB = blockedB.getSizePerThread();
auto threadsPerWarpB = blockedB.getThreadsPerWarp();
auto warpsPerCTAB = blockedB.getWarpsPerCTA();
auto orderB = blockedB.getOrder();
size_t rankB = orderB.size();

if (rankA != rankB) {
return false;
}
for (size_t i = 0; i < rankA; ++i) {
if (sizePerThreadA[i] != sizePerThreadB[i] ||
threadsPerWarpA[i] != threadsPerWarpB[i] ||
warpsPerCTAA[i] != warpsPerCTAB[i] || orderA[i] != orderB[i]) {
return false;
}
}
return true;
}

bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
auto nonKDimA = mfmaA.getNonKDim();
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
auto isTransposedA = mfmaA.getIsTransposed();

auto nonKDimB = mfmaB.getNonKDim();
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
auto isTransposedB = mfmaB.getIsTransposed();

if (nonKDimA != nonKDimB || isTransposedA != isTransposedB) {
return false;
}

if (warpsPerCTAA.size() != warpsPerCTAB.size()) {
return false;
}

auto rank = warpsPerCTAA.size();
for (size_t i = 0; i < rank; ++i) {
if (warpsPerCTAA[i] != warpsPerCTAB[i]) {
return false;
}
}
return true;
}

bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Expand Down
48 changes: 48 additions & 0 deletions lib/Dialect/TritonGPU/IR/Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,51 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {

return success();
};

mlir::LogicalResult
mlir::OpTrait::impl::verifyOperandAndResultHaveSameEncoding(Operation *op) {
if (op->getNumOperands() != 1 || op->getNumResults() != 1) {
return failure();
}

auto operandType = op->getOperand(0).getType().dyn_cast<RankedTensorType>();
auto resultType = op->getResult(0).getType().dyn_cast<RankedTensorType>();

if (!operandType || !resultType) {
return failure();
}
auto operandLayout = operandType.getEncoding();
auto resultLayout = resultType.getEncoding();

if (auto blockedLayoutSrc =
dyn_cast<triton::gpu::BlockedEncodingAttr>(operandLayout)) {
auto blockedLayoutRes =
dyn_cast<triton::gpu::BlockedEncodingAttr>(resultLayout);
if (!blockedLayoutRes) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}

if (!triton::gpu::sameBlockedEncodings(blockedLayoutSrc,
blockedLayoutRes)) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
} else if (auto mfmaLayoutSrc =
dyn_cast<triton::gpu::MfmaEncodingAttr>(operandLayout)) {
auto mfmaLayoutRes = dyn_cast<triton::gpu::MfmaEncodingAttr>(resultLayout);
if (!mfmaLayoutRes) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
if (!triton::gpu::sameMfmaEncodings(mfmaLayoutSrc, mfmaLayoutRes)) {
return op->emitOpError()
<< "requires operand and result to have same layout";
}
} else {
assert(false &&
"Unexpected Layout in verifyOperandAndResultHaveSmeEncoding");
}

return success();
};
Loading

0 comments on commit 6a52056

Please sign in to comment.