Skip to content

Commit

Permalink
[DT][NFC] Remove duplicated code for existing layout checks (#20117)
Browse files Browse the repository at this point in the history
I removed the duplicated code in the implementation of
`CPUDeviceEncodingLayoutResolverAttrInterface`,
`VMVXDeviceEncodingLayoutResolverAttrInterface`, and
`GPUDeviceEncodingLayoutResolverAttrInterface`. I provide a base class
`WrappedExternalModel` to provide `getConfiguration` and
`getEncodingInfo`, so the derived class only needs to implement their
own `getEncodingInfoImpl` method.

Fixing #20050

---------

Signed-off-by: Jinjie Liu <[email protected]>
  • Loading branch information
sgjzfzzf authored Mar 4, 2025
1 parent 6ba9cce commit 3692165
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -590,22 +590,16 @@ enumerateCPUMatmulTiles(IREE::Encoding::EncodingAttr encoding,
}

struct CPUDeviceEncodingLayoutResolverAttrInterface
: public Codegen::LayoutAttrInterface::ExternalModel<
: public DeviceEncodingLayoutResolverExternalModelBase<
CPUDeviceEncodingLayoutResolverAttrInterface, CPUEncodingLayoutAttr> {
MaterializeEncodingInfo getEncodingInfo(Attribute attr,
RankedTensorType type) const {
auto layoutAttr = cast<CPUEncodingLayoutAttr>(attr);

// If the layout is already resolved, use it directly.
if (auto config = layoutAttr.getConfiguration()) {
if (auto namedAttr = config.getNamed(kEncodingInfoAttrName)) {
std::optional<MaterializeEncodingInfo> info =
Codegen::deserializeEncodingInfo(
cast<DictionaryAttr>(namedAttr->getValue()));
assert(info && "encoding_info is invalid");
return info.value();
}
}
DictionaryAttr getConfiguration(Attribute attr) const {
return cast<CPUEncodingLayoutAttr>(attr).getConfiguration();
}

MaterializeEncodingInfo getEncodingInfoImpl(Attribute attr,
RankedTensorType type) const {
auto layoutAttr = cast<CPUEncodingLayoutAttr>(attr);

auto encoding = llvm::dyn_cast_or_null<IREE::Encoding::EncodingAttr>(
type.getEncoding());
Expand Down Expand Up @@ -726,23 +720,17 @@ enumerateVMVXMatmulTiles(linalg::ContractionDimensions cDims,
}

struct VMVXDeviceEncodingLayoutResolverAttrInterface final
: Codegen::LayoutAttrInterface::ExternalModel<
: DeviceEncodingLayoutResolverExternalModelBase<
VMVXDeviceEncodingLayoutResolverAttrInterface,
VMVXEncodingLayoutAttr> {
MaterializeEncodingInfo getEncodingInfo(Attribute attr,
RankedTensorType type) const {
auto layoutAttr = cast<VMVXEncodingLayoutAttr>(attr);

// If the layout is already resolved, use it directly.
if (auto config = layoutAttr.getConfiguration()) {
if (auto namedAttr = config.getNamed(kEncodingInfoAttrName)) {
std::optional<MaterializeEncodingInfo> info =
Codegen::deserializeEncodingInfo(
cast<DictionaryAttr>(namedAttr->getValue()));
assert(info && "encoding_info is invalid");
return info.value();
}
}
DictionaryAttr getConfiguration(Attribute attr) const {
return cast<VMVXEncodingLayoutAttr>(attr).getConfiguration();
}

MaterializeEncodingInfo getEncodingInfoImpl(Attribute attr,
RankedTensorType type) const {
auto layoutAttr = cast<VMVXEncodingLayoutAttr>(attr);

auto encoding = llvm::dyn_cast_or_null<IREE::Encoding::EncodingAttr>(
type.getEncoding());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,14 @@ static Operation *lowerContractionOpToMultiMmaOp(OpBuilder &builder,
}

struct GPUDeviceEncodingLayoutResolverAttrInterface
: public Codegen::LayoutAttrInterface::ExternalModel<
: public DeviceEncodingLayoutResolverExternalModelBase<
GPUDeviceEncodingLayoutResolverAttrInterface, GPUEncodingLayoutAttr> {
MaterializeEncodingInfo getEncodingInfo(Attribute attr,
RankedTensorType type) const {
DictionaryAttr getConfiguration(Attribute attr) const {
return cast<GPUEncodingLayoutAttr>(attr).getConfiguration();
}

MaterializeEncodingInfo getEncodingInfoImpl(Attribute attr,
RankedTensorType type) const {
auto layoutAttr = cast<GPUEncodingLayoutAttr>(attr);
DictionaryAttr config = layoutAttr.getConfiguration();

Expand All @@ -315,18 +319,6 @@ struct GPUDeviceEncodingLayoutResolverAttrInterface
return info;
}

// If the layout is already resolved, use it directly.
if (config) {
if (std::optional<NamedAttribute> namedAttr =
config.getNamed(kEncodingInfoAttrName)) {
std::optional<MaterializeEncodingInfo> preresolvedInfo =
Codegen::deserializeEncodingInfo(
cast<DictionaryAttr>(namedAttr->getValue()));
assert(preresolvedInfo && "encoding_info is invalid");
return preresolvedInfo.value();
}
}

IREE::GPU::TargetAttr gpuAttr = getGPUTargetAttr(config);
if (!gpuAttr) {
return info;
Expand Down
31 changes: 31 additions & 0 deletions compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#ifndef IREE_COMPILER_CODEGEN_EXTERNALINTERFACES_UTILS_H_
#define IREE_COMPILER_CODEGEN_EXTERNALINTERFACES_UTILS_H_

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -16,6 +18,35 @@ namespace mlir::iree_compiler::IREE {

static const char kEncodingInfoAttrName[] = "encoding_info";

// This class is the base class for the external model of different encoding
// resolver attributes. It provides a public method, `getEncodingInfo` to reduce
// the duplicated implementations before. To inherit it, it requires the derived
// class to implement the `getConfiguration` method and the
// `getEncodingInfoImpl` method.
template <typename DeviceEncodingLayoutResolverAttrInterface,
typename EncodingLayoutAttr>
struct DeviceEncodingLayoutResolverExternalModelBase
: public Codegen::LayoutAttrInterface::ExternalModel<
DeviceEncodingLayoutResolverAttrInterface, EncodingLayoutAttr> {
public:
Codegen::MaterializeEncodingInfo
getEncodingInfo(Attribute attr, RankedTensorType type) const {
const DeviceEncodingLayoutResolverAttrInterface *impl =
static_cast<const DeviceEncodingLayoutResolverAttrInterface *>(this);
// If the layout is already resolved, use it directly.
if (auto config = impl->getConfiguration(attr)) {
if (auto namedAttr = config.getNamed(kEncodingInfoAttrName)) {
std::optional<Codegen::MaterializeEncodingInfo> info =
Codegen::deserializeEncodingInfo(
cast<DictionaryAttr>(namedAttr->getValue()));
assert(info && "encoding_info is invalid");
return info.value();
}
}
return impl->getEncodingInfoImpl(attr, type);
}
};

/// Calculates the storage size in bytes for the given `type` with a layout
/// encoding `attr`.
/// Requirement: `attr` must implement IREE::Codegen::LayoutAttrInterface.
Expand Down

0 comments on commit 3692165

Please sign in to comment.