Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeGen_MLIR: Add initial MLIR CodeGen #7587

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmake/FindHalide_LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,18 @@ if (Halide_LLVM_FOUND)
endif ()
endif ()
endforeach ()

find_package(MLIR CONFIG HINTS "${LLVM_INSTALL_PREFIX}" "${LLVM_DIR}/../mlir" "${LLVM_DIR}/../lib/cmake/mlir")
if (MLIR_FOUND)
target_include_directories(Halide_LLVM::Core INTERFACE "$<BUILD_INTERFACE:${MLIR_INCLUDE_DIRS}>")
target_link_libraries(Halide_LLVM::Core INTERFACE
MLIRAnalysis
MLIRIR
MLIRArithDialect
MLIRFuncDialect
MLIRMemRefDialect
MLIRSCFDialect
MLIRVectorDialect
)
endif ()
endif ()
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ void define_enums(py::module &m) {
.value("function_info_header", OutputFileType::function_info_header)
.value("hlpipe", OutputFileType::hlpipe)
.value("llvm_assembly", OutputFileType::llvm_assembly)
.value("mlir", OutputFileType::mlir)
.value("object", OutputFileType::object)
.value("python_extension", OutputFileType::python_extension)
.value("pytorch_wrapper", OutputFileType::pytorch_wrapper)
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ target_sources(
CodeGen_Internal.cpp
CodeGen_LLVM.cpp
CodeGen_Metal_Dev.cpp
CodeGen_MLIR.cpp
CodeGen_OpenCL_Dev.cpp
CodeGen_Posix.cpp
CodeGen_PowerPC.cpp
Expand Down
565 changes: 565 additions & 0 deletions src/CodeGen_MLIR.cpp

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions src/CodeGen_MLIR.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#ifndef HALIDE_CODEGEN_MLIR_H
#define HALIDE_CODEGEN_MLIR_H

/** \file
* Defines the code-generator for producing MLIR code
*/

#include "IRVisitor.h"
#include "Scope.h"

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/ImplicitLocOpBuilder.h>

namespace Halide {

struct Target;

namespace Internal {

struct LoweredFunc;

class CodeGen_MLIR {
public:
CodeGen_MLIR(std::ostream &stream);

void compile(const Module &module);

protected:
void compile_func(mlir::ImplicitLocOpBuilder &builder, const LoweredFunc &func);

static mlir::Type mlir_type_of(mlir::ImplicitLocOpBuilder &builder, Halide::Type t);

class Visitor : public IRVisitor {
public:
Visitor(mlir::ImplicitLocOpBuilder &builder, const LoweredFunc &func);

protected:
mlir::Value codegen(const Expr &);
void codegen(const Stmt &);

void visit(const IntImm *) override;
void visit(const UIntImm *) override;
void visit(const FloatImm *) override;
void visit(const StringImm *) override;
void visit(const Cast *) override;
void visit(const Reinterpret *) override;
void visit(const Variable *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Mul *) override;
void visit(const Div *) override;
void visit(const Mod *) override;
void visit(const Min *) override;
void visit(const Max *) override;
void visit(const EQ *) override;
void visit(const NE *) override;
void visit(const LT *) override;
void visit(const LE *) override;
void visit(const GT *) override;
void visit(const GE *) override;
void visit(const And *) override;
void visit(const Or *) override;
void visit(const Not *) override;
void visit(const Select *) override;
void visit(const Load *) override;
void visit(const Ramp *) override;
void visit(const Broadcast *) override;
void visit(const Call *) override;
void visit(const Let *) override;
void visit(const LetStmt *) override;
void visit(const AssertStmt *) override;
void visit(const ProducerConsumer *) override;
void visit(const For *) override;
void visit(const Store *) override;
void visit(const Provide *) override;
void visit(const Allocate *) override;
void visit(const Free *) override;
void visit(const Realize *) override;
void visit(const Block *) override;
void visit(const IfThenElse *) override;
void visit(const Evaluate *) override;
void visit(const Shuffle *) override;
void visit(const VectorReduce *) override;
void visit(const Prefetch *) override;
void visit(const Fork *) override;
void visit(const Acquire *) override;
void visit(const Atomic *) override;
void visit(const HoistedStorage *) override;

mlir::Type mlir_type_of(Halide::Type t) const;

void sym_push(const std::string &name, mlir::Value value);
void sym_pop(const std::string &name);
mlir::Value sym_get(const std::string &name, bool must_succeed = true) const;

private:
mlir::ImplicitLocOpBuilder &builder;
mlir::Value value;
Scope<mlir::Value> symbol_table;
};

mlir::MLIRContext mlir_context;
std::ostream &stream;
};

} // namespace Internal
} // namespace Halide

#endif
10 changes: 10 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3577,6 +3577,16 @@ void Func::compile_to_llvm_assembly(const string &filename, const vector<Argumen
pipeline().compile_to_llvm_assembly(filename, args, "", target);
}

void Func::compile_to_mlir(const string &filename, const vector<Argument> &args, const string &fn_name,
const Target &target) {
pipeline().compile_to_mlir(filename, args, fn_name, target);
}

void Func::compile_to_mlir(const string &filename, const vector<Argument> &args,
const Target &target) {
pipeline().compile_to_mlir(filename, args, "", target);
}

void Func::compile_to_object(const string &filename, const vector<Argument> &args,
const string &fn_name, const Target &target) {
pipeline().compile_to_object(filename, args, fn_name, target);
Expand Down
8 changes: 8 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,14 @@ class Func {
const Target &target = get_target_from_environment());
// @}

/** Emit MLIR code. */
//@{
void compile_to_mlir(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name,
const Target &target = get_target_from_environment());
void compile_to_mlir(const std::string &filename, const std::vector<Argument> &,
const Target &target = get_target_from_environment());
// @}

/** Statically compile this function to an object file, with the
* given filename (which should probably end in .o or .obj), type
* signature, and C function name (which defaults to the same name
Expand Down
2 changes: 1 addition & 1 deletion src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ gengen
[assembly, bitcode, c_header, c_source, cpp_stub, featurization,
llvm_assembly, object, python_extension, pytorch_wrapper, registration,
schedule, static_library, stmt, stmt_html, conceptual_stmt,
conceptual_stmt_html, compiler_log, hlpipe, device_code].
conceptual_stmt_html, compiler_log, hlpipe, device_code, mlir].
If omitted, default value is [c_header, static_library, registration].

-p A comma-separated list of shared libraries that will be loaded before the
Expand Down
11 changes: 11 additions & 0 deletions src/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "CodeGen_C.h"
#include "CodeGen_Internal.h"
#include "CodeGen_MLIR.h"
#include "CodeGen_PyTorch.h"
#include "CompilerLogger.h"
#include "Debug.h"
Expand Down Expand Up @@ -42,6 +43,7 @@ std::map<OutputFileType, const OutputInfo> get_output_info(const Target &target)
{OutputFileType::function_info_header, {"function_info_header", ".function_info.h", IsSingle}},
{OutputFileType::hlpipe, {"hlpipe", ".hlpipe", IsSingle}},
{OutputFileType::llvm_assembly, {"llvm_assembly", ".ll", IsMulti}},
{OutputFileType::mlir, {"mlir", ".mlir", IsSingle}},
{OutputFileType::object, {"object", is_windows_coff ? ".obj" : ".o", IsMulti}},
{OutputFileType::python_extension, {"python_extension", ".py.cpp", IsSingle}},
{OutputFileType::pytorch_wrapper, {"pytorch_wrapper", ".pytorch.h", IsSingle}},
Expand Down Expand Up @@ -774,6 +776,15 @@ void Module::compile(const std::map<OutputFileType, std::string> &output_files)
file.close();
internal_assert(!file.fail());
}
if (contains(output_files, OutputFileType::mlir)) {
debug(1) << "Module.compile(): mlir " << output_files.at(OutputFileType::mlir) << "\n";

std::ofstream file(output_files.at(OutputFileType::mlir));
Internal::CodeGen_MLIR cg(file);
cg.compile(*this);
file.close();
internal_assert(!file.fail());
}
if (contains(output_files, OutputFileType::compiler_log)) {
debug(1) << "Module.compile(): compiler_log " << output_files.at(OutputFileType::compiler_log) << "\n";
std::ofstream file(output_files.at(OutputFileType::compiler_log));
Expand Down
1 change: 1 addition & 0 deletions src/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ enum class OutputFileType {
function_info_header,
hlpipe,
llvm_assembly,
mlir,
object,
python_extension,
pytorch_wrapper,
Expand Down
8 changes: 8 additions & 0 deletions src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ void Pipeline::compile_to_llvm_assembly(const string &filename,
m.compile(single_output(filename, m, OutputFileType::llvm_assembly));
}

void Pipeline::compile_to_mlir(const string &filename,
const vector<Argument> &args,
const string &fn_name,
const Target &target) {
Module m = compile_to_module(args, fn_name, target);
m.compile(single_output(filename, m, OutputFileType::mlir));
}

void Pipeline::compile_to_object(const string &filename,
const vector<Argument> &args,
const string &fn_name,
Expand Down
6 changes: 6 additions & 0 deletions src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ class Pipeline {
const std::string &fn_name,
const Target &target = get_target_from_environment());

/** Emit MLIR code. */
void compile_to_mlir(const std::string &filename,
const std::vector<Argument> &args,
const std::string &fn_name,
const Target &target = get_target_from_environment());

/** Statically compile a pipeline with multiple output functions to an
* object file, with the given filename (which should probably end in
* .o or .obj), type signature, and C function name (which defaults to
Expand Down