From c01f1835e0433a22222c2b2600a2fa1ff05d9627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 21 Jan 2025 10:03:08 +0100 Subject: [PATCH] feat: implement `AdvancedExtension` and at it to `Plan` (#58) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements the `AdvanceExtension` message type as a custom attribute with two `StringAttr`s as well as an `AnyType` with a parameter for the type URL to represent the `Any` messages as typed `StringAttr`s. The PR also adds those attributes to the `PlanOp` in order to represent the `advanced_extensions` message of the `Plan` message. We do not currently need this attribute but it is one of the few missing fields of the `Plan` message as well as many other messages we have otherwise implemented, so it allows us to get to higher coverage with relatively little effort. Signed-off-by: Ingo Müller --- .../Dialect/Substrait/IR/SubstraitOps.td | 7 ++- .../Dialect/Substrait/IR/SubstraitTypes.td | 26 ++++++++ lib/Dialect/Substrait/IR/Substrait.cpp | 14 +++++ lib/Target/SubstraitPB/Export.cpp | 38 +++++++++++- lib/Target/SubstraitPB/Import.cpp | 36 ++++++++++- test/Dialect/Substrait/plan-invalid.mlir | 16 +++++ test/Dialect/Substrait/plan.mlir | 14 +++++ test/Target/SubstraitPB/Export/plan.mlir | 44 ++++++++++++++ test/Target/SubstraitPB/Import/plan.textpb | 59 +++++++++++++++++++ 9 files changed, 248 insertions(+), 6 deletions(-) diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index 78ec7c36..ad33a1c1 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -164,18 +164,21 @@ def Substrait_PlanOp : Substrait_Op<"plan", [ UI32Attr:$minor_number, UI32Attr:$patch_number, DefaultValuedAttr:$git_hash, - DefaultValuedAttr:$producer + DefaultValuedAttr:$producer, + OptionalAttr:$advanced_extension ); let regions = (region RegionOf:$body); let assemblyFormat = [{ `version` $major_number `:` $minor_number `:` $patch_number (`git_hash` $git_hash^)? (`producer` $producer^)? + (`advanced_extension` `` $advanced_extension^)? attr-dict-with-keyword $body }]; let builders = [ OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{ build($_builder, $_state, major, minor, patch, - StringAttr(), StringAttr()); + /*git_hash=*/StringAttr(), /*producer*/StringAttr(), + /*advanced_extension=*/AdvancedExtensionAttr()); }]> ]; let extraClassDefinition = [{ diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td index b4f7fa35..d5dfc768 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td @@ -120,6 +120,32 @@ def Substrait_AtomicAttributes { ]; } +def Substrait_AnyType : Substrait_Type<"Any", "any"> { + let summary = "Represents the `type_url` of a `google.protobuf.Any` message"; + let description = [{ + This type represents the `type_url` fields of a `google.protobuf.Any` + message. These messages consist of an opaque byte array and a string holding + the URL identifying the type of what is contained in the byte array. + }]; + let parameters = (ins "StringAttr":$type_url); + let assemblyFormat = "`<` $type_url `>`"; + +} + +def Substrait_AdvancedExtensionAttr + : Substrait_Attr<"AdvancedExtension", "advanced_extension"> { + let summary = "Represents the `AdvancedExtenssion` message of Substrait"; + let parameters = (ins + OptionalParameter<"StringAttr">:$optimization, // XXX: verify type + OptionalParameter<"StringAttr">:$enhancement + ); + let assemblyFormat = [{ + ( `optimization` `=` $optimization^ )? + ( `enhancement` `=` $enhancement^ )? + }]; + let genVerifyDecl = 1; +} + /// Attribute of one of the currently supported atomic types. def Substrait_AtomicAttribute : AnyAttrOf; diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index 507d822a..7a306bc6 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -38,6 +38,20 @@ void SubstraitDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// Substrait attributes +//===----------------------------------------------------------------------===// + +LogicalResult AdvancedExtensionAttr::verify( + llvm::function_ref emitError, + mlir::StringAttr optimization, mlir::StringAttr enhancement) { + if (optimization && !mlir::isa(optimization.getType())) + return emitError() << "has 'optimization' attribute of wrong type"; + if (enhancement && !mlir::isa(enhancement.getType())) + return emitError() << "has 'enhancement' attribute of wrong type"; + return success(); +} + //===----------------------------------------------------------------------===// // Substrait enums //===----------------------------------------------------------------------===// diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 15697eec..5373fd29 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -59,6 +59,7 @@ class SubstraitExporter { DECLARE_EXPORT_FUNC(RelOpInterface, Rel) DECLARE_EXPORT_FUNC(SetOp, Rel) + std::unique_ptr exportAny(StringAttr attr); FailureOr> exportOperation(Operation *op); FailureOr> exportType(Location loc, mlir::Type mlirType); @@ -89,6 +90,16 @@ class SubstraitExporter { std::unique_ptr symbolTable; // Symbol table cache. }; +std::unique_ptr SubstraitExporter::exportAny(StringAttr attr) { + auto any = std::make_unique(); + auto anyType = mlir::cast(attr.getType()); + std::string typeUrl = anyType.getTypeUrl().getValue().str(); + std::string value = attr.getValue().str(); + any->set_type_url(typeUrl); + any->set_value(value); + return any; +} + std::unique_ptr exportIntegerType(mlir::Type mlirType, MLIRContext *context) { // Function that handles `IntegerType`'s. @@ -814,6 +825,9 @@ FailureOr> SubstraitExporter::exportOperation(PlanOp op) { using extensions::SimpleExtensionDeclaration; using extensions::SimpleExtensionURI; + // Build `Plan` message. + auto plan = std::make_unique(); + // Build `Version` message. auto version = std::make_unique(); version->set_major_number(op.getMajorNumber()); @@ -821,11 +835,29 @@ FailureOr> SubstraitExporter::exportOperation(PlanOp op) { version->set_patch_number(op.getPatchNumber()); version->set_producer(op.getProducer().str()); version->set_git_hash(op.getGitHash().str()); - - // Build `Plan` message. - auto plan = std::make_unique(); plan->set_allocated_version(version.release()); + // Build `AdvancedExtension` message. + if (op.getAdvancedExtension()) { + AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value(); + auto extension = std::make_unique(); + + StringAttr optimizationAttr = extensionAttr.getOptimization(); + StringAttr enhancementAttr = extensionAttr.getEnhancement(); + + if (optimizationAttr) { + std::unique_ptr optimization = exportAny(optimizationAttr); + extension->set_allocated_optimization(optimization.release()); + } + + if (enhancementAttr) { + std::unique_ptr enhancement = exportAny(enhancementAttr); + extension->set_allocated_enhancement(enhancement.release()); + } + + plan->set_allocated_advanced_extensions(extension.release()); + } + // Add `extension_uris` to plan. { AnchorUniquer anchorUniquer("extension_uri.", anchorsByOp); diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 0ac1fdf0..0e4fab3d 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -46,6 +46,7 @@ namespace { static FailureOr import##MESSAGE_TYPE(ImplicitLocOpBuilder builder, \ const ARG_TYPE &message); +DECLARE_IMPORT_FUNC(Any, pb::Any, StringAttr) DECLARE_IMPORT_FUNC(CrossRel, Rel, CrossOp) DECLARE_IMPORT_FUNC(FetchRel, Rel, FetchOp) DECLARE_IMPORT_FUNC(FilterRel, Rel, FilterOp) @@ -63,6 +64,14 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp) +FailureOr importAny(ImplicitLocOpBuilder builder, + const pb::Any &message) { + MLIRContext *context = builder.getContext(); + auto typeUrlAttr = StringAttr::get(context, message.type_url()); + auto anyType = AnyType::get(context, typeUrlAttr); + return StringAttr::get(message.value(), anyType); +} + // Helpers to build symbol names from anchors deterministically. This allows // to reate symbol references from anchors without look-up structure. Also, // the format is exploited by the export logic to recover the original anchor @@ -468,10 +477,35 @@ static FailureOr importPlan(ImplicitLocOpBuilder builder, MLIRContext *context = builder.getContext(); Location loc = UnknownLoc::get(context); + // Import version. const Version &version = message.version(); + + // Import advanced extension. + AdvancedExtensionAttr advancedExtensionAttr; + if (message.has_advanced_extensions()) { + const extensions::AdvancedExtension &advancedExtension = + message.advanced_extensions(); + + StringAttr optimizationAttr; + if (advancedExtension.has_optimization()) { + const pb::Any &optimization = advancedExtension.optimization(); + optimizationAttr = importAny(builder, optimization).value(); + } + + StringAttr enhancementAttr; + if (advancedExtension.has_enhancement()) { + const pb::Any &enhancement = advancedExtension.enhancement(); + enhancementAttr = importAny(builder, enhancement).value(); + } + + advancedExtensionAttr = + AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr); + } + + // Build `PlanOp`. auto planOp = builder.create( version.major_number(), version.minor_number(), version.patch_number(), - version.git_hash(), version.producer()); + version.git_hash(), version.producer(), advancedExtensionAttr); planOp.getBody().push_back(new Block()); OpBuilder::InsertionGuard insertGuard(builder); diff --git a/test/Dialect/Substrait/plan-invalid.mlir b/test/Dialect/Substrait/plan-invalid.mlir index 2add9259..45fa61ca 100644 --- a/test/Dialect/Substrait/plan-invalid.mlir +++ b/test/Dialect/Substrait/plan-invalid.mlir @@ -40,3 +40,19 @@ substrait.plan version 0 : 42 : 1 { // expected-error@+1 {{'substrait.extension_function' op refers to @function.1, which is not a valid 'uri' op}} extension_function @function.2 at @function.1["somefunc"] } + +// ----- + +// Test error if the `enhancement` attribute has the wrong/no type. +substrait.plan version 0 : 42 : 1 + // expected-error@+1 {{custom op 'substrait.plan' has 'enhancement' attribute of wrong type}} + advanced_extension enhancement = "blup" +{} + +// ----- + +// Test error if the `optimization` attribute has the wrong/no type. +substrait.plan version 0 : 42 : 1 + // expected-error@+1 {{custom op 'substrait.plan' has 'optimization' attribute of wrong type}} + advanced_extension optimization = "blup" +{} diff --git a/test/Dialect/Substrait/plan.mlir b/test/Dialect/Substrait/plan.mlir index 41bb27fd..ae043bad 100644 --- a/test/Dialect/Substrait/plan.mlir +++ b/test/Dialect/Substrait/plan.mlir @@ -86,3 +86,17 @@ substrait.plan version 0 : 42 : 1 { extension_uri @other.extension at "http://other.url/with/more/extensions.yml" extension_function @other.function at @other.extension["someotherfunc"] } + +// ----- + +// CHECK: substrait.plan +// CHECK-SAME: advanced_extension +// CHECK-SAME: optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> +// CHECK-SAME: enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> +// CHECK-NEXT: } + +substrait.plan version 0 : 42 : 1 + advanced_extension + optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> + enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> +{} diff --git a/test/Target/SubstraitPB/Export/plan.mlir b/test/Target/SubstraitPB/Export/plan.mlir index 7d47394a..9397b0f1 100644 --- a/test/Target/SubstraitPB/Export/plan.mlir +++ b/test/Target/SubstraitPB/Export/plan.mlir @@ -138,3 +138,47 @@ substrait.plan version 0 : 42 : 1 { // If not handled carefully, parsing this symbol into an anchor may clash. extension_uri @extension_uri.0 at "http://other.url/with/more/extensions.yml" } + +// ----- + +// CHECK: advanced_extensions { +// CHECK-NEXT: optimization { +// CHECK-NEXT: type_url: "http://some.url/with/type.proto" +// CHECK-NEXT: value: "protobuf message" +// CHECK-NEXT: } +// CHECK-NEXT: enhancement { +// CHECK-NEXT: type_url: "http://other.url/with/type.proto" +// CHECK-NEXT: value: "other protobuf message" +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: version + +substrait.plan version 0 : 42 : 1 + advanced_extension + optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> + enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> +{} + +// ----- + +// CHECK: advanced_extensions { +// CHECK-NEXT: optimization { +// CHECK-NOT: enhancement { +// CHECK-: version + +substrait.plan version 0 : 42 : 1 + advanced_extension + optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> +{} + +// ----- + +// CHECK: advanced_extensions { +// CHECK-NEXT: enhancement { +// CHECK-NOT: optimization { +// CHECK-: version + +substrait.plan version 0 : 42 : 1 + advanced_extension + enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> +{} diff --git a/test/Target/SubstraitPB/Import/plan.textpb b/test/Target/SubstraitPB/Import/plan.textpb index 97f2958e..6fe649f0 100644 --- a/test/Target/SubstraitPB/Import/plan.textpb +++ b/test/Target/SubstraitPB/Import/plan.textpb @@ -187,3 +187,62 @@ version { minor_number: 42 patch_number: 1 } + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-SAME: advanced_extension +# CHECK-SAME: optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> +# CHECK-SAME: enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> +# CHECK-NEXT: } + +advanced_extensions { + optimization { + type_url: "http://some.url/with/type.proto" + value: "protobuf message" + } + enhancement { + type_url: "http://other.url/with/type.proto" + value: "other protobuf message" + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-SAME: advanced_extension +# CHECK-SAME: optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> +# CHECK-NEXT: } + +advanced_extensions { + optimization { + type_url: "http://some.url/with/type.proto" + value: "protobuf message" + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-SAME: advanced_extension +# CHECK-SAME: enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> +# CHECK-NEXT: } + +advanced_extensions { + enhancement { + type_url: "http://other.url/with/type.proto" + value: "other protobuf message" + } +} +version { + minor_number: 42 + patch_number: 1 +}