diff --git a/include/substrait-mlir/Dialect/Substrait/IR/Substrait.h b/include/substrait-mlir/Dialect/Substrait/IR/Substrait.h index 710e8539..1d5a602e 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/Substrait.h +++ b/include/substrait-mlir/Dialect/Substrait/IR/Substrait.h @@ -15,19 +15,38 @@ #include "mlir/IR/SymbolTable.h" // IWYU: keep #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU: keep -#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export +//===----------------------------------------------------------------------===// +// Substrait dialect +//===----------------------------------------------------------------------===// #include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export -#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export -#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export +//===----------------------------------------------------------------------===// +// Substrait enums +//===----------------------------------------------------------------------===// + +#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export + +//===----------------------------------------------------------------------===// +// Substrait types +//===----------------------------------------------------------------------===// +#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export #define GET_TYPEDEF_CLASSES #include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsTypes.h.inc" // IWYU: export +//===----------------------------------------------------------------------===// +// Substrait attributes +//===----------------------------------------------------------------------===// + #define GET_ATTRDEF_CLASSES #include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsAttrs.h.inc" // IWYU: export +//===----------------------------------------------------------------------===// +// Substrait ops +//===----------------------------------------------------------------------===// + +#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export #define GET_OP_CLASSES #include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td index 141d1a5a..1c6f795c 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td @@ -21,6 +21,26 @@ def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> { let cppNamespace = "::mlir::substrait"; } +def Substrait_ExtensibleOpInterface : OpInterface<"ExtensibleOpInterface"> { + let description = [{ + Interface for ops with the `advanced_extension` attribute. Several relations + and other message types of the Substrait specification have a field with the + same name (or the variant `advanced_extensions`, which has the same meaning) + and the interface enables handling all of them transparently. + }]; + let cppNamespace = "::mlir::substrait"; + let methods = [ + InterfaceMethod< + "Get the `advanced_extension` attribute", + "std::optional<::mlir::substrait::AdvancedExtensionAttr>", + "getAdvancedExtension">, + InterfaceMethod< + "Get the `advanced_extension` attribute", + "void", "setAdvancedExtensionAttr", + (ins "::mlir::substrait::AdvancedExtensionAttr":$attr)>, + ]; +} + def Substrait_RelOpInterface : OpInterface<"RelOpInterface"> { let description = [{ Interface for any relational operation in a Substrait plan. This corresponds diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index ad33a1c1..98afce96 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -150,6 +150,7 @@ def PlanBodyOp : AnyOf<[ def Substrait_PlanOp : Substrait_Op<"plan", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, NoTerminator, NoRegionArguments, SingleBlock, SymbolTable ]> { let summary = "Represents a Substrait plan"; @@ -177,9 +178,13 @@ def Substrait_PlanOp : Substrait_Op<"plan", [ let builders = [ OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{ build($_builder, $_state, major, minor, patch, - /*git_hash=*/StringAttr(), /*producer*/StringAttr(), + /*git_hash=*/"", /*producer*/""); + }]>, + OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch, + "std::string":$git_hash, "std::string":$producer), [{ + build($_builder, $_state, major, minor, patch, git_hash, producer, /*advanced_extension=*/AdvancedExtensionAttr()); - }]> + }]>, ]; let extraClassDefinition = [{ /// Implement OpAsmOpInterface. @@ -526,6 +531,7 @@ def Substrait_NamedTableOp : Substrait_RelOp<"named_table", [ def Substrait_ProjectOp : Substrait_RelOp<"project", [ SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { let summary = "Project operation"; @@ -550,14 +556,18 @@ def Substrait_ProjectOp : Substrait_RelOp<"project", [ } ``` }]; - let arguments = (ins Substrait_Relation:$input); + let arguments = (ins + Substrait_Relation:$input, + OptionalAttr:$advanced_extension + ); let regions = (region AnyRegion:$expressions); let results = (outs Substrait_Relation:$result); // TODO(ingomueller): We could elide/shorten the block argument from the // assembly by writing custom printers/parsers similar to // `scf.for` etc. let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) $expressions + $input (`advanced_extension` `` $advanced_extension^)? + attr-dict `:` type($input) `->` type($result) $expressions }]; let hasRegionVerifier = 1; let hasFolder = 1; diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 5373fd29..7cef1839 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -23,6 +23,7 @@ using namespace mlir; using namespace mlir::substrait; +using namespace mlir::substrait::protobuf_utils; using namespace ::substrait; using namespace ::substrait::proto; @@ -59,6 +60,8 @@ class SubstraitExporter { DECLARE_EXPORT_FUNC(RelOpInterface, Rel) DECLARE_EXPORT_FUNC(SetOp, Rel) + template + void exportAdvancedExtension(ExtensibleOpInterface op, MessageType &message); std::unique_ptr exportAny(StringAttr attr); FailureOr> exportOperation(Operation *op); FailureOr> exportType(Location loc, @@ -90,6 +93,36 @@ class SubstraitExporter { std::unique_ptr symbolTable; // Symbol table cache. }; +template +void SubstraitExporter::exportAdvancedExtension(ExtensibleOpInterface op, + MessageType &message) { + if (!op.getAdvancedExtension()) + return; + + // Build the base `AdvancedExtension` message. + AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value(); + auto extension = std::make_unique(); + + StringAttr optimizationAttr = extensionAttr.getOptimization(); + StringAttr enhancementAttr = extensionAttr.getEnhancement(); + + // Set `optimization` field if present. + if (optimizationAttr) { + std::unique_ptr optimization = exportAny(optimizationAttr); + extension->set_allocated_optimization(optimization.release()); + } + + // Set `enhancement` field if present. + if (enhancementAttr) { + std::unique_ptr enhancement = exportAny(enhancementAttr); + extension->set_allocated_enhancement(enhancement.release()); + } + + // Set the `advanced_extension` field in the provided message. + using Trait = advanced_extension_trait; + Trait::set_allocated_advanced_extension(message, extension.release()); +} + std::unique_ptr SubstraitExporter::exportAny(StringAttr attr) { auto any = std::make_unique(); auto anyType = mlir::cast(attr.getType()); @@ -837,26 +870,8 @@ FailureOr> SubstraitExporter::exportOperation(PlanOp op) { version->set_git_hash(op.getGitHash().str()); plan->set_allocated_version(version.release()); - // Build `AdvancedExtension` message. - if (op.getAdvancedExtension()) { - AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value(); - auto extension = std::make_unique(); - - StringAttr optimizationAttr = extensionAttr.getOptimization(); - StringAttr enhancementAttr = extensionAttr.getEnhancement(); - - if (optimizationAttr) { - std::unique_ptr optimization = exportAny(optimizationAttr); - extension->set_allocated_optimization(optimization.release()); - } - - if (enhancementAttr) { - std::unique_ptr enhancement = exportAny(enhancementAttr); - extension->set_allocated_enhancement(enhancement.release()); - } - - plan->set_allocated_advanced_extensions(extension.release()); - } + // Attach the `AdvancedExtension` message if the attribute exists. + exportAdvancedExtension(op, *plan); // Add `extension_uris` to plan. { @@ -987,6 +1002,9 @@ SubstraitExporter::exportOperation(ProjectOp op) { *projectRel->add_expressions() = *expression.value(); } + // Attach the `AdvancedExtension` message if the attribute exists. + exportAdvancedExtension(op, *projectRel); + // Build `Rel` message. auto rel = std::make_unique(); rel->set_allocated_project(projectRel.release()); diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 0e4fab3d..ebb631c5 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -25,6 +25,7 @@ using namespace mlir; using namespace mlir::substrait; +using namespace mlir::substrait::protobuf_utils; using namespace ::substrait; using namespace ::substrait::proto; @@ -64,6 +65,46 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp) +/// If present, imports the `advanced_extension` or `advanced_extensions` field +/// from the given message and sets the obtained attribute on the given op. +template +void importAdvancedExtension(ImplicitLocOpBuilder builder, + ExtensibleOpInterface op, + const MessageType &message); + +template +void importAdvancedExtension(ImplicitLocOpBuilder builder, + ExtensibleOpInterface op, + const MessageType &message) { + using Trait = advanced_extension_trait; + if (!Trait::has_advanced_extension(message)) + return; + + // Get the `advanced_extension(s)` field. + const extensions::AdvancedExtension &advancedExtension = + Trait::advanced_extension(message); + + // Import `optimization` field if present. + StringAttr optimizationAttr; + if (advancedExtension.has_optimization()) { + const pb::Any &optimization = advancedExtension.optimization(); + optimizationAttr = importAny(builder, optimization).value(); + } + + // Import `enhancement` field if present. + StringAttr enhancementAttr; + if (advancedExtension.has_enhancement()) { + const pb::Any &enhancement = advancedExtension.enhancement(); + enhancementAttr = importAny(builder, enhancement).value(); + } + + // Build attribute and set it on the op. + MLIRContext *context = builder.getContext(); + auto advancedExtensionAttr = + AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr); + op.setAdvancedExtensionAttr(advancedExtensionAttr); +} + FailureOr importAny(ImplicitLocOpBuilder builder, const pb::Any &message) { MLIRContext *context = builder.getContext(); @@ -480,34 +521,15 @@ static FailureOr importPlan(ImplicitLocOpBuilder builder, // Import version. const Version &version = message.version(); - // Import advanced extension. - AdvancedExtensionAttr advancedExtensionAttr; - if (message.has_advanced_extensions()) { - const extensions::AdvancedExtension &advancedExtension = - message.advanced_extensions(); - - StringAttr optimizationAttr; - if (advancedExtension.has_optimization()) { - const pb::Any &optimization = advancedExtension.optimization(); - optimizationAttr = importAny(builder, optimization).value(); - } - - StringAttr enhancementAttr; - if (advancedExtension.has_enhancement()) { - const pb::Any &enhancement = advancedExtension.enhancement(); - enhancementAttr = importAny(builder, enhancement).value(); - } - - advancedExtensionAttr = - AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr); - } - // Build `PlanOp`. auto planOp = builder.create( version.major_number(), version.minor_number(), version.patch_number(), - version.git_hash(), version.producer(), advancedExtensionAttr); + version.git_hash(), version.producer()); planOp.getBody().push_back(new Block()); + // Import advanced extension if it is present. + importAdvancedExtension(builder, planOp, message); + OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToEnd(&planOp.getBody().front()); @@ -668,6 +690,9 @@ static mlir::FailureOr importProjectRel(ImplicitLocOpBuilder builder, builder.create(resultType, inputOp.value()->getResult(0)); projectOp.getExpressions().push_back(conditionBlock.release()); + // Import advanced extension if it is present. + importAdvancedExtension(builder, projectOp, projectRel); + return projectOp; } diff --git a/lib/Target/SubstraitPB/ProtobufUtils.h b/lib/Target/SubstraitPB/ProtobufUtils.h index 426d0c8e..4ea76a34 100644 --- a/lib/Target/SubstraitPB/ProtobufUtils.h +++ b/lib/Target/SubstraitPB/ProtobufUtils.h @@ -9,6 +9,8 @@ #ifndef LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H #define LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H +#include + #include "mlir/IR/Location.h" namespace substrait::proto { @@ -28,6 +30,59 @@ getCommon(const ::substrait::proto::Rel &rel, Location loc); FailureOr<::substrait::proto::RelCommon *> getMutableCommon(::substrait::proto::Rel *rel, Location loc); +/// SFINAE-based template that checks if the given (message) type has an field +/// called `advanced_extension`: the `value` member is `true` iff it has. This +/// is useful to deal with the two different names, `advanced_extension` and +/// `advanced_extensions`, that are used for the same thing across different +/// message types in the Substrait spec. +template +class has_advanced_extensions { + template + static std::true_type test(decltype(&C::advanced_extensions)); + template + static std::false_type test(...); + +public: + static constexpr bool value = decltype(test(0))::value; +}; + +/// Trait class for accessing the `advanced_extension` field. The default +/// instances is automatically used for message types that call this field +/// `advanced_extension`; the specialization below is automatically used for +/// message types that call it `advanced_extensions`. +template +struct advanced_extension_trait { + static auto has_advanced_extension(const T &message) { + return message.has_advanced_extension(); + } + static auto advanced_extension(const T &message) { + return message.advanced_extension(); + } + template + static auto set_allocated_advanced_extension(T &message, + S &&advanced_extensions) { + message.set_allocated_advanced_extension( + std::forward(advanced_extensions)); + } +}; + +template +struct advanced_extension_trait< + T, std::enable_if_t::value>> { + static auto has_advanced_extension(const T &message) { + return message.has_advanced_extensions(); + } + static auto advanced_extension(const T &message) { + return message.advanced_extensions(); + } + template + static auto set_allocated_advanced_extension(T &message, + S &&advanced_extensions) { + message.set_allocated_advanced_extensions( + std::forward(advanced_extensions)); + } +}; + } // namespace mlir::substrait::protobuf_utils #endif // LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H diff --git a/test/Dialect/Substrait/project.mlir b/test/Dialect/Substrait/project.mlir index e5918f53..d537aba5 100644 --- a/test/Dialect/Substrait/project.mlir +++ b/test/Dialect/Substrait/project.mlir @@ -27,7 +27,7 @@ substrait.plan version 0 : 42 : 1 { // ----- -// CHECK: substrait.plan version 0 : 42 : 1 { +// CHECK: substrait.plan // CHECK-NEXT: relation // CHECK: %[[V0:.*]] = named_table // CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple -> tuple { @@ -44,3 +44,25 @@ substrait.plan version 0 : 42 : 1 { yield %1 : tuple } } + +// ----- + +// CHECK: substrait.plan version +// CHECK-NEXT: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] +// CHECK-SAME: advanced_extension optimization = "foo" : !substrait.any<"bar"> +// CHECK-SAME: tuple -> tuple { + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = project %0 + advanced_extension optimization = "foo" : !substrait.any<"bar"> + : tuple -> tuple { + ^bb0(%arg0: tuple): + yield + } + yield %1 : tuple + } +} diff --git a/test/Target/SubstraitPB/Export/project.mlir b/test/Target/SubstraitPB/Export/project.mlir index ca1be2df..ee33d55e 100644 --- a/test/Target/SubstraitPB/Export/project.mlir +++ b/test/Target/SubstraitPB/Export/project.mlir @@ -1,9 +1,12 @@ -// RUN: substrait-translate -substrait-to-protobuf %s \ +// RUN: substrait-translate -substrait-to-protobuf --split-input-file %s \ // RUN: | FileCheck %s // RUN: substrait-translate -substrait-to-protobuf %s \ +// RUN: --split-input-file --output-split-marker="# -----" \ // RUN: | substrait-translate -protobuf-to-substrait \ +// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \ // RUN: | substrait-translate -substrait-to-protobuf \ +// RUN: --split-input-file --output-split-marker="# -----" \ // RUN: | FileCheck %s // CHECK-LABEL: relations { @@ -36,3 +39,30 @@ substrait.plan version 0 : 42 : 1 { yield %1 : tuple } } + +// ----- + +// CHECK-LABEL: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: project { +// CHECK-NEXT: common { +// CHECK: input { +// CHECK: advanced_extension { +// CHECK-NEXT: optimization { +// CHECK-NEXT: type_url: "bar" +// CHECK-NEXT: value: "foo" +// CHECK-NEXT: } +// CHECK-NEXT: } + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = project %0 + advanced_extension optimization = "foo" : !substrait.any<"bar"> + : tuple -> tuple { + ^bb0(%arg0: tuple): + yield + } + yield %1 : tuple + } +} diff --git a/test/Target/SubstraitPB/Import/project.textpb b/test/Target/SubstraitPB/Import/project.textpb index e22072b6..d054f3c5 100644 --- a/test/Target/SubstraitPB/Import/project.textpb +++ b/test/Target/SubstraitPB/Import/project.textpb @@ -1,9 +1,13 @@ # RUN: substrait-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" \ # RUN: | FileCheck %s # RUN: substrait-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ # RUN: | substrait-translate -substrait-to-protobuf \ +# RUN: --split-input-file --output-split-marker="# ""-----" \ # RUN: | substrait-translate -protobuf-to-substrait \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ # RUN: | FileCheck %s # CHECK: substrait.plan version 0 : 42 : 1 { @@ -63,3 +67,54 @@ version { minor_number: 42 patch_number: 1 } + +# ----- + +# CHECK: substrait.plan version +# CHECK-NEXT: relation +# CHECK: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] +# CHECK-SAME: advanced_extension optimization = "foo" : !substrait.any<"bar"> + +relations { + rel { + project { + common { + direct { + } + } + input { + read { + common { + direct { + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } + advanced_extension { + optimization { + type_url: "bar" + value: "foo" + } + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +}