Skip to content

Commit

Permalink
check point
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Dec 20, 2024
1 parent 9acb4ca commit 9159d90
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 27 deletions.
24 changes: 14 additions & 10 deletions compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
Expand Down Expand Up @@ -65,9 +66,8 @@ static Value convertElementType(OpBuilder &b, Location loc, Type targetType,
/// std::nullopt.
static std::optional<Type> getLegalizedType(Type t) {
if (auto shapedType = llvm::dyn_cast<RankedTensorType>(t)) {
Type elementType = shapedType.getElementType();
std::optional<Type> legalizedElementType =
legalizeStorageElementType(elementType);
legalizeTensorStorageElementType(shapedType);
if (!legalizedElementType)
return std::nullopt;
return RankedTensorType::get(shapedType.getShape(),
Expand Down Expand Up @@ -114,7 +114,7 @@ struct ConstantOpTypeConversion
constantOp, "expected attribute type to be shaped type");
}
std::optional<Type> legalizedElementType =
legalizeStorageElementType(attrType.getElementType());
legalizeTensorStorageElementType(attrType);
if (!legalizedElementType) {
return rewriter.notifyMatchFailure(constantOp,
"cannot legalize elementType");
Expand Down Expand Up @@ -220,8 +220,10 @@ struct GenericOpTypePropagation
signatureConverter.addInputs(index, argType);
continue;
}
auto inputOperandType =
llvm::cast<RankedTensorType>(genericOp->getOperandTypes()[index]);
std::optional<Type> legalizedArgType =
legalizeStorageElementType(argType);
legalizeTensorStorageElementType(inputOperandType);
if (!legalizedArgType) {
return genericOp.emitOpError("failed to get legalized type for arg ")
<< index;
Expand Down Expand Up @@ -251,8 +253,8 @@ struct GenericOpTypePropagation
modifyYield = true;
OpOperand *yieldOperand =
modifiedOp.getMatchingYieldValue(modifiedOpOperand);
std::optional<Type> legalizedType =
legalizeStorageElementType(yieldOperand->get().getType());
std::optional<Type> legalizedType = legalizeTensorStorageElementType(
modifiedOpOperand->get().getType());
if (!legalizedType) {
return genericOp.emitOpError(
"failed to get legalized type for yield value");
Expand Down Expand Up @@ -282,7 +284,7 @@ struct LinalgFillTypePropagation
ConversionPatternRewriter &rewriter) const final {
Value value = adaptor.getInputs().front();
std::optional<Type> legalizedElementType =
legalizeStorageElementType(value.getType());
legalizeTensorStorageElementType(adaptor.getOutputs()[0].getType());
if (!legalizedElementType) {
return fillOp.emitOpError("failed to get legalized type for value");
}
Expand Down Expand Up @@ -348,8 +350,8 @@ struct IREELinalgExtScatterTypePropagation
// type.
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
Type argType = modifiedOpRegion.getArguments()[0].getType();
std::optional<Type> legalizedArgType = legalizeStorageElementType(argType);
std::optional<Type> legalizedArgType =
legalizeTensorStorageElementType(inputType);
if (!legalizedArgType) {
return scatterOp.emitOpError("failed to get legalized type for argument");
}
Expand Down Expand Up @@ -411,8 +413,10 @@ struct IREELinalgExtSortTypePropagation
TypeConverter::SignatureConversion signatureConverter(
modifiedOpRegion.getNumArguments());
for (auto [index, arg] : llvm::enumerate(modifiedOpRegion.getArguments())) {
// Refer to input types of the original operation to determine the
// corresponding legal arg type.
std::optional<Type> legalizedArgType =
legalizeStorageElementType(arg.getType());
legalizeTensorStorageElementType(sortOp->getOperandTypes()[index]);
if (!legalizedArgType) {
return sortOp.emitOpError("failed to get legalized type for argument");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,12 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

bool hasPackedStorageAttr(RankedTensorType type) {
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
bool hasPackedStorageAttr(Type type) {
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
return dyn_cast_or_null<PackedStorageAttr>(tensorType.getEncoding()) !=
nullptr;
}
return false;
}

FailureOr<linalg::ContractionDimensions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace mlir::iree_compiler::IREE::Encoding {
EncodingAttr getEncodingAttr(RankedTensorType type);

/// Returns true if the type contains packed_storage attribute.
bool hasPackedStorageAttr(RankedTensorType type);
bool hasPackedStorageAttr(Type type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
// Aligns the element type of a tensor<> to a byte-aligned power of 2 bit width.
static RankedTensorType alignTensorType(RankedTensorType originalType) {
Type elementType = originalType.getElementType();
Type alignedType = legalizeStorageElementType(elementType);
Type alignedType = legalizeTensorStorageElementType(originalType);
if (alignedType == elementType)
return originalType;
return RankedTensorType::get(originalType.getShape(), alignedType,
Expand Down Expand Up @@ -620,7 +620,8 @@ struct EncodeHostTensorsPass
static IREE::Flow::DispatchTensorType
alignDispatchTensorType(IREE::Flow::DispatchTensorType originalType) {
Type elementType = originalType.getBoundElementType();
Type alignedType = legalizeStorageElementType(elementType);
Type alignedType =
legalizeTensorStorageElementType(originalType.asRankedTensorType());
if (alignedType == elementType)
return originalType;
return IREE::Flow::DispatchTensorType::get(
Expand Down
20 changes: 8 additions & 12 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
intType.getSignedness());
}

Type legalizeStorageElementType(Type elementType) {
// Consider packed storage for i1 tensors if cl opt is set.
return legalizeStorageElementTypeImpl(elementType,
/*isPackedStorage=*/false);
Type legalizeTensorStorageElementType(Type type) {
auto tensorType = llvm::dyn_cast<RankedTensorType>(type);
return legalizeStorageElementTypeImpl(
type, tensorType && IREE::Encoding::hasPackedStorageAttr(type));
}

Value calculateStorageElementCountInBytes(Location loc,
Expand All @@ -77,12 +77,10 @@ Value calculateStorageElementCountInBytes(Location loc,
loc, builder, shapedType, dynamicDims);
}

// TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
Type alignedElementType = legalizeStorageElementTypeImpl(
shapedType.getElementType(), isPackedStorage);
Type alignedElementType = legalizeTensorStorageElementType(shapedType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;

// Calculate all static dims first, if any.
Expand Down Expand Up @@ -123,12 +121,10 @@ Value calculateStorageElementOffsetInBytes(Location loc,
RankedTensorType originalType,
Value linearizedIndex,
OpBuilder &builder) {
// TODO: remove cl options once frontend can emit packed i1 tensors.
bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(originalType);
Type alignedElementType = legalizeStorageElementTypeImpl(
originalType.getElementType(), isPackedStorage);
Type alignedElementType = legalizeTensorStorageElementType(originalType);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(originalType);
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;

// Sub-byte packing requires putting multiple elements in the same byte.
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ bool needToPackSubByteElements(RankedTensorType shapedType);
/// cases.
Type legalizeStorageElementType(Type elementType);

Type legalizeTensorStorageElementType(Type tensorType);

/// Emits IR with the given |builder| to calculate the total number of bytes
/// required for the given |shapedType| in storage. Returns the value for the
/// final count on success; returns nullptr on failure. Dynamic dimensions in
Expand Down

0 comments on commit 9159d90

Please sign in to comment.