Skip to content

Commit

Permalink
feat: factor out advanced_extension logic and add project op
Browse files Browse the repository at this point in the history
This PR factors out the handling of the `shared_extension` field from
the `plan` op and adds that logic to the `project` op. This mainly
consisted of moving the import and export logic from the functions
related to the `plan` op to dedicated functions. The PR also introduces
the new `ExtensibleOpInterface` that enforces an attribute called
`advanced_extension` on the op that implement it and allows to deal with
all such ops transparently. Since that interface depends on an
attribute, the include order of the generated code of interfaces and
attributes also had to be adapted. Unfortunately, the field names in the
Substrait spec also vary (singular or plural), so the PR also introduces
some template magic to be able to deal with protobuf message types with
both spellings. With this PR, message types with an `advanced_extension`
field should be able to support it by (1) adding the
`ExtensibleOpInterface` to their traits, (2) adding an
`advanced_extension` parameter, and (3) adding that parameter to their
assembly format (although that's technically optional; otherwise, the
attribute is set through the `attributes` dictionary).
  • Loading branch information
ingomueller-net committed Jan 22, 2025
1 parent c01f183 commit 8c74e7b
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 52 deletions.
25 changes: 22 additions & 3 deletions include/substrait-mlir/Dialect/Substrait/IR/Substrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,38 @@
#include "mlir/IR/SymbolTable.h" // IWYU: keep
#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU: keep

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export
//===----------------------------------------------------------------------===//
// Substrait dialect
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
//===----------------------------------------------------------------------===//
// Substrait enums
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export

//===----------------------------------------------------------------------===//
// Substrait types
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
#define GET_TYPEDEF_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsTypes.h.inc" // IWYU: export

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsAttrs.h.inc" // IWYU: export

//===----------------------------------------------------------------------===//
// Substrait ops
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#define GET_OP_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export

Expand Down
20 changes: 20 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
let cppNamespace = "::mlir::substrait";
}

def Substrait_ExtensibleOpInterface : OpInterface<"ExtensibleOpInterface"> {
let description = [{
Interface for ops with the `advanced_extension` attribute. Several relations
and other message types of the Substrait specification have a field with the
same name (or the variant `advanced_extensions`, which has the same meaning)
and the interface enables handling all of them transparently.
}];
let cppNamespace = "::mlir::substrait";
let methods = [
InterfaceMethod<
"Get the `advanced_extension` attribute",
"std::optional<::mlir::substrait::AdvancedExtensionAttr>",
"getAdvancedExtension">,
InterfaceMethod<
"Get the `advanced_extension` attribute",
"void", "setAdvancedExtensionAttr",
(ins "::mlir::substrait::AdvancedExtensionAttr":$attr)>,
];
}

def Substrait_RelOpInterface : OpInterface<"RelOpInterface"> {
let description = [{
Interface for any relational operation in a Substrait plan. This corresponds
Expand Down
18 changes: 14 additions & 4 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def PlanBodyOp : AnyOf<[

def Substrait_PlanOp : Substrait_Op<"plan", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
NoTerminator, NoRegionArguments, SingleBlock, SymbolTable
]> {
let summary = "Represents a Substrait plan";
Expand Down Expand Up @@ -177,9 +178,13 @@ def Substrait_PlanOp : Substrait_Op<"plan", [
let builders = [
OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{
build($_builder, $_state, major, minor, patch,
/*git_hash=*/StringAttr(), /*producer*/StringAttr(),
/*git_hash=*/"", /*producer*/"");
}]>,
OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch,
"std::string":$git_hash, "std::string":$producer), [{
build($_builder, $_state, major, minor, patch, git_hash, producer,
/*advanced_extension=*/AdvancedExtensionAttr());
}]>
}]>,
];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
Expand Down Expand Up @@ -526,6 +531,7 @@ def Substrait_NamedTableOp : Substrait_RelOp<"named_table", [

def Substrait_ProjectOp : Substrait_RelOp<"project", [
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>
]> {
let summary = "Project operation";
Expand All @@ -550,14 +556,18 @@ def Substrait_ProjectOp : Substrait_RelOp<"project", [
}
```
}];
let arguments = (ins Substrait_Relation:$input);
let arguments = (ins
Substrait_Relation:$input,
OptionalAttr<Substrait_AdvancedExtensionAttr>:$advanced_extension
);
let regions = (region AnyRegion:$expressions);
let results = (outs Substrait_Relation:$result);
// TODO(ingomueller): We could elide/shorten the block argument from the
// assembly by writing custom printers/parsers similar to
// `scf.for` etc.
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($result) $expressions
$input (`advanced_extension` `` $advanced_extension^)?
attr-dict `:` type($input) `->` type($result) $expressions
}];
let hasRegionVerifier = 1;
let hasFolder = 1;
Expand Down
58 changes: 38 additions & 20 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

using namespace mlir;
using namespace mlir::substrait;
using namespace mlir::substrait::protobuf_utils;
using namespace ::substrait;
using namespace ::substrait::proto;

Expand Down Expand Up @@ -59,6 +60,8 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(RelOpInterface, Rel)
DECLARE_EXPORT_FUNC(SetOp, Rel)

template <typename MessageType>
void exportAdvancedExtension(ExtensibleOpInterface op, MessageType &message);
std::unique_ptr<pb::Any> exportAny(StringAttr attr);
FailureOr<std::unique_ptr<pb::Message>> exportOperation(Operation *op);
FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
Expand Down Expand Up @@ -90,6 +93,36 @@ class SubstraitExporter {
std::unique_ptr<SymbolTable> symbolTable; // Symbol table cache.
};

template <typename MessageType>
void SubstraitExporter::exportAdvancedExtension(ExtensibleOpInterface op,
MessageType &message) {
if (!op.getAdvancedExtension())
return;

// Build the base `AdvancedExtension` message.
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
auto extension = std::make_unique<extensions::AdvancedExtension>();

StringAttr optimizationAttr = extensionAttr.getOptimization();
StringAttr enhancementAttr = extensionAttr.getEnhancement();

// Set `optimization` field if present.
if (optimizationAttr) {
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
extension->set_allocated_optimization(optimization.release());
}

// Set `enhancement` field if present.
if (enhancementAttr) {
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
extension->set_allocated_enhancement(enhancement.release());
}

// Set the `advanced_extension` field in the provided message.
using Trait = advanced_extension_trait<MessageType>;
Trait::set_allocated_advanced_extension(message, extension.release());
}

std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
auto any = std::make_unique<pb::Any>();
auto anyType = mlir::cast<AnyType>(attr.getType());
Expand Down Expand Up @@ -837,26 +870,8 @@ FailureOr<std::unique_ptr<Plan>> SubstraitExporter::exportOperation(PlanOp op) {
version->set_git_hash(op.getGitHash().str());
plan->set_allocated_version(version.release());

// Build `AdvancedExtension` message.
if (op.getAdvancedExtension()) {
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
auto extension = std::make_unique<extensions::AdvancedExtension>();

StringAttr optimizationAttr = extensionAttr.getOptimization();
StringAttr enhancementAttr = extensionAttr.getEnhancement();

if (optimizationAttr) {
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
extension->set_allocated_optimization(optimization.release());
}

if (enhancementAttr) {
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
extension->set_allocated_enhancement(enhancement.release());
}

plan->set_allocated_advanced_extensions(extension.release());
}
// Attach the `AdvancedExtension` message if the attribute exists.
exportAdvancedExtension(op, *plan);

// Add `extension_uris` to plan.
{
Expand Down Expand Up @@ -987,6 +1002,9 @@ SubstraitExporter::exportOperation(ProjectOp op) {
*projectRel->add_expressions() = *expression.value();
}

// Attach the `AdvancedExtension` message if the attribute exists.
exportAdvancedExtension(op, *projectRel);

// Build `Rel` message.
auto rel = std::make_unique<Rel>();
rel->set_allocated_project(projectRel.release());
Expand Down
71 changes: 48 additions & 23 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

using namespace mlir;
using namespace mlir::substrait;
using namespace mlir::substrait::protobuf_utils;
using namespace ::substrait;
using namespace ::substrait::proto;

Expand Down Expand Up @@ -64,6 +65,46 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp)

/// If present, imports the `advanced_extension` or `advanced_extensions` field
/// from the given message and sets the obtained attribute on the given op.
template <typename MessageType>
void importAdvancedExtension(ImplicitLocOpBuilder builder,
ExtensibleOpInterface op,
const MessageType &message);

template <typename MessageType>
void importAdvancedExtension(ImplicitLocOpBuilder builder,
ExtensibleOpInterface op,
const MessageType &message) {
using Trait = advanced_extension_trait<MessageType>;
if (!Trait::has_advanced_extension(message))
return;

// Get the `advanced_extension(s)` field.
const extensions::AdvancedExtension &advancedExtension =
Trait::advanced_extension(message);

// Import `optimization` field if present.
StringAttr optimizationAttr;
if (advancedExtension.has_optimization()) {
const pb::Any &optimization = advancedExtension.optimization();
optimizationAttr = importAny(builder, optimization).value();
}

// Import `enhancement` field if present.
StringAttr enhancementAttr;
if (advancedExtension.has_enhancement()) {
const pb::Any &enhancement = advancedExtension.enhancement();
enhancementAttr = importAny(builder, enhancement).value();
}

// Build attribute and set it on the op.
MLIRContext *context = builder.getContext();
auto advancedExtensionAttr =
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
op.setAdvancedExtensionAttr(advancedExtensionAttr);
}

FailureOr<StringAttr> importAny(ImplicitLocOpBuilder builder,
const pb::Any &message) {
MLIRContext *context = builder.getContext();
Expand Down Expand Up @@ -480,34 +521,15 @@ static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
// Import version.
const Version &version = message.version();

// Import advanced extension.
AdvancedExtensionAttr advancedExtensionAttr;
if (message.has_advanced_extensions()) {
const extensions::AdvancedExtension &advancedExtension =
message.advanced_extensions();

StringAttr optimizationAttr;
if (advancedExtension.has_optimization()) {
const pb::Any &optimization = advancedExtension.optimization();
optimizationAttr = importAny(builder, optimization).value();
}

StringAttr enhancementAttr;
if (advancedExtension.has_enhancement()) {
const pb::Any &enhancement = advancedExtension.enhancement();
enhancementAttr = importAny(builder, enhancement).value();
}

advancedExtensionAttr =
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
}

// Build `PlanOp`.
auto planOp = builder.create<PlanOp>(
version.major_number(), version.minor_number(), version.patch_number(),
version.git_hash(), version.producer(), advancedExtensionAttr);
version.git_hash(), version.producer());
planOp.getBody().push_back(new Block());

// Import advanced extension if it is present.
importAdvancedExtension(builder, planOp, message);

OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToEnd(&planOp.getBody().front());

Expand Down Expand Up @@ -668,6 +690,9 @@ static mlir::FailureOr<ProjectOp> importProjectRel(ImplicitLocOpBuilder builder,
builder.create<ProjectOp>(resultType, inputOp.value()->getResult(0));
projectOp.getExpressions().push_back(conditionBlock.release());

// Import advanced extension if it is present.
importAdvancedExtension(builder, projectOp, projectRel);

return projectOp;
}

Expand Down
55 changes: 55 additions & 0 deletions lib/Target/SubstraitPB/ProtobufUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#ifndef LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H
#define LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H

#include <type_traits>

#include "mlir/IR/Location.h"

namespace substrait::proto {
Expand All @@ -28,6 +30,59 @@ getCommon(const ::substrait::proto::Rel &rel, Location loc);
FailureOr<::substrait::proto::RelCommon *>
getMutableCommon(::substrait::proto::Rel *rel, Location loc);

/// SFINAE-based template that checks if the given (message) type has an field
/// called `advanced_extension`: the `value` member is `true` iff it has. This
/// is useful to deal with the two different names, `advanced_extension` and
/// `advanced_extensions`, that are used for the same thing across different
/// message types in the Substrait spec.
template <typename T>
class has_advanced_extensions {
template <typename C>
static std::true_type test(decltype(&C::advanced_extensions));
template <typename C>
static std::false_type test(...);

public:
static constexpr bool value = decltype(test<T>(0))::value;
};

/// Trait class for accessing the `advanced_extension` field. The default
/// instances is automatically used for message types that call this field
/// `advanced_extension`; the specialization below is automatically used for
/// message types that call it `advanced_extensions`.
template <typename T, typename = void>
struct advanced_extension_trait {
static auto has_advanced_extension(const T &message) {
return message.has_advanced_extension();
}
static auto advanced_extension(const T &message) {
return message.advanced_extension();
}
template <typename S>
static auto set_allocated_advanced_extension(T &message,
S &&advanced_extensions) {
message.set_allocated_advanced_extension(
std::forward<S>(advanced_extensions));
}
};

template <typename T>
struct advanced_extension_trait<
T, std::enable_if_t<has_advanced_extensions<T>::value>> {
static auto has_advanced_extension(const T &message) {
return message.has_advanced_extensions();
}
static auto advanced_extension(const T &message) {
return message.advanced_extensions();
}
template <typename S>
static auto set_allocated_advanced_extension(T &message,
S &&advanced_extensions) {
message.set_allocated_advanced_extensions(
std::forward<S>(advanced_extensions));
}
};

} // namespace mlir::substrait::protobuf_utils

#endif // LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H
Loading

0 comments on commit 8c74e7b

Please sign in to comment.