Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Nov 15, 2024
1 parent d9032f8 commit 84ee293
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 267 deletions.
2 changes: 1 addition & 1 deletion include/imex/Dialect/NDArray/IR/NDArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def ReshapeOp : NDArray_Op<"reshape", []> {
def CastElemTypeOp: NDArray_Op<"cast_elemtype", [Pure]> {
let summary = "Cast array from one element type to another";

let arguments = (ins AnyType:$input, OptionalAttr<I1Attr>:$copy);
let arguments = (ins AnyRankedTensor:$input, OptionalAttr<I1Attr>:$copy);
let results = (outs AnyRankedTensor:$output);

let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `to` qualified(type($output))";
Expand Down
180 changes: 16 additions & 164 deletions lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,6 @@ namespace imex {

namespace imex {

// /// @return type without a sign
// static mlir::Type makeSignlessType(mlir::Type type) {
// if (auto intType = mlir::dyn_cast<mlir::IntegerType>(type)) {
// if (!intType.isSignless())
// return mlir::IntegerType::get(intType.getContext(),
// intType.getWidth());
// }
// return type;
// }

// /// @return operand cast to signless type if needed, val if not
// static mlir::Value doSignCast(mlir::OpBuilder &builder, mlir::Location &loc,
// mlir::Value val) {
// auto origType = val.getType();
// auto signlessType = makeSignlessType(origType);
// if (signlessType != origType) {
// val =
// builder
// .create<::mlir::UnrealizedConversionCastOp>(loc, signlessType,
// val) .getResult(0);
// }
// return val;
// }

/// Create a linalg generic op from given output, input and body
template <typename V, typename B>
auto createParFor(mlir::Location &loc, mlir::OpBuilder &builder, uint64_t rank,
Expand Down Expand Up @@ -100,13 +76,8 @@ struct CopyLowering : public ::mlir::OpRewritePattern<::imex::ndarray::CopyOp> {
matchAndRewrite(::imex::ndarray::CopyOp op,
::mlir::PatternRewriter &rewriter) const override {
// check output type and get operands
auto srcArTyp =
mlir::dyn_cast<::mlir::RankedTensorType>(op.getSource().getType());
auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType());
if (!(srcArTyp && retArTyp)) {
return ::mlir::failure();
}

auto srcArTyp = op.getSource().getType();
auto retArTyp = op.getType();
auto loc = op.getLoc();
auto src = op.getSource();
auto rank = srcArTyp.getRank();
Expand Down Expand Up @@ -157,25 +128,18 @@ struct ReshapeLowering
matchAndRewrite(::imex::ndarray::ReshapeOp op,
::mlir::PatternRewriter &rewriter) const override {
// check output type and get operands
auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType());
auto srcArTyp =
mlir::dyn_cast<::mlir::RankedTensorType>(op.getSource().getType());
if (!(retArTyp && srcArTyp)) {
return ::mlir::failure();
}

auto loc = op.getLoc();
auto src = op.getSource();
auto shape = op.getShape();

if (op.getCopy().value_or(false)) {
src = rewriter.create<::imex::ndarray::CopyOp>(loc, srcArTyp,
op.getSource());
src = rewriter.create<::imex::ndarray::CopyOp>(
loc, op.getSource().getType(), op.getSource());
}

auto shapeT = rewriter.create<::mlir::tensor::FromElementsOp>(loc, shape);
rewriter.replaceOpWithNewOp<::mlir::tensor::ReshapeOp>(op, retArTyp, src,
shapeT);
rewriter.replaceOpWithNewOp<::mlir::tensor::ReshapeOp>(op, op.getType(),
src, shapeT);

return ::mlir::success();
}
Expand All @@ -195,12 +159,7 @@ struct SubviewLowering
auto loc = op->getLoc();

// convert src array to memref
auto srcArType =
mlir::dyn_cast_or_null<::mlir::ShapedType>(op.getSource().getType());
auto resType = mlir::dyn_cast_or_null<::mlir::ShapedType>(op.getType());
if (!resType || !srcArType)
return mlir::failure();

auto srcArType = srcTnsr.getType();
auto srcMRType = imex::getMemRefType(op.getContext(), srcArType.getShape(),
srcArType.getElementType());
auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType);
Expand All @@ -214,7 +173,7 @@ struct SubviewLowering

auto resMRType = mlir::cast<::mlir::MemRefType>(
::mlir::memref::SubViewOp::inferRankReducedResultType(
resType.getShape(), srcMRType, offsets, sizes, strides));
op.getType().getShape(), srcMRType, offsets, sizes, strides));

auto sw = rewriter.create<::mlir::memref::SubViewOp>(
loc, resMRType, srcMR, offsets, sizes, strides);
Expand Down Expand Up @@ -269,11 +228,8 @@ struct InsertSliceLowering
// get operators
auto src = op.getSource();
auto dst = op.getDestination();
auto srcTyp = mlir::dyn_cast<::mlir::ShapedType>(src.getType());
auto dstTyp = mlir::dyn_cast<::mlir::ShapedType>(dst.getType());
if (!dstTyp || !srcTyp)
return ::mlir::failure();

auto srcTyp = src.getType();
auto dstTyp = dst.getType();
auto srcMRTyp = getMemRefType(op.getContext(), srcTyp.getShape(),
srcTyp.getElementType());
auto dstMRTyp = getMemRefType(op.getContext(), dstTyp.getShape(),
Expand Down Expand Up @@ -362,7 +318,7 @@ struct LinSpaceLowering
auto stop = op.getStop();
auto count = op.getNum();
bool endpoint = op.getEndpoint();
auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType());
auto retArTyp = op.getType();
auto rank = retArTyp.getRank();
auto elTyp = retArTyp.getElementType();

Expand Down Expand Up @@ -421,9 +377,7 @@ struct CreateLowering
auto loc = op.getLoc();

// check output type and get operands
auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType());
if (!retArTyp)
return ::mlir::failure();
auto retArTyp = op.getType();
auto value = op.getValue();

// init tensor
Expand Down Expand Up @@ -453,16 +407,9 @@ struct DeleteLowering
::mlir::LogicalResult
matchAndRewrite(::imex::ndarray::DeleteOp op,
::mlir::PatternRewriter &rewriter) const override {
// check output type and get operands
auto inpArType =
mlir::dyn_cast<::mlir::RankedTensorType>(op.getInput().getType());
if (!inpArType) {
return ::mlir::failure();
}

auto inp = op.getInput();
auto inpMR =
createToMemRef(op.getLoc(), rewriter, inp, getMemRefType(inpArType));
auto inpMR = createToMemRef(op.getLoc(), rewriter, inp,
getMemRefType(op.getInput().getType()));
auto newOp =
rewriter.replaceOpWithNewOp<::mlir::memref::DeallocOp>(op, inpMR);
newOp->setAttrs(op->getAttrs());
Expand All @@ -481,12 +428,8 @@ struct CastElemTypeLowering
::mlir::PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto src = op.getInput();
auto srcArType =
mlir::dyn_cast<::mlir::RankedTensorType>(op.getInput().getType());
auto dstArType = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType());
if (!(srcArType && dstArType)) {
return ::mlir::failure();
}
auto srcArType = op.getInput().getType();
auto dstArType = op.getType();

// verify identical shape
assert(dstArType.getRank() == srcArType.getRank());
Expand Down Expand Up @@ -542,71 +485,6 @@ struct ConvertNDArrayToLinalgPass

void runOnOperation() override {
auto &ctxt = getContext();
// ::mlir::TypeConverter typeConverter;
// // Convert unknown types to itself
// auto convT2T = [](::mlir::Type type) { return type; };
// // Convert NDArrayType to (tensorType)
// auto convNDArray2RankedTensor =
// [](::imex::ndarray::NDArrayType type) -> std::optional<::mlir::Type>
// {
// return type.getTensorType();
// };

// typeConverter.addConversion(convT2T);
// typeConverter.addConversion(convNDArray2RankedTensor);

// auto materializeCast =
// [](::mlir::OpBuilder &builder, ::mlir::Type type,
// ::mlir::ValueRange inputs,
// ::mlir::Location loc) -> ::mlir::Value {
// if (inputs.size() == 1) {
// auto input = inputs[0];
// auto itype = input.getType();
// if (mlir::isa<::mlir::TensorType>(type) and
// mlir::isa<::mlir::TensorType>(itype)) {
// return builder.create<::mlir::tensor::CastOp>(loc, type, inputs)
// .getResult();
// }
// auto ttype = mlir::dyn_cast<::mlir::RankedTensorType>(itype);
// if (ttype && mlir::isa<::mlir::MemRefType>(type)) {
// return createToMemRef(loc, builder, input, type);
// }
// auto mrtype = mlir::dyn_cast<::mlir::MemRefType>(itype);
// if (mrtype && mlir::isa<::mlir::RankedTensorType>(type)) {
// return builder
// .create<::mlir::bufferization::ToTensorOp>(loc, type, input,
// /*restrict=*/true)
// .getResult();
// }
// }
// return builder
// .create<::mlir::UnrealizedConversionCastOp>(loc, type, inputs)
// .getResult(0);
// };
// typeConverter.addSourceMaterialization(materializeCast);
// typeConverter.addTargetMaterialization(materializeCast);

// // At function boundaries we have actual memref semantics.
// // We need to explicitly convert in/out arguments to memrefs.
// // If we use tensors downstream passes will auto-convert to non-strided
// // memrefs which will imply a copy (converting from strided to
// non-strided
// // requires a copy)
// // We simply use a separate type-converter and materializations

// ::mlir::TypeConverter typeConverter4Func;
// // Convert NDArrayType to MemRefType
// auto convNDArray2MemRef =
// [](::imex::ndarray::NDArrayType type) -> std::optional<::mlir::Type>
// {
// return type.getMemRefType();
// };

// typeConverter4Func.addConversion(convT2T);
// typeConverter4Func.addConversion(convNDArray2MemRef);
// typeConverter4Func.addSourceMaterialization(materializeCast);
// typeConverter4Func.addTargetMaterialization(materializeCast);

::mlir::ConversionTarget target(ctxt);
// We convert all NDArray stuff...
target.addIllegalDialect<::imex::ndarray::NDArrayDialect>();
Expand All @@ -616,38 +494,12 @@ struct ConvertNDArrayToLinalgPass
::mlir::memref::MemRefDialect, ::mlir::tensor::TensorDialect,
::mlir::bufferization::BufferizationDialect,
::imex::region::RegionDialect>();
// target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME

// // make sure function boundaries use tensors (not NDArrays)
// target.addDynamicallyLegalOp<::mlir::func::FuncOp>(
// [&](::mlir::func::FuncOp op) {
// return typeConverter4Func.isSignatureLegal(op.getFunctionType()) &&
// typeConverter4Func.isLegal(&op.getBody());
// });
// target.addDynamicallyLegalOp<::mlir::func::ReturnOp, mlir::func::CallOp>(
// [&](mlir::Operation *op) { return typeConverter4Func.isLegal(op); });

// target.addDynamicallyLegalOp<::imex::region::EnvironmentRegionOp,
// ::imex::region::EnvironmentRegionYieldOp>(
// [&](mlir::Operation *op) { return typeConverter.isLegal(op); });

::mlir::RewritePatternSet patterns(&ctxt);
patterns
.insert<SubviewLowering, ExtractSliceLowering, InsertSliceLowering,
ImmutableInsertSliceLowering, LinSpaceLowering, CreateLowering,
CopyLowering, DeleteLowering, CastElemTypeLowering>(&ctxt);
// ::imex::populateRegionTypeConversionPatterns(patterns, typeConverter);

// // populate function boundaries using our special type converter
// ::mlir::populateFunctionOpInterfaceTypeConversionPattern<
// ::mlir::func::FuncOp>(patterns, typeConverter4Func);
// ::mlir::populateReturnOpTypeConversionPattern(patterns,
// typeConverter4Func);
// ::mlir::populateCallOpTypeConversionPattern(patterns,
// typeConverter4Func);

// ::mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
// typeConverter, patterns, target);

if (::mlir::failed(::mlir::applyPartialConversion(getOperation(), target,
::std::move(patterns)))) {
Expand Down
Loading

0 comments on commit 84ee293

Please sign in to comment.