Skip to content

Commit

Permalink
Add stablehlo-translate interpreter arg passing, complex types in ref…
Browse files Browse the repository at this point in the history
…erence APIs (#2600)

Two main changes:
1. Support complex types in Reference APIs
2. Add support for passing arguments when using `stablehlo-translate`

To avoid inventing anything fancy, I've leveraged existing MLIR assembly
printing/parsing:

```
stablehlo-translate myfile.mlir --interpret --args="[dense<1> : tensor<2xi32>, dense<2> : tensor<2xi32>]"
```
  • Loading branch information
GleasonK authored Oct 30, 2024
1 parent 7920ed0 commit 9edd9cf
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 4 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ cc_binary(
"//stablehlo/tests:test_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
28 changes: 26 additions & 2 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,36 @@ DenseElementsAttr makeDenseElementsAttr(Tensor tensor) {
std::vector<llvm::APInt> values;
for (auto it = tensor.index_begin(); it != tensor.index_end(); ++it) {
Element element = tensor.get(*it);
values.push_back(element.getIntegerValue());
if (isSupportedBooleanType(elementType)) {
values.push_back(APInt(1, element.getBooleanValue() ? 1 : 0));
} else {
values.push_back(element.getIntegerValue());
}
}
return DenseIntElementsAttr::get(tensor.getType(), values);
}
if (isa<ComplexType>(elementType)) {
auto complexElemTy = cast<ComplexType>(elementType).getElementType();

if (complexElemTy.isF32()) {
auto elementData =
reinterpret_cast<const std::complex<float> *>(tensor.getData());
ArrayRef<std::complex<float>> elementDataRef(elementData,
tensor.getNumElements());
return DenseElementsAttr::get(tensor.getType(), elementDataRef);
}

if (complexElemTy.isF64()) {
auto elementData =
reinterpret_cast<const std::complex<double> *>(tensor.getData());
ArrayRef<std::complex<double>> elementDataRef(elementData,
tensor.getNumElements());
return DenseElementsAttr::get(tensor.getType(), elementDataRef);
}
}

llvm::report_fatal_error("Only FloatType and IntType are handled currently.");
llvm::report_fatal_error(
"Only FloatType, IntType, and Complex<f32,f64> are handled currently.");
}

Sizes makeSizes(Tensor tensor) {
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/interpret/api_input_arguments.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: stablehlo-translate %s --interpret --args="[dense<1> : tensor<2xi32>, dense<2> : tensor<2xi32>]" | FileCheck %s

// RUN: not stablehlo-translate %s --interpret --args="not_array" 2>&1 | FileCheck %s --check-prefixes=CHECK-ERROR-NOT-ARRAY
// CHECK-ERROR-NOT-ARRAY: expectected array attribute string for args, i.e. --args=[dense<1> : tensor<2xi32>, ...]

// RUN: not stablehlo-translate %s --interpret --args="[4.0 : f32]" 2>&1 | FileCheck %s --check-prefixes=CHECK-ERROR-NOT-DENSE
// CHECK-ERROR-NOT-DENSE: expected dense elements attribute for args elements, i.e. --args=[dense<1> : tensor<2xi32>, ...]

func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<2xi32>
return %0 : tensor<2xi32>
}

// CHECK: tensor<2xi32> {
// CHECK-NEXT: [3, 3]
// CHECK-NEXT: }
43 changes: 41 additions & 2 deletions stablehlo/tools/StablehloTranslateMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ limitations under the License.
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -66,6 +69,10 @@ llvm::cl::opt<std::string> targetOption(
"target", llvm::cl::desc("Target version for serialization"),
llvm::cl::init(""));

llvm::cl::opt<std::string> argsOption(
"args", llvm::cl::desc("Arguments to pass to the interpreter"),
llvm::cl::init(""));

namespace {

stablehlo::Tensor makeBooleanTensor(MLIRContext *context, bool value) {
Expand All @@ -75,6 +82,35 @@ stablehlo::Tensor makeBooleanTensor(MLIRContext *context, bool value) {
return stablehlo::makeTensor(res);
}

// Parse `--args` option into a list of interpreter arguments.
// The format is:
// --args=[dense<1> : tensor<2xi32>, ...], where each dense attribute is
// interpreted as a tensor.
mlir::FailureOr<SmallVector<stablehlo::InterpreterValue>>
parseInterpreterArguments(std::string argsStr, MLIRContext *context) {
llvm::SmallVector<stablehlo::InterpreterValue> inputs;
auto parseError = [&](llvm::StringRef msg) {
std::string usage = "--args=[dense<1> : tensor<2xi32>, ...]";
return emitError(UnknownLoc::get(context), msg) << ", i.e. " << usage;
};
if (!argsStr.empty()) {
auto arrayAttr =
dyn_cast_or_null<ArrayAttr>(mlir::parseAttribute(argsStr, context));
if (!arrayAttr) {
return parseError("expectected array attribute string for args");
}
for (auto attr : arrayAttr.getValue()) {
auto denseAttr = dyn_cast<DenseElementsAttr>(attr);
if (!denseAttr) {
return parseError(
"expected dense elements attribute for args elements");
}
inputs.push_back(stablehlo::makeTensor(denseAttr));
}
}
return inputs;
}

llvm::Error evalCustomCallCheckEq(stablehlo::CustomCallOp op,
stablehlo::Scope &scope) {
if (op->getNumOperands() != 2)
Expand Down Expand Up @@ -224,8 +260,11 @@ TranslateFromMLIRRegistration interpretRegistration(
config.fallback = std::make_unique<StablehloTranslateInterpreterFallback>(
config.probeInstrumentationDir);

llvm::SmallVector<stablehlo::InterpreterValue> inputs;
auto results = evalModule(module, inputs, config);
auto inputs = parseInterpreterArguments(argsOption.getValue(),
module->getContext());
if (failed(inputs)) return failure();

auto results = evalModule(module, inputs.value(), config);
if (failed(results)) return failure();

for (auto &result : *results) {
Expand Down

0 comments on commit 9edd9cf

Please sign in to comment.