From 129964ec23591dea5baeb9a90844364da030eece Mon Sep 17 00:00:00 2001 From: Sergi Granell Date: Sun, 21 May 2023 11:05:07 +0900 Subject: [PATCH] CodeGen_MLIR: Add initial MLIR CodeGen --- dependencies/llvm/CMakeLists.txt | 3 +- src/CMakeLists.txt | 2 + src/CodeGen_MLIR.cpp | 520 +++++++++++++++++++++++++++++++ src/CodeGen_MLIR.h | 101 ++++++ 4 files changed, 625 insertions(+), 1 deletion(-) create mode 100644 src/CodeGen_MLIR.cpp create mode 100644 src/CodeGen_MLIR.h diff --git a/dependencies/llvm/CMakeLists.txt b/dependencies/llvm/CMakeLists.txt index 04fa53c06849..783f68e99203 100644 --- a/dependencies/llvm/CMakeLists.txt +++ b/dependencies/llvm/CMakeLists.txt @@ -10,7 +10,7 @@ set(CMAKE_MAP_IMPORTED_CONFIG_MINSIZEREL MinSizeRel Release RelWithDebInfo "") set(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO RelWithDebInfo Release MinSizeRel "") set(CMAKE_MAP_IMPORTED_CONFIG_RELEASE Release MinSizeRel RelWithDebInfo "") -find_package(LLVM ${Halide_REQUIRE_LLVM_VERSION} REQUIRED) +find_package(MLIR ${Halide_REQUIRE_LLVM_VERSION} REQUIRED) find_package(Clang REQUIRED CONFIG HINTS "${LLVM_DIR}/../clang" "${LLVM_DIR}/../lib/cmake/clang") set(LLVM_PACKAGE_VERSION "${LLVM_PACKAGE_VERSION}" @@ -19,6 +19,7 @@ set(LLVM_PACKAGE_VERSION "${LLVM_PACKAGE_VERSION}" message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}") +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") if (LLVM_PACKAGE_VERSION VERSION_LESS 14.0) message(FATAL_ERROR "LLVM version must be 14.0 or newer") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 84f98033adb5..da892879d3e1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -33,6 +33,7 @@ set(HEADER_FILES CodeGen_Internal.h CodeGen_LLVM.h CodeGen_Metal_Dev.h + CodeGen_MLIR.h CodeGen_OpenCL_Dev.h CodeGen_OpenGLCompute_Dev.h CodeGen_Posix.h @@ -201,6 +202,7 @@ set(SOURCE_FILES CodeGen_Internal.cpp CodeGen_LLVM.cpp CodeGen_Metal_Dev.cpp + CodeGen_MLIR.cpp CodeGen_OpenCL_Dev.cpp CodeGen_OpenGLCompute_Dev.cpp CodeGen_Posix.cpp diff --git a/src/CodeGen_MLIR.cpp b/src/CodeGen_MLIR.cpp new file mode 100644 index 000000000000..fa7b811bb980 --- /dev/null +++ b/src/CodeGen_MLIR.cpp @@ -0,0 +1,520 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "CodeGen_MLIR.h" +#include "IROperator.h" + +namespace Halide { + +namespace Internal { + +bool CodeGen_MLIR::compile(mlir::LocationAttr &loc, mlir::ModuleOp &mlir_module, Stmt stmt, const std::string &name, + const std::vector &args) { + mlir::ImplicitLocOpBuilder builder = mlir::ImplicitLocOpBuilder::atBlockEnd(loc, mlir_module.getBody()); + + mlir::SmallVector inputs; + mlir::SmallVector results; + mlir::SmallVector funcAttrs; + mlir::SmallVector funcArgAttrs; + + for (const auto &arg : args) + inputs.push_back(arg.is_buffer ? mlir::UnrankedMemRefType::get(mlir_type_of(builder, arg.type), {}) : mlir_type_of(builder, arg.type)); + + mlir::FunctionType functionType = builder.getFunctionType(inputs, results); + mlir::func::FuncOp functionOp = builder.create(builder.getStringAttr(name), functionType, funcAttrs, funcArgAttrs); + builder.setInsertionPointToStart(functionOp.addEntryBlock()); + + CodeGen_MLIR::Visitor visitor(builder, args); + stmt.accept(&visitor); + builder.create(); + + return mlir::verify(mlir_module).succeeded(); +} + +mlir::Type CodeGen_MLIR::mlir_type_of(mlir::ImplicitLocOpBuilder &builder, Halide::Type t) { + if (t.lanes() == 1) { + if (t.is_int_or_uint()) { + return builder.getIntegerType(t.bits()); + } else if (t.is_bfloat()) { + return builder.getBF16Type(); + } else if (t.is_float()) { + switch (t.bits()) { + case 16: + return builder.getF16Type(); + case 32: + return builder.getF32Type(); + case 64: + return builder.getF64Type(); + default: + internal_error << "There is no MLIR type matching this floating-point bit width: " << t << "\n"; + return nullptr; + } + } else { + internal_error << "Type not supported: " << t << "\n"; + } + } else { + return mlir::VectorType::get(t.lanes(), mlir_type_of(builder, t.element_of())); + } + + return mlir::Type(); +} + +CodeGen_MLIR::Visitor::Visitor(mlir::ImplicitLocOpBuilder &builder, const std::vector &args) + : builder(builder) { + + mlir::func::FuncOp funcOp = cast(builder.getBlock()->getParentOp()); + for (auto [index, arg] : llvm::enumerate(args)) + sym_push(arg.name, funcOp.getArgument(index)); +} + +mlir::Value CodeGen_MLIR::Visitor::codegen(const Expr &e) { + internal_assert(e.defined()); + debug(4) << "Codegen (E): " << e.type() << ", " << e << "\n"; + value = mlir::Value(); + e.accept(this); + internal_assert(value) << "Codegen of an expr did not produce a MLIR value\n" + << e; + return value; +} + +void CodeGen_MLIR::Visitor::codegen(const Stmt &s) { + internal_assert(s.defined()); + debug(4) << "Codegen (S): " << s << "\n"; + value = mlir::Value(); + s.accept(this); +} + +void CodeGen_MLIR::Visitor::visit(const IntImm *op) { + mlir::Type type = mlir_type_of(op->type); + value = builder.create(type, builder.getIntegerAttr(type, op->value)); +} + +void CodeGen_MLIR::Visitor::visit(const UIntImm *op) { + mlir::Type type = mlir_type_of(op->type); + value = builder.create(type, builder.getIntegerAttr(type, op->value)); +} + +void CodeGen_MLIR::Visitor::visit(const FloatImm *op) { + mlir::Type type = mlir_type_of(op->type); + value = builder.create(type, builder.getFloatAttr(type, op->value)); +} + +void CodeGen_MLIR::Visitor::visit(const StringImm *op) { + internal_error << "String immediates are not supported\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Cast *op) { + Halide::Type src = op->value.type(); + Halide::Type dst = op->type; + mlir::Type mlir_type = mlir_type_of(dst); + + value = codegen(op->value); + + if (src.is_int_or_uint() && dst.is_int_or_uint()) { + if (dst.bits() > src.bits()) { + if (src.is_int()) + value = builder.create(mlir_type, value); + else + value = builder.create(mlir_type, value); + } else { + value = builder.create(mlir_type, value); + } + } else if (src.is_float() && dst.is_int()) { + value = builder.create(mlir_type, value); + } else if (src.is_float() && dst.is_uint()) { + value = builder.create(mlir_type, value); + } else if (src.is_int() && dst.is_float()) { + value = builder.create(mlir_type, value); + } else if (src.is_uint() && dst.is_float()) { + value = builder.create(mlir_type, value); + } else if (src.is_float() && dst.is_float()) { + if (dst.bits() > src.bits()) { + value = builder.create(mlir_type, value); + } else { + value = builder.create(mlir_type, value); + } + } else { + internal_error << "Cast of " << src << " to " << dst << " is not implemented\n"; + } +} + +void CodeGen_MLIR::Visitor::visit(const Reinterpret *op) { + value = builder.create(mlir_type_of(op->type), codegen(op->value)); +} + +void CodeGen_MLIR::Visitor::visit(const Variable *op) { + value = sym_get(op->name, true); +} + +void CodeGen_MLIR::Visitor::visit(const Add *op) { + if (op->type.is_int_or_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const Sub *op) { + if (op->type.is_int_or_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const Mul *op) { + if (op->type.is_int_or_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const Div *op) { + if (op->type.is_int()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const Mod *op) { + if (op->type.is_int()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const Min *op) { + if (op->type.is_int()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const Max *op) { + if (op->type.is_int()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_uint()) + value = builder.create(codegen(op->a), codegen(op->b)); + else if (op->type.is_float()) + value = builder.create(codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const EQ *op) { + if (op->a.type().is_int_or_uint()) + value = builder.create(mlir::arith::CmpIPredicate::eq, codegen(op->a), codegen(op->b)); + else if (op->a.type().is_float()) + value = builder.create(mlir::arith::CmpFPredicate::OEQ, codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const NE *op) { + if (op->a.type().is_int_or_uint()) + value = builder.create(mlir::arith::CmpIPredicate::ne, codegen(op->a), codegen(op->b)); + else if (op->a.type().is_float()) + value = builder.create(mlir::arith::CmpFPredicate::ONE, codegen(op->a), codegen(op->b)); +} + +void CodeGen_MLIR::Visitor::visit(const LT *op) { + if (op->a.type().is_int_or_uint()) { + mlir::arith::CmpIPredicate predicate = op->type.is_int() ? mlir::arith::CmpIPredicate::slt : + mlir::arith::CmpIPredicate::ult; + value = builder.create(predicate, codegen(op->a), codegen(op->b)); + } else if (op->a.type().is_float()) { + value = builder.create(mlir::arith::CmpFPredicate::OLT, codegen(op->a), codegen(op->b)); + } +} + +void CodeGen_MLIR::Visitor::visit(const LE *op) { + if (op->a.type().is_int_or_uint()) { + mlir::arith::CmpIPredicate predicate = op->a.type().is_int() ? mlir::arith::CmpIPredicate::sle : + mlir::arith::CmpIPredicate::ule; + value = builder.create(predicate, codegen(op->a), codegen(op->b)); + } else if (op->a.type().is_float()) { + value = builder.create(mlir::arith::CmpFPredicate::OLE, codegen(op->a), codegen(op->b)); + } +} + +void CodeGen_MLIR::Visitor::visit(const GT *op) { + if (op->a.type().is_int_or_uint()) { + mlir::arith::CmpIPredicate predicate = op->a.type().is_int() ? mlir::arith::CmpIPredicate::sgt : + mlir::arith::CmpIPredicate::ugt; + value = builder.create(predicate, codegen(op->a), codegen(op->b)); + } else if (op->a.type().is_float()) { + value = builder.create(mlir::arith::CmpFPredicate::OGT, codegen(op->a), codegen(op->b)); + } +} + +void CodeGen_MLIR::Visitor::visit(const GE *op) { + if (op->a.type().is_int_or_uint()) { + mlir::arith::CmpIPredicate predicate = op->a.type().is_int() ? mlir::arith::CmpIPredicate::sge : + mlir::arith::CmpIPredicate::uge; + value = builder.create(predicate, codegen(op->a), codegen(op->b)); + } else if (op->a.type().is_float()) { + value = builder.create(mlir::arith::CmpFPredicate::OGE, codegen(op->a), codegen(op->b)); + } +} + +void CodeGen_MLIR::Visitor::visit(const And *op) { + value = builder.create(codegen(NE::make(op->a, make_zero(op->a.type()))), + codegen(NE::make(op->b, make_zero(op->b.type())))); +} + +void CodeGen_MLIR::Visitor::visit(const Or *op) { + value = builder.create(codegen(NE::make(op->a, make_zero(op->a.type()))), + codegen(NE::make(op->b, make_zero(op->b.type())))); +} + +void CodeGen_MLIR::Visitor::visit(const Not *op) { + value = codegen(EQ::make(op->a, make_zero(op->a.type()))); +} + +void CodeGen_MLIR::Visitor::visit(const Select *op) { + value = builder.create(codegen(op->condition), + codegen(op->true_value), + codegen(op->false_value)); +} + +void CodeGen_MLIR::Visitor::visit(const Load *op) { + mlir::Value buffer = sym_get(op->name); + mlir::Type type = mlir_type_of(op->type); + mlir::Value index; + if (op->type.is_scalar()) { + index = codegen(op->index); + } else if (Expr ramp_base = strided_ramp_base(op->index); ramp_base.defined()) { + index = codegen(ramp_base); + } else { + internal_error << "Unsupported load\n"; + } + + if (op->type.is_scalar()) { + value = builder.create(type, buffer, mlir::ValueRange{index}); + } else { + value = builder.create(type, buffer, mlir::ValueRange{index}); + } +} + +void CodeGen_MLIR::Visitor::visit(const Ramp *op) { + mlir::Value base = codegen(op->base); + mlir::Value stride = codegen(op->stride); + mlir::Type elementType = mlir_type_of(op->base.type()); + mlir::VectorType vectorType = mlir::VectorType::get(op->lanes, elementType); + + mlir::SmallVector indicesAttrs(op->lanes); + for (int i = 0; i < op->lanes; i++) + indicesAttrs[i] = mlir::IntegerAttr::get(elementType, i); + + mlir::DenseElementsAttr indicesDenseAttr = mlir::DenseElementsAttr::get(vectorType, indicesAttrs); + mlir::Value indicesConst = builder.create(indicesDenseAttr); + mlir::Value splatStride = builder.create(vectorType, stride); + mlir::Value offsets = builder.create(splatStride, indicesConst); + mlir::Value splatBase = builder.create(vectorType, base); + value = builder.create(splatBase, offsets); +} + +void CodeGen_MLIR::Visitor::visit(const Broadcast *op) { + value = builder.create(mlir_type_of(op->type), codegen(op->value)); +} + +void CodeGen_MLIR::Visitor::visit(const Call *op) { + if (op->is_intrinsic(Call::bitwise_and)) { + value = builder.create(codegen(op->args[0]), codegen(op->args[1])); + } else if (op->is_intrinsic(Call::shift_left)) { + value = builder.create(codegen(op->args[0]), codegen(op->args[1])); + } else if (op->is_intrinsic(Call::shift_right)) { + if (op->type.is_int()) + value = builder.create(codegen(op->args[0]), codegen(op->args[1])); + else + value = builder.create(codegen(op->args[0]), codegen(op->args[1])); + } else if (op->is_intrinsic(Call::widen_right_mul)) { + mlir::Value a = codegen(op->args[0]); + mlir::Value b = codegen(op->args[1]); + if (op->type.is_int()) + b = builder.create(a.getType(), b); + else + b = builder.create(a.getType(), b); + value = builder.create(a, b); + } else { + internal_error << "Call " << op->name << " not implemented\n"; + } +} + +void CodeGen_MLIR::Visitor::visit(const Let *op) { + sym_push(op->name, codegen(op->value)); + value = codegen(op->body); + sym_pop(op->name); +} + +void CodeGen_MLIR::Visitor::visit(const LetStmt *op) { + sym_push(op->name, codegen(op->value)); + codegen(op->body); + sym_pop(op->name); +} + +void CodeGen_MLIR::Visitor::visit(const AssertStmt *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const ProducerConsumer *op) { + codegen(op->body); +} + +void CodeGen_MLIR::Visitor::visit(const For *op) { + mlir::Value min = codegen(op->min); + mlir::Value max = builder.create(min, codegen(op->extent)); + mlir::Value lb = builder.create(builder.getIndexType(), min); + mlir::Value ub = builder.create(builder.getIndexType(), max); + mlir::Value step = builder.create(1); + + mlir::scf::ForOp forOp = builder.create(lb, ub, step); + { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&forOp.getLoopBody().front()); + + mlir::Value i = forOp.getInductionVar(); + sym_push(op->name, builder.create(max.getType(), i)); + codegen(op->body); + sym_pop(op->name); + } +} + +void CodeGen_MLIR::Visitor::visit(const Store *op) { + mlir::Value buffer = sym_get(op->name); + mlir::Value value = codegen(op->value); + mlir::Value index; + if (op->value.type().is_scalar()) { + index = codegen(op->index); + } else if (Expr ramp_base = strided_ramp_base(op->index); ramp_base.defined()) { + index = codegen(ramp_base); + } else { + internal_error << "Unsupported store\n"; + } + + if (op->value.type().is_scalar()) + builder.create(value, buffer, mlir::ValueRange{index}); + else + builder.create(value, buffer, mlir::ValueRange{index}); +} + +void CodeGen_MLIR::Visitor::visit(const Provide *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Allocate *op) { + int32_t size = op->constant_allocation_size(); + mlir::MemRefType type = mlir::MemRefType::get({size}, mlir_type_of(op->type)); + mlir::memref::AllocOp alloc = builder.create(type); + + sym_push(op->name, alloc); + codegen(op->body); + sym_pop(op->name); +} + +void CodeGen_MLIR::Visitor::visit(const Free *op) { + builder.create(sym_get(op->name)); +} + +void CodeGen_MLIR::Visitor::visit(const Realize *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Block *op) { + // Peel blocks of assertions with pure conditions + const AssertStmt *a = op->first.as(); + if (a && is_pure(a->condition)) { + std::vector asserts; + asserts.push_back(a); + Stmt s = op->rest; + while ((op = s.as()) && (a = op->first.as()) && is_pure(a->condition) && asserts.size() < 63) { + asserts.push_back(a); + s = op->rest; + } + // TODO + // codegen_asserts(asserts); + codegen(s); + } else { + codegen(op->first); + codegen(op->rest); + } +} + +void CodeGen_MLIR::Visitor::visit(const IfThenElse *op) { + builder.create( + codegen(op->condition), + /*thenBuilder=*/[&](mlir::OpBuilder &b, mlir::Location) { codegen(op->then_case); }, + /*elseBuilder=*/[&](mlir::OpBuilder &b, mlir::Location) { + if (op->else_case.defined()) + codegen(op->else_case); }); +} + +void CodeGen_MLIR::Visitor::visit(const Evaluate *op) { + codegen(op->value); + // Discard result + value = mlir::Value(); +} + +void CodeGen_MLIR::Visitor::visit(const Shuffle *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const VectorReduce *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Prefetch *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Fork *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Acquire *op) { + internal_error << "Unimplemented\n"; +} + +void CodeGen_MLIR::Visitor::visit(const Atomic *op) { + internal_error << "Unimplemented\n"; +} + +mlir::Type CodeGen_MLIR::Visitor::mlir_type_of(Halide::Type t) const { + return CodeGen_MLIR::mlir_type_of(builder, t); +} + +void CodeGen_MLIR::Visitor::sym_push(const std::string &name, mlir::Value value) { + symbol_table.push(name, value); +} + +void CodeGen_MLIR::Visitor::sym_pop(const std::string &name) { + symbol_table.pop(name); +} + +mlir::Value CodeGen_MLIR::Visitor::sym_get(const std::string &name, bool must_succeed) const { + // look in the symbol table + if (!symbol_table.contains(name)) { + if (must_succeed) { + std::ostringstream err; + err << "Symbol not found: " << name << "\n"; + + if (debug::debug_level() > 0) { + err << "The following names are in scope:\n" + << symbol_table << "\n"; + } + + internal_error << err.str(); + } else { + return nullptr; + } + } + return symbol_table.get(name); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/CodeGen_MLIR.h b/src/CodeGen_MLIR.h new file mode 100644 index 000000000000..47e6de70b40d --- /dev/null +++ b/src/CodeGen_MLIR.h @@ -0,0 +1,101 @@ +#ifndef HALIDE_CODEGEN_MLIR_H +#define HALIDE_CODEGEN_MLIR_H + +/** \file + * Defines the code-generator for producing MLIR code + */ + +#include "DeviceArgument.h" +#include "IRVisitor.h" +#include "Scope.h" + +#include +#include + +namespace Halide { + +struct Target; + +namespace Internal { + +class CodeGen_MLIR { +public: + bool compile(mlir::LocationAttr &loc, mlir::ModuleOp &mlir_module, Stmt stmt, + const std::string &name, const std::vector &args); + + static mlir::Type mlir_type_of(mlir::ImplicitLocOpBuilder &builder, Halide::Type t); + +protected: + class Visitor : public IRVisitor { + public: + Visitor(mlir::ImplicitLocOpBuilder &builder, const std::vector &args); + + protected: + mlir::Value codegen(const Expr &); + void codegen(const Stmt &); + + void visit(const IntImm *) override; + void visit(const UIntImm *) override; + void visit(const FloatImm *) override; + void visit(const StringImm *) override; + void visit(const Cast *) override; + void visit(const Reinterpret *) override; + void visit(const Variable *) override; + void visit(const Add *) override; + void visit(const Sub *) override; + void visit(const Mul *) override; + void visit(const Div *) override; + void visit(const Mod *) override; + void visit(const Min *) override; + void visit(const Max *) override; + void visit(const EQ *) override; + void visit(const NE *) override; + void visit(const LT *) override; + void visit(const LE *) override; + void visit(const GT *) override; + void visit(const GE *) override; + void visit(const And *) override; + void visit(const Or *) override; + void visit(const Not *) override; + void visit(const Select *) override; + void visit(const Load *) override; + void visit(const Ramp *) override; + void visit(const Broadcast *) override; + void visit(const Call *) override; + void visit(const Let *) override; + void visit(const LetStmt *) override; + void visit(const AssertStmt *) override; + void visit(const ProducerConsumer *) override; + void visit(const For *) override; + void visit(const Store *) override; + void visit(const Provide *) override; + void visit(const Allocate *) override; + void visit(const Free *) override; + void visit(const Realize *) override; + void visit(const Block *) override; + void visit(const IfThenElse *) override; + void visit(const Evaluate *) override; + void visit(const Shuffle *) override; + void visit(const VectorReduce *) override; + void visit(const Prefetch *) override; + void visit(const Fork *) override; + void visit(const Acquire *) override; + void visit(const Atomic *) override; + + mlir::Type mlir_type_of(Halide::Type t) const; + + void sym_push(const std::string &name, mlir::Value value); + void sym_pop(const std::string &name); + mlir::Value sym_get(const std::string &name, bool must_succeed = true) const; + + private: + mlir::ImplicitLocOpBuilder &builder; + mlir::Value value; + Scope symbol_table; + }; +}; + +} // namespace Internal +} // namespace Halide + +#endif