Skip to content

Commit

Permalink
feat: implement AdvancedExtension and at it to Plan (#58)
Browse files Browse the repository at this point in the history
This PR implements the `AdvanceExtension` message type as a custom
attribute with two `StringAttr`s as well as an `AnyType` with a
parameter for the type URL to represent the `Any` messages as typed
`StringAttr`s. The PR also adds those attributes to the `PlanOp` in
order to represent the `advanced_extensions` message of the `Plan`
message.

We do not currently need this attribute but it is one of the few missing
fields of the `Plan` message as well as many other messages we have
otherwise implemented, so it allows us to get to higher coverage with
relatively little effort.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net authored Jan 21, 2025
1 parent ad4e3b0 commit c01f183
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 6 deletions.
7 changes: 5 additions & 2 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,21 @@ def Substrait_PlanOp : Substrait_Op<"plan", [
UI32Attr:$minor_number,
UI32Attr:$patch_number,
DefaultValuedAttr<StrAttr, "\"\"">:$git_hash,
DefaultValuedAttr<StrAttr, "\"\"">:$producer
DefaultValuedAttr<StrAttr, "\"\"">:$producer,
OptionalAttr<Substrait_AdvancedExtensionAttr>:$advanced_extension
);
let regions = (region RegionOf<PlanBodyOp>:$body);
let assemblyFormat = [{
`version` $major_number `:` $minor_number `:` $patch_number
(`git_hash` $git_hash^)? (`producer` $producer^)?
(`advanced_extension` `` $advanced_extension^)?
attr-dict-with-keyword $body
}];
let builders = [
OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{
build($_builder, $_state, major, minor, patch,
StringAttr(), StringAttr());
/*git_hash=*/StringAttr(), /*producer*/StringAttr(),
/*advanced_extension=*/AdvancedExtensionAttr());
}]>
];
let extraClassDefinition = [{
Expand Down
26 changes: 26 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ def Substrait_AtomicAttributes {
];
}

def Substrait_AnyType : Substrait_Type<"Any", "any"> {
let summary = "Represents the `type_url` of a `google.protobuf.Any` message";
let description = [{
This type represents the `type_url` fields of a `google.protobuf.Any`
message. These messages consist of an opaque byte array and a string holding
the URL identifying the type of what is contained in the byte array.
}];
let parameters = (ins "StringAttr":$type_url);
let assemblyFormat = "`<` $type_url `>`";

}

def Substrait_AdvancedExtensionAttr
: Substrait_Attr<"AdvancedExtension", "advanced_extension"> {
let summary = "Represents the `AdvancedExtenssion` message of Substrait";
let parameters = (ins
OptionalParameter<"StringAttr">:$optimization, // XXX: verify type
OptionalParameter<"StringAttr">:$enhancement
);
let assemblyFormat = [{
( `optimization` `=` $optimization^ )?
( `enhancement` `=` $enhancement^ )?
}];
let genVerifyDecl = 1;
}

/// Attribute of one of the currently supported atomic types.
def Substrait_AtomicAttribute : AnyAttrOf<Substrait_AtomicAttributes.attrs>;

Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ void SubstraitDialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//

LogicalResult AdvancedExtensionAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::StringAttr optimization, mlir::StringAttr enhancement) {
if (optimization && !mlir::isa<AnyType>(optimization.getType()))
return emitError() << "has 'optimization' attribute of wrong type";
if (enhancement && !mlir::isa<AnyType>(enhancement.getType()))
return emitError() << "has 'enhancement' attribute of wrong type";
return success();
}

//===----------------------------------------------------------------------===//
// Substrait enums
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 35 additions & 3 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(RelOpInterface, Rel)
DECLARE_EXPORT_FUNC(SetOp, Rel)

std::unique_ptr<pb::Any> exportAny(StringAttr attr);
FailureOr<std::unique_ptr<pb::Message>> exportOperation(Operation *op);
FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
mlir::Type mlirType);
Expand Down Expand Up @@ -89,6 +90,16 @@ class SubstraitExporter {
std::unique_ptr<SymbolTable> symbolTable; // Symbol table cache.
};

std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
auto any = std::make_unique<pb::Any>();
auto anyType = mlir::cast<AnyType>(attr.getType());
std::string typeUrl = anyType.getTypeUrl().getValue().str();
std::string value = attr.getValue().str();
any->set_type_url(typeUrl);
any->set_value(value);
return any;
}

std::unique_ptr<proto::Type> exportIntegerType(mlir::Type mlirType,
MLIRContext *context) {
// Function that handles `IntegerType`'s.
Expand Down Expand Up @@ -814,18 +825,39 @@ FailureOr<std::unique_ptr<Plan>> SubstraitExporter::exportOperation(PlanOp op) {
using extensions::SimpleExtensionDeclaration;
using extensions::SimpleExtensionURI;

// Build `Plan` message.
auto plan = std::make_unique<Plan>();

// Build `Version` message.
auto version = std::make_unique<Version>();
version->set_major_number(op.getMajorNumber());
version->set_minor_number(op.getMinorNumber());
version->set_patch_number(op.getPatchNumber());
version->set_producer(op.getProducer().str());
version->set_git_hash(op.getGitHash().str());

// Build `Plan` message.
auto plan = std::make_unique<Plan>();
plan->set_allocated_version(version.release());

// Build `AdvancedExtension` message.
if (op.getAdvancedExtension()) {
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
auto extension = std::make_unique<extensions::AdvancedExtension>();

StringAttr optimizationAttr = extensionAttr.getOptimization();
StringAttr enhancementAttr = extensionAttr.getEnhancement();

if (optimizationAttr) {
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
extension->set_allocated_optimization(optimization.release());
}

if (enhancementAttr) {
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
extension->set_allocated_enhancement(enhancement.release());
}

plan->set_allocated_advanced_extensions(extension.release());
}

// Add `extension_uris` to plan.
{
AnchorUniquer anchorUniquer("extension_uri.", anchorsByOp);
Expand Down
36 changes: 35 additions & 1 deletion lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace {
static FailureOr<OP_TYPE> import##MESSAGE_TYPE(ImplicitLocOpBuilder builder, \
const ARG_TYPE &message);

DECLARE_IMPORT_FUNC(Any, pb::Any, StringAttr)
DECLARE_IMPORT_FUNC(CrossRel, Rel, CrossOp)
DECLARE_IMPORT_FUNC(FetchRel, Rel, FetchOp)
DECLARE_IMPORT_FUNC(FilterRel, Rel, FilterOp)
Expand All @@ -63,6 +64,14 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp)

FailureOr<StringAttr> importAny(ImplicitLocOpBuilder builder,
const pb::Any &message) {
MLIRContext *context = builder.getContext();
auto typeUrlAttr = StringAttr::get(context, message.type_url());
auto anyType = AnyType::get(context, typeUrlAttr);
return StringAttr::get(message.value(), anyType);
}

// Helpers to build symbol names from anchors deterministically. This allows
// to reate symbol references from anchors without look-up structure. Also,
// the format is exploited by the export logic to recover the original anchor
Expand Down Expand Up @@ -468,10 +477,35 @@ static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
MLIRContext *context = builder.getContext();
Location loc = UnknownLoc::get(context);

// Import version.
const Version &version = message.version();

// Import advanced extension.
AdvancedExtensionAttr advancedExtensionAttr;
if (message.has_advanced_extensions()) {
const extensions::AdvancedExtension &advancedExtension =
message.advanced_extensions();

StringAttr optimizationAttr;
if (advancedExtension.has_optimization()) {
const pb::Any &optimization = advancedExtension.optimization();
optimizationAttr = importAny(builder, optimization).value();
}

StringAttr enhancementAttr;
if (advancedExtension.has_enhancement()) {
const pb::Any &enhancement = advancedExtension.enhancement();
enhancementAttr = importAny(builder, enhancement).value();
}

advancedExtensionAttr =
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
}

// Build `PlanOp`.
auto planOp = builder.create<PlanOp>(
version.major_number(), version.minor_number(), version.patch_number(),
version.git_hash(), version.producer());
version.git_hash(), version.producer(), advancedExtensionAttr);
planOp.getBody().push_back(new Block());

OpBuilder::InsertionGuard insertGuard(builder);
Expand Down
16 changes: 16 additions & 0 deletions test/Dialect/Substrait/plan-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,19 @@ substrait.plan version 0 : 42 : 1 {
// expected-error@+1 {{'substrait.extension_function' op refers to @function.1, which is not a valid 'uri' op}}
extension_function @function.2 at @function.1["somefunc"]
}

// -----

// Test error if the `enhancement` attribute has the wrong/no type.
substrait.plan version 0 : 42 : 1
// expected-error@+1 {{custom op 'substrait.plan' has 'enhancement' attribute of wrong type}}
advanced_extension enhancement = "blup"
{}

// -----

// Test error if the `optimization` attribute has the wrong/no type.
substrait.plan version 0 : 42 : 1
// expected-error@+1 {{custom op 'substrait.plan' has 'optimization' attribute of wrong type}}
advanced_extension optimization = "blup"
{}
14 changes: 14 additions & 0 deletions test/Dialect/Substrait/plan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,17 @@ substrait.plan version 0 : 42 : 1 {
extension_uri @other.extension at "http://other.url/with/more/extensions.yml"
extension_function @other.function at @other.extension["someotherfunc"]
}

// -----

// CHECK: substrait.plan
// CHECK-SAME: advanced_extension
// CHECK-SAME: optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto">
// CHECK-SAME: enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto">
// CHECK-NEXT: }

substrait.plan version 0 : 42 : 1
advanced_extension
optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto">
enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto">
{}
44 changes: 44 additions & 0 deletions test/Target/SubstraitPB/Export/plan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,47 @@ substrait.plan version 0 : 42 : 1 {
// If not handled carefully, parsing this symbol into an anchor may clash.
extension_uri @extension_uri.0 at "http://other.url/with/more/extensions.yml"
}

// -----

// CHECK: advanced_extensions {
// CHECK-NEXT: optimization {
// CHECK-NEXT: type_url: "http://some.url/with/type.proto"
// CHECK-NEXT: value: "protobuf message"
// CHECK-NEXT: }
// CHECK-NEXT: enhancement {
// CHECK-NEXT: type_url: "http://other.url/with/type.proto"
// CHECK-NEXT: value: "other protobuf message"
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: version

substrait.plan version 0 : 42 : 1
advanced_extension
optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto">
enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto">
{}

// -----

// CHECK: advanced_extensions {
// CHECK-NEXT: optimization {
// CHECK-NOT: enhancement {
// CHECK-: version

substrait.plan version 0 : 42 : 1
advanced_extension
optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto">
{}

// -----

// CHECK: advanced_extensions {
// CHECK-NEXT: enhancement {
// CHECK-NOT: optimization {
// CHECK-: version

substrait.plan version 0 : 42 : 1
advanced_extension
enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto">
{}
59 changes: 59 additions & 0 deletions test/Target/SubstraitPB/Import/plan.textpb
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,62 @@ version {
minor_number: 42
patch_number: 1
}

# -----

# CHECK-LABEL: substrait.plan
# CHECK-SAME: advanced_extension
# CHECK-SAME: optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto">
# CHECK-SAME: enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto">
# CHECK-NEXT: }

advanced_extensions {
optimization {
type_url: "http://some.url/with/type.proto"
value: "protobuf message"
}
enhancement {
type_url: "http://other.url/with/type.proto"
value: "other protobuf message"
}
}
version {
minor_number: 42
patch_number: 1
}

# -----

# CHECK-LABEL: substrait.plan
# CHECK-SAME: advanced_extension
# CHECK-SAME: optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto">
# CHECK-NEXT: }

advanced_extensions {
optimization {
type_url: "http://some.url/with/type.proto"
value: "protobuf message"
}
}
version {
minor_number: 42
patch_number: 1
}

# -----

# CHECK-LABEL: substrait.plan
# CHECK-SAME: advanced_extension
# CHECK-SAME: enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto">
# CHECK-NEXT: }

advanced_extensions {
enhancement {
type_url: "http://other.url/with/type.proto"
value: "other protobuf message"
}
}
version {
minor_number: 42
patch_number: 1
}

0 comments on commit c01f183

Please sign in to comment.