Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support shape_propagate for more ops #1302

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 232 additions & 2 deletions tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct DiscShapePropagatePass
DiscShapePropagatePassBase<DiscShapePropagatePass>::getDependentDialects(
registry);
registry.insert<shape::ShapeDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<mhlo::MhloDialect>();
}
void runOnOperation() override;
};
Expand All @@ -75,7 +78,9 @@ bool isBinaryOp(Operation* op) {
isa<mhlo::SelectOp>(*op) || isa<mhlo::ConvertOp>(*op);
}

bool isUnaryOp(Operation* op) { return isa<mhlo::ConvertOp>(op); }
bool isUnaryOp(Operation* op) {
return isa<mhlo::ConvertOp, mhlo::ScatterOp>(op);
}
bool isConcreteShape(ShapeContext& ctx) {
for (auto dim : ctx.shape) {
if (dim == ShapedType::kDynamic) return false;
Expand Down Expand Up @@ -136,10 +141,24 @@ std::optional<ShapeContext> propagateHelper(OpBuilder& b, Operation* op,
ShapeContext& inputCtx) {
return std::nullopt;
}

template <>
std::optional<ShapeContext> propagateHelper<tensor::DimOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto dim_op = dyn_cast_or_null<tensor::DimOp>(op);
if (!dim_op) return std::nullopt;

SmallVector<int64_t> new_shape(
op->getResult(0).getType().cast<RankedTensorType>().getShape());
return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::DotOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto dot_op = cast<mhlo::DotOp>(op);
auto dot_op = dyn_cast_or_null<mhlo::DotOp>(op);
if (!dot_op) return std::nullopt;

auto lhs_shape =
dot_op.getOperand(0).getType().cast<RankedTensorType>().getShape();
auto rhs_shape =
Expand All @@ -152,6 +171,210 @@ std::optional<ShapeContext> propagateHelper<mhlo::DotOp>(
return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::ConcatenateOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto concat_op = dyn_cast_or_null<mhlo::ConcatenateOp>(op);
if (!concat_op) return std::nullopt;

auto operands = op->getOperands();
SmallVector<int64_t> new_shape(
op->getResult(0).getType().cast<RankedTensorType>().getRank(),
ShapedType::kDynamic);
new_shape[concat_op.getDimension()] =
op->getResult(0)
.getType()
.cast<RankedTensorType>()
.getShape()[concat_op.getDimension()];

for (auto operand : operands) {
auto shape = operand.getType().cast<RankedTensorType>().getShape();
if (inputCtx.value == operand) {
shape = inputCtx.shape;
}

for (int dim_idx = 0; dim_idx < new_shape.size(); dim_idx++) {
if (dim_idx == concat_op.getDimension() &&
shape[dim_idx] == ShapedType::kDynamic) {
new_shape[dim_idx] = ShapedType::kDynamic;
} else if (dim_idx != concat_op.getDimension() &&
shape[dim_idx] != ShapedType::kDynamic) {
new_shape[dim_idx] = shape[dim_idx];
}
}
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::TransposeOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto transpose_op = dyn_cast_or_null<mhlo::TransposeOp>(op);
if (!transpose_op) return std::nullopt;

SmallVector<int64_t> new_shape;

for (auto it = transpose_op.getPermutation().begin();
it != transpose_op.getPermutation().end(); it++) {
int64_t src_dim = (*it).getSExtValue();
new_shape.push_back(inputCtx.shape[src_dim]);
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::ReduceOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto reduce_op = dyn_cast_or_null<mhlo::ReduceOp>(op);
if (!reduce_op) return std::nullopt;

SmallVector<int64_t> new_shape;

for (int dim = 0; dim < inputCtx.shape.size(); dim++) {
bool add_dim = true;
for (auto it = reduce_op.getDimensions().begin();
it != reduce_op.getDimensions().end(); it++) {
int64_t src_dim = (*it).getSExtValue();
add_dim = add_dim && !(dim == src_dim);
}
if (add_dim) {
new_shape.push_back(inputCtx.shape[dim]);
}
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::DynamicGatherOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto dynamic_gather_op = dyn_cast_or_null<mhlo::DynamicGatherOp>(op);
if (!dynamic_gather_op) return std::nullopt;

SmallVector<int64_t> new_shape(dynamic_gather_op.getResult()
.getType()
.cast<RankedTensorType>()
.getShape());

auto attr = dynamic_gather_op.getDimensionNumbers();
auto slice_sizes =
op->getOperand(2).getType().cast<RankedTensorType>().getShape();

auto offset_dims = attr.getOffsetDims();
auto index_vector_dim = attr.getIndexVectorDim();
auto collapsed_slice_dims = attr.getCollapsedSliceDims();

if (inputCtx.value == op->getOperand(1)) {
// start_indices
int shape_dim_idx = 0;
for (int dim_idx = 0; dim_idx < inputCtx.shape.size(); dim_idx++) {
if (dim_idx != index_vector_dim) {
new_shape[shape_dim_idx++] = inputCtx.shape[dim_idx];
}
}
} else if (inputCtx.value == op->getOperand(2)) {
int shape_dim_idx =
op->getOperand(0).getType().cast<RankedTensorType>().getRank() - 1;
for (int dim_idx = 0; dim_idx < inputCtx.shape.size(); dim_idx++) {
bool include_this_dim = true;
for (auto collapsed_slice_dim : collapsed_slice_dims) {
if (dim_idx == collapsed_slice_dim) {
include_this_dim = false;
}
}
if (include_this_dim) {
// need to decide whether it is a constant value or value from operand
new_shape[shape_dim_idx++] = inputCtx.shape[dim_idx];
}
}
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::GatherOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto gather_op = dyn_cast_or_null<mhlo::GatherOp>(op);
if (!gather_op) return std::nullopt;

// batch_dims = [d for d in axes(result) and d not in offset_dims].
auto attr = gather_op.getDimensionNumbers();
auto offset_dims = attr.getOffsetDims();
auto index_vector_dim = attr.getIndexVectorDim();
auto slice_sizes = gather_op.getSliceSizes();
auto collapsed_slice_dims = attr.getCollapsedSliceDims();
auto src_shape =
op->getOperand(0).getType().cast<RankedTensorType>().getShape();
SmallVector<Value> slice_sizes_vec;
SmallVector<int64_t> new_shape;
auto start_indices_shape =
op->getOperand(1).getType().cast<RankedTensorType>().getShape();

b.setInsertionPoint(op);
// process offset_dim_sizes, offset dims
for (int dim_idx = 0; dim_idx < start_indices_shape.size(); dim_idx++) {
if (dim_idx != index_vector_dim) {
new_shape.push_back(start_indices_shape[dim_idx]);
}
}

int dim_idx = 0;
for (auto dim_size : slice_sizes) {
bool include_this_dim = true;
for (auto collapsed_slice_dim : collapsed_slice_dims) {
if (dim_idx == collapsed_slice_dim) {
include_this_dim = false;
}
}
// need to decide whether it is a constant value or value from operand
if (src_shape[dim_idx] == dim_size.getSExtValue()) {
auto dim_value = b.create<tensor::DimOp>(op->getLoc(), op->getOperand(0),
b.create<arith::ConstantIndexOp>(
op->getLoc(), dim_idx)
.getResult())
.getResult();
slice_sizes_vec.push_back(
b.create<arith::IndexCastOp>(op->getLoc(), b.getI64Type(), dim_value)
.getResult());
} else {
slice_sizes_vec.push_back(b.create<arith::ConstantIntOp>(
op->getLoc(), dim_size.getSExtValue(), b.getI64Type()));
}

if (include_this_dim && src_shape[dim_idx] == dim_size.getSExtValue()) {
new_shape.push_back(ShapedType::kDynamic);
} else if (include_this_dim &&
src_shape[dim_idx] != dim_size.getSExtValue()) {
new_shape.push_back(dim_size.getSExtValue());
}

dim_idx += 1;
}

// create a dynamic gather op
auto dynamic_gather_op = b.create<mhlo::DynamicGatherOp>(
op->getLoc(),
RankedTensorType::get(new_shape, gather_op.getResult()
.getType()
.cast<RankedTensorType>()
.getElementType()),
op->getOperand(0), op->getOperand(1),
b.create<tensor::FromElementsOp>(op->getLoc(), slice_sizes_vec)
.getResult(),
mhlo::GatherDimensionNumbersAttr::get(
attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(),
attr.getStartIndexMap(), attr.getIndexVectorDim()),
gather_op.getIndicesAreSorted());
gather_op.getResult().replaceAllUsesWith(dynamic_gather_op.getResult());

// Update DynamicGatherOp result shape information
return propagateHelper<mhlo::DynamicGatherOp>(
b, dynamic_gather_op.getOperation(), inputCtx);
}

LogicalResult parseInputDynamicDims(
func::FuncOp main,
std::vector<std::pair<int, std::vector<int>>>& input_dynamic_dims) {
Expand Down Expand Up @@ -203,7 +426,13 @@ std::optional<ShapeContext> propagateOpShape(OpBuilder& rewriter, Operation* op,
using PropagationFunc =
std::optional<ShapeContext> (*)(OpBuilder&, Operation*, ShapeContext&);
const std::vector<PropagationFunc> propagationFunctions = {
propagateHelper<mhlo::ConcatenateOp>,
propagateHelper<mhlo::DotOp>,
propagateHelper<mhlo::DynamicGatherOp>,
propagateHelper<mhlo::GatherOp>,
propagateHelper<mhlo::ReduceOp>,
propagateHelper<mhlo::TransposeOp>,
propagateHelper<tensor::DimOp>,
};
// Iterate over the propagation functions and apply each one
for (const auto& propagate : propagationFunctions) {
Expand All @@ -226,6 +455,7 @@ void visitOperator(ModuleOp& m, OpBuilder& rewriter, Operation* op,
op->getName().stripDialect().str());
return;
}

for (auto user : op->getResult(0).getUsers()) {
visitOperator(m, rewriter, user, resultShapeCtx.value());
}
Expand Down
Loading
Loading