diff --git a/BUILD b/BUILD index 880271eed..ba46ff49a 100644 --- a/BUILD +++ b/BUILD @@ -10,24 +10,25 @@ package( py_package( name = "enzyme_jax_data", + # Only include these Python packages. + packages = [ + "@//enzyme_jax:enzyme_call.so", + "@llvm-project//clang:builtin_headers_gen", + ], + prefix = "enzyme_jax/", deps = [ "//enzyme_jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen", ], - # Only include these Python packages. - packages = ["@//enzyme_jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"], - prefix = "enzyme_jax/" ) py_wheel( name = "enzyme_jax", + author = "Enzyme Authors", + author_email = "wmoses@mit.edu, zinenko@google.com", # Package data. We're building "example_minimal_package-0.0.1-py3-none-any.whl" distribution = "enzyme_jax", - author="Enzyme Authors", - license='LLVM', - author_email="wmoses@mit.edu, zinenko@google.com", - python_tag = "py3", - version = "0.0.5", + license = "LLVM", platform = select({ "@bazel_tools//src/conditions:windows_x64": "win_amd64", "@bazel_tools//src/conditions:darwin_arm64": "macosx_11_0_arm64", @@ -36,5 +37,10 @@ py_wheel( "@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64", "@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le", }), - deps = ["//enzyme_jax:enzyme_jax_internal", ":enzyme_jax_data"] + python_tag = "py3", + version = "0.0.5", + deps = [ + ":enzyme_jax_data", + "//enzyme_jax:enzyme_jax_internal", + ], ) diff --git a/enzyme_jax/BUILD b/enzyme_jax/BUILD index e00831138..e7ac864e8 100644 --- a/enzyme_jax/BUILD +++ b/enzyme_jax/BUILD @@ -11,7 +11,7 @@ cc_library( srcs = ["clang_compile.cc"], hdrs = ["clang_compile.h"], deps = [ - "@pybind11", + "@enzyme//:EnzymeStatic", "@llvm-project//clang:ast", "@llvm-project//clang:basic", "@llvm-project//clang:driver", @@ -19,18 +19,21 @@ cc_library( "@llvm-project//clang:frontend_tool", "@llvm-project//clang:lex", "@llvm-project//clang:serialization", - "@llvm-project//llvm:Support", "@llvm-project//llvm:Core", "@llvm-project//llvm:IRReader", "@llvm-project//llvm:OrcJIT", - "@enzyme//:EnzymeStatic" + "@llvm-project//llvm:Support", + "@pybind11", ], ) py_library( name = "enzyme_jax_internal", - srcs = ["primitives.py", "__init__.py"], - visibility = ["//visibility:public"] + srcs = [ + "__init__.py", + "primitives.py", + ], + visibility = ["//visibility:public"], ) pybind_library( @@ -95,12 +98,12 @@ pybind_library( pybind_extension( name = "enzyme_call", srcs = ["enzyme_call.cc"], + visibility = ["//visibility:public"], deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//llvm:OrcJIT", ":clang_compile", ":compile_with_xla", - "@com_google_absl//absl/status" + "@com_google_absl//absl/status", + "@llvm-project//llvm:OrcJIT", + "@llvm-project//llvm:Support", ], - visibility = ["//visibility:public"], ) diff --git a/enzyme_jax/clang_compile.cc b/enzyme_jax/clang_compile.cc index fce4dfcc1..1aee540e5 100644 --- a/enzyme_jax/clang_compile.cc +++ b/enzyme_jax/clang_compile.cc @@ -7,36 +7,21 @@ //===----------------------------------------------------------------------===// #include "clang_compile.h" -#include "llvm/IRReader/IRReader.h" -#include -#include -#include -#include -#include -#include -#include +#include +#include #include +#include +#include #include +#include -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/AsmParser/LLLexer.h" -#include "llvm/AsmParser/LLParser.h" -#include "llvm/AsmParser/LLToken.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/AsmParser/SlotMapping.h" -#include "llvm-c/Core.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" -#include "clang/CodeGen/CodeGenAction.h" +#include +#include +#include +#include +#include "Enzyme/Enzyme.h" #include "clang/AST/Decl.h" #include "clang/Basic/DiagnosticOptions.h" #include "clang/Basic/FileManager.h" @@ -46,34 +31,45 @@ #include "clang/Basic/TargetInfo.h" #include "clang/Basic/TargetOptions.h" #include "clang/Basic/Version.h" +#include "clang/CodeGen/CodeGenAction.h" #include "clang/Driver/Compilation.h" #include "clang/Driver/Driver.h" #include "clang/Driver/Tool.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/CompilerInvocation.h" #include "clang/Frontend/FrontendOptions.h" +#include "clang/Frontend/TextDiagnosticBuffer.h" #include "clang/Frontend/TextDiagnosticPrinter.h" #include "clang/Frontend/Utils.h" +#include "clang/FrontendTool/Utils.h" #include "clang/Parse/ParseAST.h" #include "clang/Parse/Parser.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" +#include "llvm-c/Core.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/AsmParser/LLLexer.h" +#include "llvm/AsmParser/LLParser.h" +#include "llvm/AsmParser/LLToken.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/AsmParser/SlotMapping.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "clang/Frontend/TextDiagnosticBuffer.h" #include "llvm/Support/Host.h" -#include "clang/FrontendTool/Utils.h" -#include "llvm/MC/TargetRegistry.h" -#include "llvm/CodeGen/CommandFlags.h" #include "llvm/Support/MemoryBufferRef.h" -#include "llvm/Linker/Linker.h" - -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" - -#include -#include - -#include "Enzyme/Enzyme.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" namespace clang { namespace driver { @@ -86,21 +82,21 @@ namespace tools { void addDirectoryList(const llvm::opt::ArgList &Args, llvm::opt::ArgStringList &CmdArgs, const char *ArgName, const char *EnvVar); -} -} -} +} // namespace tools +} // namespace driver +} // namespace clang using namespace clang; using namespace llvm; class ArgumentList { -private: + private: /// Helper storage. llvm::SmallVector> Storage; /// List of arguments llvm::opt::ArgStringList Args; -public: + public: /// Add argument. /// /// The element stored will not be owned by this. @@ -127,7 +123,7 @@ class ArgumentList { /// /// The return value of this operation could be invalidated by subsequent /// calls to push_back() or emplace_back(). - llvm::opt::ArgStringList& getArguments() { return Args; } + llvm::opt::ArgStringList &getArguments() { return Args; } }; /* @@ -148,9 +144,10 @@ PYBIND11_DECLARE_HOLDER_TYPE(T, ptr_wrapper, true); */ // Returns the TargetMachine instance or zero if no triple is provided. -static TargetMachine* GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr, +static TargetMachine *GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr, StringRef FeaturesStr, - const llvm::TargetOptions &Options, CodeGenOptLevel level) { + const llvm::TargetOptions &Options, + CodeGenOptLevel level) { std::string Error; const Target *TheTarget = TargetRegistry::lookupTarget(codegen::getMArch(), TheTriple, Error); @@ -165,9 +162,12 @@ static TargetMachine* GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr, codegen::getExplicitCodeModel(), level); } -std::unique_ptr GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, ArrayRef pyargv, LLVMContext* Context, std::unique_ptr linkMod) { - const llvm::opt::InputArgList Args; - const char *binary = cpp ? "clang++" : "clang"; +std::unique_ptr GetLLVMFromJob( + std::string filename, std::string filecontents, bool cpp, + ArrayRef pyargv, LLVMContext *Context, + std::unique_ptr linkMod) { + const llvm::opt::InputArgList Args; + const char *binary = cpp ? "clang++" : "clang"; // Buffer diagnostics from argument parsing so that we can output them using a // well formed diagnostic object. IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); @@ -179,15 +179,14 @@ std::unique_ptr GetLLVMFromJob(std::string filename, std::string f IntrusiveRefCntPtr DiagOpts0 = new DiagnosticOptions(); IntrusiveRefCntPtr DiagID0(new DiagnosticIDs()); DiagnosticsEngine Diags0(DiagID0, &*DiagOpts0, DiagsBuffer0); - const std::unique_ptr driver( - new clang::driver::Driver(binary, llvm::sys::getDefaultTargetTriple(), Diags0)); + const std::unique_ptr driver(new clang::driver::Driver( + binary, llvm::sys::getDefaultTargetTriple(), Diags0)); ArgumentList Argv; - + Argv.emplace_back(StringRef(filename)); - for (auto v : pyargv) - Argv.emplace_back(v); + for (auto v : pyargv) Argv.emplace_back(v); - SmallVector PreArgs; + SmallVector PreArgs; PreArgs.push_back(binary); PreArgs.append(Argv.getArguments()); PreArgs[1] = "-"; @@ -204,27 +203,34 @@ std::unique_ptr GetLLVMFromJob(std::string filename, std::string f // frontend into the driver. It will allow deleting 4 otherwise unused flags. // CPATH - included following the user specified includes (but prior to // builtin and standard includes). - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-I", "CPATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-I", + "CPATH"); // C_INCLUDE_PATH - system includes enabled when compiling C. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-c-isystem", "C_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), + "-c-isystem", "C_INCLUDE_PATH"); // CPLUS_INCLUDE_PATH - system includes enabled when compiling C++. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-cxx-isystem", "CPLUS_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), + "-cxx-isystem", "CPLUS_INCLUDE_PATH"); // OBJC_INCLUDE_PATH - system includes enabled when compiling ObjC. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-objc-isystem", "OBJC_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), + "-objc-isystem", "OBJC_INCLUDE_PATH"); // OBJCPLUS_INCLUDE_PATH - system includes enabled when compiling ObjC++. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-objcxx-isystem", "OBJCPLUS_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList( + Args, Argv.getArguments(), "-objcxx-isystem", "OBJCPLUS_INCLUDE_PATH"); auto &TC = compilation->getDefaultToolChain(); if (cpp) { - bool HasStdlibxxIsystem = false; // Args.hasArg(options::OPT_stdlibxx_isystem); - HasStdlibxxIsystem ? TC.AddClangCXXStdlibIsystemArgs(Args, Argv.getArguments()) - : TC.AddClangCXXStdlibIncludeArgs(Args, Argv.getArguments()); + bool HasStdlibxxIsystem = + false; // Args.hasArg(options::OPT_stdlibxx_isystem); + HasStdlibxxIsystem + ? TC.AddClangCXXStdlibIsystemArgs(Args, Argv.getArguments()) + : TC.AddClangCXXStdlibIncludeArgs(Args, Argv.getArguments()); } - TC.AddClangSystemIncludeArgs(Args, Argv.getArguments()); - + TC.AddClangSystemIncludeArgs(Args, Argv.getArguments()); + SmallVector outputvec; - + std::unique_ptr Clang(new CompilerInstance()); // Register the support for object-file-wrapped Clang modules. @@ -232,20 +238,27 @@ std::unique_ptr GetLLVMFromJob(std::string filename, std::string f // PCHOps->registerWriter(std::make_unique()); // PCHOps->registerReader(std::make_unique()); + auto baseFS = createVFSFromCompilerInvocation(Clang->getInvocation(), Diags); - auto baseFS = createVFSFromCompilerInvocation(Clang->getInvocation(), - Diags); - - IntrusiveRefCntPtr fs(new llvm::vfs::InMemoryFileSystem()); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); struct tm y2k = {}; - y2k.tm_hour = 0; y2k.tm_min = 0; y2k.tm_sec = 0; - y2k.tm_year = 100; y2k.tm_mon = 0; y2k.tm_mday = 1; + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; time_t timer = mktime(&y2k); - fs->addFile(filename, timer, llvm::MemoryBuffer::getMemBuffer(filecontents, filename, /*RequiresNullTerminator*/false)); - fs->addFile("/enzyme/enzyme/utils", timer, llvm::MemoryBuffer::getMemBuffer(R"( + fs->addFile(filename, timer, + llvm::MemoryBuffer::getMemBuffer( + filecontents, filename, /*RequiresNullTerminator*/ false)); + fs->addFile("/enzyme/enzyme/utils", timer, + llvm::MemoryBuffer::getMemBuffer( + R"( namespace enzyme { template RT __enzyme_fwddiff(Args...); @@ -264,8 +277,11 @@ extern "C" int enzyme_dupnoneed; extern "C" int enzyme_nooverwrite; extern "C" int enzyme_tape; extern "C" int enzyme_allocated; - )", "/enzyme/enzyme/utils", /*RequiresNullTerminator*/false)); - fs->addFile("/enzyme/enzyme/tensor", timer, llvm::MemoryBuffer::getMemBuffer(R"( + )", + "/enzyme/enzyme/utils", /*RequiresNullTerminator*/ false)); + fs->addFile("/enzyme/enzyme/tensor", timer, + llvm::MemoryBuffer::getMemBuffer( + R"( #include #include namespace enzyme { @@ -445,26 +461,28 @@ struct tensor }; } - )", "/enzyme/enzyme/tensor", /*RequiresNullTerminator*/false)); + )", + "/enzyme/enzyme/tensor", /*RequiresNullTerminator*/ false)); - std::unique_ptr outputStream(new llvm::raw_svector_ostream(outputvec)); + std::unique_ptr outputStream( + new llvm::raw_svector_ostream(outputvec)); Clang->setOutputStream(std::move(outputStream)); - IntrusiveRefCntPtr fuseFS(new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); fuseFS->pushOverlay(fs); fuseFS->pushOverlay(baseFS); Clang->createFileManager(fuseFS); - - bool Success = CompilerInvocation::CreateFromArgs(Clang->getInvocation(), - Argv.getArguments(), Diags, binary); + bool Success = CompilerInvocation::CreateFromArgs( + Clang->getInvocation(), Argv.getArguments(), Diags, binary); // Infer the builtin include path if unspecified. if (Clang->getHeaderSearchOpts().UseBuiltinIncludes && Clang->getHeaderSearchOpts().ResourceDir.empty()) Clang->getHeaderSearchOpts().ResourceDir = - CompilerInvocation::GetResourcesPath(binary, /*MainAddr*/0x0); + CompilerInvocation::GetResourcesPath(binary, /*MainAddr*/ 0x0); // Create the actual diagnostics engine. Clang->createDiagnostics(); @@ -504,7 +522,7 @@ struct tensor if (f.getName() == "entry") continue; f.setLinkage(Function::LinkageTypes::InternalLinkage); } - + PipelineTuningOptions PTO; LoopAnalysisManager LAM; FunctionAnalysisManager FAM; @@ -518,13 +536,14 @@ struct tensor llvm::driver::createTLII(triple, Clang->getCodeGenOpts().getVecLib())); FAM.registerPass([&] { return TargetLibraryAnalysis(*TLII); }); - - auto level = CodeGenOptLevel::Aggressive; //OptimizationLevel::O3; + auto level = CodeGenOptLevel::Aggressive; // OptimizationLevel::O3; Triple ModuleTriple(mod->getTargetTriple()); std::string CPUStr, FeaturesStr; - auto ETM = llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple())).createTargetMachine (); + auto ETM = + llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple())) + .createTargetMachine(); if (!ETM) { throw pybind11::value_error("failed to create targetmachine"); } @@ -548,4 +567,3 @@ struct tensor MPM.run(*mod, MAM); return mod; } - diff --git a/enzyme_jax/clang_compile.h b/enzyme_jax/clang_compile.h index ff9a32c14..73050b3c7 100644 --- a/enzyme_jax/clang_compile.h +++ b/enzyme_jax/clang_compile.h @@ -10,9 +10,14 @@ #define ENZYME_JAX_CLANG_COMPILE_H #include + #include + #include "llvm/IR/Module.h" -std::unique_ptr GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, llvm::ArrayRef pyargv, llvm::LLVMContext*ctx=nullptr, std::unique_ptr linkMod=nullptr); +std::unique_ptr GetLLVMFromJob( + std::string filename, std::string filecontents, bool cpp, + llvm::ArrayRef pyargv, llvm::LLVMContext* ctx = nullptr, + std::unique_ptr linkMod = nullptr); #endif // ENZYME_JAX_CLANG_COMPILE_H diff --git a/enzyme_jax/enzyme_call.cc b/enzyme_jax/enzyme_call.cc index ae13383b3..2b457ca9e 100644 --- a/enzyme_jax/enzyme_call.cc +++ b/enzyme_jax/enzyme_call.cc @@ -29,40 +29,38 @@ #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Instructions.h" +#include "llvm/IRReader/IRReader.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/RWMutex.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" #include "pybind11/pybind11.h" -#include "llvm/IRReader/IRReader.h" -#include "llvm/Support/SourceMgr.h" absl::StatusOr compile_mhlo_to_llvm_with_xla( llvm::StringRef mhlo_text); -enum class Language : int { - CPP = 0, - LLVM = 1, - MHLO = 2 -}; +enum class Language : int { CPP = 0, LLVM = 1, MHLO = 2 }; namespace { class CpuKernel { // static llvm::orc::ExecutionSession ES; static std::unique_ptr DL; - static std::unique_ptr JIT; + static std::unique_ptr JIT; int64_t identifier; size_t num_out; uint64_t addr; - public: - CpuKernel(int64_t identifier, - size_t num_out, uint64_t addr) - : identifier(identifier), num_out(num_out), addr(addr) { - } - static std::string make_type(std::string typenam, llvm::ArrayRef shape, bool constv, Language lang) { - std::string s = std::string(constv ? "const " : "") + "enzyme::tensor<" + typenam; + public: + CpuKernel(int64_t identifier, size_t num_out, uint64_t addr) + : identifier(identifier), num_out(num_out), addr(addr) {} + + static std::string make_type(std::string typenam, + llvm::ArrayRef shape, bool constv, + Language lang) { + std::string s = + std::string(constv ? "const " : "") + "enzyme::tensor<" + typenam; for (auto v : shape) { s += ", " + std::to_string(v); } @@ -90,98 +88,110 @@ class CpuKernel { std::string stringbuf; switch (lang) { - case Language::CPP: - ss << source << "\n"; - break; - - - case Language::MHLO:{ - absl::StatusOr llvm_ir = - compile_mhlo_to_llvm_with_xla(source); - if (!llvm_ir.ok()) { - throw std::runtime_error("failed to compile to LLVM IR with XLA:" + - llvm_ir.status().ToString()); - } - stringbuf = *llvm_ir; - source = stringbuf; - // explicitly fall through - } - case Language::LLVM: - llvm::SMDiagnostic Err; - linkMod = llvm::parseIR(llvm::MemoryBufferRef(source, ""), Err, *llvm_ctx); - if (!linkMod) { - std::string err_str; - llvm::raw_string_ostream ss(err_str); - Err.print("llvmsource", ss, false); - throw pybind11::value_error("failed to compile LLVM: " + ss.str()); - } - assert(linkMod); - if (lang == Language::MHLO) { - for (auto &lfn : linkMod->functions()) { - if (lfn.empty()) continue; - assert(fn != "mhlo_main"); - fn = "mhlo_main"; - lfn.setName(fn); - lfn.addFnAttr(llvm::Attribute::AlwaysInline); + case Language::CPP: + ss << source << "\n"; + break; + + case Language::MHLO: { + absl::StatusOr llvm_ir = + compile_mhlo_to_llvm_with_xla(source); + if (!llvm_ir.ok()) { + throw std::runtime_error("failed to compile to LLVM IR with XLA:" + + llvm_ir.status().ToString()); } + stringbuf = *llvm_ir; + source = stringbuf; + // explicitly fall through } - ss << " extern \"C\" void " << fn << "(void* retval, void* run_options, void* params, void* buffer_table, void* status, void* prof_counters);\n\n"; + case Language::LLVM: + llvm::SMDiagnostic Err; + linkMod = llvm::parseIR(llvm::MemoryBufferRef(source, ""), Err, + *llvm_ctx); + if (!linkMod) { + std::string err_str; + llvm::raw_string_ostream ss(err_str); + Err.print("llvmsource", ss, false); + throw pybind11::value_error("failed to compile LLVM: " + ss.str()); + } + assert(linkMod); + if (lang == Language::MHLO) { + for (auto &lfn : linkMod->functions()) { + if (lfn.empty()) continue; + assert(fn != "mhlo_main"); + fn = "mhlo_main"; + lfn.setName(fn); + lfn.addFnAttr(llvm::Attribute::AlwaysInline); + } + } + ss << " extern \"C\" void " << fn + << "(void* retval, void* run_options, void* params, void* " + "buffer_table, void* status, void* prof_counters);\n\n"; - ss << " __attribute__((always_inline)) static inline void abi_wrap("; - bool comma = false; - for (size_t i=0, off=0; i(" << fn << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; - for (size_t i=0; i(" << fn + << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; + for (size_t i = 0; i < out_shapes.size(); i++) { + ss << ", enzyme_dup, &out_" << i << ", nullptr"; } - for (size_t i=0; i(" << fn << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; - for (size_t i=0; i(" << fn + << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; + for (size_t i = 0; i < out_shapes.size(); i++) { + ss << ", enzyme_dup, nullptr, &dout_" << i; } - for (size_t i=0; i pyargv_strs; - assert (PySequence_Check(pyargv)); - auto sz = PySequence_Size(pyargv); + assert(PySequence_Check(pyargv)); + auto sz = PySequence_Size(pyargv); for (Py_ssize_t i = 0; i < sz; ++i) { - PyObject* item = PySequence_GetItem(pyargv, i); + PyObject *item = PySequence_GetItem(pyargv, i); #if PY_VERSION_HEX < 0x03000000 - auto argv = PyString_AsString(item); + auto argv = PyString_AsString(item); #else - auto argv = PyUnicode_AsUTF8(item); + auto argv = PyUnicode_AsUTF8(item); #endif - Py_DECREF(item); - assert(argv); - pyargv_strs.emplace_back(argv); + Py_DECREF(item); + assert(argv); + pyargv_strs.emplace_back(argv); #if PY_VERSION_HEX < 0x03000000 - free(argv); + free(argv); #else - // should not free py3+ + // should not free py3+ #endif } - auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/true, pyargv_strs, llvm_ctx.get(), std::move(linkMod)); - if (!mod) - throw pybind11::value_error("failed to compile C++"); + auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true, + pyargv_strs, llvm_ctx.get(), std::move(linkMod)); + if (!mod) throw pybind11::value_error("failed to compile C++"); return std::make_tuple(std::move(mod), std::move(llvm_ctx), num_out); } static size_t tapeSize(llvm::StringRef fn, llvm::StringRef source, - llvm::ArrayRef> out_shapes, - llvm::ArrayRef out_names, - llvm::ArrayRef> in_shapes, - llvm::ArrayRef in_names, - PyObject* pyargv, Language lang) { + llvm::ArrayRef> out_shapes, + llvm::ArrayRef out_names, + llvm::ArrayRef> in_shapes, + llvm::ArrayRef in_names, PyObject *pyargv, + Language lang) { int mode = 4; - auto [mod, llvm_ctx, num_out] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, pyargv, mode, lang); + auto [mod, llvm_ctx, num_out] = + createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, + pyargv, mode, lang); auto lfn = mod->getFunction("entry"); - auto RI = llvm::cast(lfn->getEntryBlock().getTerminator()); + auto RI = + llvm::cast(lfn->getEntryBlock().getTerminator()); auto val = llvm::cast(RI->getReturnValue()); size_t res = val->getZExtValue(); // force deletion of mod first explicitly @@ -368,24 +404,34 @@ class CpuKernel { return res; } - static int64_t create(llvm::StringRef fn, llvm::StringRef source, llvm::ArrayRef> out_shapes, llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, - llvm::ArrayRef in_names, - PyObject* pyargv, int mode, Language lang) { + llvm::ArrayRef in_names, PyObject *pyargv, + int mode, Language lang) { llvm::sys::SmartScopedWriter lock(kernel_mutex); int64_t identifier = last_identifier++; - auto [mod, llvm_ctx, num_out] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, pyargv, mode, lang); + auto [mod, llvm_ctx, num_out] = + createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, + pyargv, mode, lang); if (!JIT) { DL = std::make_unique(mod.get()); - auto tJIT = llvm::orc::LLJITBuilder().setDataLayout(*DL.get()).setLinkProcessSymbolsByDefault(true).setObjectLinkingLayerCreator( - [](llvm::orc::ExecutionSession & ES, const llvm::Triple &OLL) -> llvm::Expected> { - return std::make_unique(ES); - }).setJITTargetMachineBuilder(llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple()))).create(); + auto tJIT = + llvm::orc::LLJITBuilder() + .setDataLayout(*DL.get()) + .setLinkProcessSymbolsByDefault(true) + .setObjectLinkingLayerCreator( + [](llvm::orc::ExecutionSession &ES, const llvm::Triple &OLL) + -> llvm::Expected< + std::unique_ptr> { + return std::make_unique(ES); + }) + .setJITTargetMachineBuilder(llvm::orc::JITTargetMachineBuilder( + llvm::Triple(mod->getTargetTriple()))) + .create(); if (!tJIT) { llvm::errs() << tJIT.takeError() << "\n"; throw pybind11::value_error("failed to create jit"); @@ -394,12 +440,16 @@ class CpuKernel { assert(JIT); } - auto LibA = JIT->createJITDylib("enzymedl_"+std::to_string(identifier)); + auto LibA = JIT->createJITDylib("enzymedl_" + std::to_string(identifier)); // Add the module. - // if (auto Err = JIT->addIRModule(llvm::orc::ThreadSafeModule(std::move(mod), std::move(llvm_ctx)))) { - if (auto Err = JIT->addIRModule(LibA.get(), llvm::orc::ThreadSafeModule(std::move(mod), std::move(llvm_ctx)))) { - llvm::errs() <<" error " << Err << "\n"; + // if (auto Err = + // JIT->addIRModule(llvm::orc::ThreadSafeModule(std::move(mod), + // std::move(llvm_ctx)))) { + if (auto Err = JIT->addIRModule( + LibA.get(), + llvm::orc::ThreadSafeModule(std::move(mod), std::move(llvm_ctx)))) { + llvm::errs() << " error " << Err << "\n"; throw pybind11::value_error("failed to add IR module"); } @@ -412,10 +462,9 @@ class CpuKernel { // Cast the entry point address to a function pointer. auto Entry = EntrySym->getValue(); - + kernels.try_emplace( - identifier, - std::make_unique(identifier, num_out, Entry)); + identifier, std::make_unique(identifier, num_out, Entry)); return identifier; } @@ -428,11 +477,11 @@ class CpuKernel { void call(void *out, void **ins) const { void **outs = num_out > 1 ? reinterpret_cast(out) : &out; - for(int i=0; i kernel_mutex; }; -llvm::DenseMap> - CpuKernel::kernels; +llvm::DenseMap> CpuKernel::kernels; int64_t CpuKernel::last_identifier = 1; llvm::sys::SmartRWMutex CpuKernel::kernel_mutex; std::unique_ptr CpuKernel::DL; -std::unique_ptr CpuKernel::JIT = nullptr; -// llvm::orc::ExecutionSession CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create())); +std::unique_ptr CpuKernel::JIT = nullptr; +// llvm::orc::ExecutionSession +// CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create())); } // namespace void CpuCallback(void *out, void **ins) { @@ -468,14 +517,15 @@ PYBIND11_MODULE(enzyme_call, m) { llvm::InitializeAllAsmParsers(); pybind11::enum_(m, "Language") - .value("CPP", Language::CPP) - .value("LLVM", Language::LLVM) - .value("MHLO", Language::MHLO); + .value("CPP", Language::CPP) + .value("LLVM", Language::LLVM) + .value("MHLO", Language::MHLO); m.def("create_enzyme_cpu_kernel", - [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, - const pybind11::list &py_in_shapes, - pybind11::object pyargv, int mode, Language lang) -> int64_t { + [](const std::string &source, const std::string &fn, + const pybind11::list &py_out_shapes, + const pybind11::list &py_in_shapes, pybind11::object pyargv, + int mode, Language lang) -> int64_t { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -509,13 +559,16 @@ PYBIND11_MODULE(enzyme_call, m) { target.push_back(nested_element.cast()); } } - return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), mode, (Language)lang); + return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, + in_types, pyargv.ptr(), mode, + (Language)lang); }); m.def("tape_size", - [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, - const pybind11::list &py_in_shapes, - pybind11::object pyargv, Language lang) -> int64_t { + [](const std::string &source, const std::string &fn, + const pybind11::list &py_out_shapes, + const pybind11::list &py_in_shapes, pybind11::object pyargv, + Language lang) -> int64_t { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -549,7 +602,9 @@ PYBIND11_MODULE(enzyme_call, m) { target.push_back(nested_element.cast()); } } - return (int64_t)CpuKernel::tapeSize(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), (Language)lang); + return (int64_t)CpuKernel::tapeSize(fn, source, out_shapes, out_types, + in_shapes, in_types, pyargv.ptr(), + (Language)lang); }); m.def("get_cpu_callback", []() { @@ -567,4 +622,3 @@ PYBIND11_MODULE(enzyme_call, m) { return *llvm_ir; }); } - diff --git a/enzyme_jax/primitives.py b/enzyme_jax/primitives.py index 31ebfad71..400bf3c27 100644 --- a/enzyme_jax/primitives.py +++ b/enzyme_jax/primitives.py @@ -1,20 +1,19 @@ """JAX primitives for Enzyme connection.""" -from functools import partial from collections.abc import Callable, Sequence -from typing import Any +from functools import partial import itertools import sys +from typing import Any import jax from jax import lax -from jax.interpreters import mlir as jax_mlir from jax.interpreters import ad -from jaxlib.mlir import ir -from jaxlib.mlir.dialects import stablehlo +from jax.interpreters import mlir as jax_mlir from jax.lib import xla_client - import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import stablehlo from . import enzyme_call @@ -22,81 +21,103 @@ LANG_LLVM = enzyme_call.Language.LLVM LANG_MHLO = enzyme_call.Language.MHLO + def resource_dir(): import os + dn = os.path.dirname(enzyme_call.__file__) return os.path.join(dn, "..", "clang", "staging") + def cflags(): import platform import os - if platform.system() == 'Darwin': - return ('-isysroot', '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk', "-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1", "-internal-isystem", os.path.join(resource_dir(), "include"), "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include", "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include", "-fgnuc-version=4.2.1") + + if platform.system() == "Darwin": + return ( + "-isysroot", + "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk", + "-isystem", + "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1", + "-internal-isystem", + os.path.join(resource_dir(), "include"), + "-internal-externc-isystem", + "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include", + "-internal-externc-isystem", + "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include", + "-fgnuc-version=4.2.1", + ) else: return () + def _enzyme_primal_impl( *args_flat: jax.Array, source, fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") + def _enzyme_fwd_impl( *args_flat: jax.Array, source, fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") + def _enzyme_aug_impl( *args_flat: jax.Array, source, fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") + def _enzyme_shadow_aug_impl( *args_flat: jax.Array, source, fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") + def _enzyme_rev_impl( *args_flat: jax.Array, source, fn: str, argv: Sequence[str], in_shapes, - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") + def _enzyme_primal_abstract_eval( *args_flat: jax.core.ShapedArray, source, fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: del source, fn, args_flat @@ -104,6 +125,7 @@ def _enzyme_primal_abstract_eval( # result types instead. return tuple(out_shapes) + def _enzyme_fwd_abstract_eval( *args_flat: jax.core.ShapedArray, source, @@ -117,20 +139,21 @@ def _enzyme_fwd_abstract_eval( # each return is duplicated return tuple(o for o in out_shapes for _ in range(2)) + def absmaketup(ty): tystr = ty.dtype.__str__() - tystr = {'float32':'float','float64':'double'}[tystr] + tystr = {"float32": "float", "float64": "double"}[tystr] return (tystr, ty.shape) + def _enzyme_aug_abstract_eval( *args_flat: jax.core.ShapedArray, source, fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang : enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: - in_shapes = args_flat prev_out_shapes = out_shapes @@ -143,13 +166,17 @@ def _enzyme_aug_abstract_eval( (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat) lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') + mhlo = lowered_func.compiler_ir(dialect="mhlo") source = str(mhlo) - argv = argv + ( "-resource-dir", resource_dir()) + cflags() + argv = argv + ("-resource-dir", resource_dir()) + cflags() - tapeSize = enzyme_call.tape_size(source, fn, out_shapes, in_shapes, argv, lang) - res = tuple(prev_out_shapes) + (jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)),) + tapeSize = enzyme_call.tape_size( + source, fn, out_shapes, in_shapes, argv, lang + ) + res = tuple(prev_out_shapes) + ( + jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)), + ) return res @@ -159,26 +186,30 @@ def _enzyme_shadow_aug_abstract_eval( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: return out_shapes + def _enzyme_rev_abstract_eval( *args_flat: jax.core.ShapedArray, source, fn: str, argv: Sequence[str], in_shapes, - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: del source, fn, args_flat - return tuple(jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes) + return tuple( + jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes + ) + def maketup(ty): ty = ir.RankedTensorType(ty) tystr = ty.element_type.__str__() - tystr = {'f32':'float','f64':'double'}[tystr] + tystr = {"f32": "float", "f64": "double"}[tystr] return (tystr, ty.shape) @@ -189,7 +220,7 @@ def _enzyme_primal_lowering( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: del out_shapes @@ -204,18 +235,20 @@ def _enzyme_primal_lowering( (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') + mhlo = lowered_func.compiler_ir(dialect="mhlo") source = str(mhlo) - argv = argv + ( "-resource-dir", resource_dir() ) + cflags() + argv = argv + ("-resource-dir", resource_dir()) + cflags() mode = 0 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, mode, lang + ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) mlir_args = (identifier_op, *args_flat) custom_call = stablehlo.CustomCallOp( - out_types, mlir_args, call_target_name="jaxzyme.primal" + out_types, mlir_args, call_target_name="jaxzyme.primal" ) return custom_call.results @@ -228,7 +261,7 @@ def _enzyme_fwd_lowering( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: del out_shapes @@ -244,18 +277,20 @@ def _enzyme_fwd_lowering( (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2]) lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') + mhlo = lowered_func.compiler_ir(dialect="mhlo") source = str(mhlo) - argv = argv + ( "-resource-dir", resource_dir() ) + cflags() + argv = argv + ("-resource-dir", resource_dir()) + cflags() mode = 1 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, mode, lang + ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) mlir_args = (identifier_op, *args_flat) custom_call = stablehlo.CustomCallOp( - out_types, mlir_args, call_target_name="jaxzyme.fwd" + out_types, mlir_args, call_target_name="jaxzyme.fwd" ) return custom_call.results @@ -268,7 +303,7 @@ def _enzyme_aug_lowering( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: del out_shapes @@ -276,7 +311,7 @@ def _enzyme_aug_lowering( itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) ) - out_shapes = list(map(maketup, out_types[:len(out_types)-1])) + out_shapes = list(map(maketup, out_types[: len(out_types) - 1])) in_shapes = list(map(lambda x: maketup(x.type), args_flat)) @@ -284,22 +319,25 @@ def _enzyme_aug_lowering( (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') + mhlo = lowered_func.compiler_ir(dialect="mhlo") source = str(mhlo) - argv = argv + ( "-resource-dir", resource_dir()) + cflags() + argv = argv + ("-resource-dir", resource_dir()) + cflags() mode = 2 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, mode, lang + ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) mlir_args = (identifier_op, *args_flat) custom_call = stablehlo.CustomCallOp( - out_types, mlir_args, call_target_name="jaxzyme.aug" + out_types, mlir_args, call_target_name="jaxzyme.aug" ) return custom_call.results + def _enzyme_rev_lowering( ctx: jax_mlir.LoweringRuleContext, *args_flat: ir.Value, @@ -307,7 +345,7 @@ def _enzyme_rev_lowering( fn: str, argv: Sequence[str], in_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: del in_shapes @@ -323,33 +361,61 @@ def _enzyme_rev_lowering( (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') + mhlo = lowered_func.compiler_ir(dialect="mhlo") source = str(mhlo) - argv = tuple(argv) + ( "-resource-dir", resource_dir()) + cflags() + argv = tuple(argv) + ("-resource-dir", resource_dir()) + cflags() mode = 3 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, mode, lang + ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) mlir_args = (identifier_op, *args_flat) custom_call = stablehlo.CustomCallOp( - in_types, mlir_args, call_target_name="jaxzyme.rev" + in_types, mlir_args, call_target_name="jaxzyme.rev" ) return custom_call.results -def ffi_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source, fn:str="f", argv: tuple[str]=(), lang:int=LANG_CPP): + +def ffi_call( + *args, + out_shapes: Sequence[jax.core.ShapedArray], + source, + fn: str = "f", + argv: tuple[str] = (), + lang: int = LANG_CPP, +): return _enzyme_primal_p.bind( - *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang) + *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang + ) + + +def cpp_call( + *args, + out_shapes: Sequence[jax.core.ShapedArray], + source: str, + fn: str = "f", + argv: tuple[str] = (), +): + return ffi_call( + *args, + source=source, + fn=fn, + argv=argv, + out_shapes=out_shapes, + lang=LANG_CPP, + ) -def cpp_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source: str, fn:str="f", argv: tuple[str]=()): - return ffi_call(*args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP) _enzyme_primal_p = jax.core.Primitive("enzyme_primal") _enzyme_primal_p.multiple_results = True _enzyme_primal_p.def_impl(_enzyme_primal_impl) _enzyme_primal_p.def_abstract_eval(_enzyme_primal_abstract_eval) -jax_mlir.register_lowering(_enzyme_primal_p, _enzyme_primal_lowering, platform="cpu") +jax_mlir.register_lowering( + _enzyme_primal_p, _enzyme_primal_lowering, platform="cpu" +) xla_client.register_custom_call_target( "jaxzyme.primal", enzyme_call.get_cpu_callback(), platform="cpu" @@ -365,26 +431,39 @@ def cpp_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source: str, fn: "jaxzyme.fwd", enzyme_call.get_cpu_callback(), platform="cpu" ) + def enzyme_jvp(arg_primals, arg_tangents, **kwargs): - + # TODO propagate activity info rather than make_zero def make_zero(tan, prim): - return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan + return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan - arg_tangents = tuple(make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals)) + arg_tangents = tuple( + make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals) + ) args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t) shadconv = _enzyme_fwd_p.bind( - *args, source=kwargs['source'], fn=kwargs['fn'], argv=kwargs['argv'], out_shapes=kwargs['out_shapes'], lang=kwargs['lang']) + *args, + source=kwargs["source"], + fn=kwargs["fn"], + argv=kwargs["argv"], + out_shapes=kwargs["out_shapes"], + lang=kwargs["lang"], + ) res = (shadconv[0::2], shadconv[1::2]) return res + ad.primitive_jvps[_enzyme_primal_p] = enzyme_jvp + def jaxify(x): - return {'float32':0, 'float64':1}[x.__str__()] + return {"float32": 0, "float64": 1}[x.__str__()] + def dejaxify(x): - return {0:jnp.float32, 1:jnp.float64}[x] + return {0: jnp.float32, 1: jnp.float64}[x] + _enzyme_aug_p = jax.core.Primitive("enzyme_aug") _enzyme_aug_p.multiple_results = True @@ -414,6 +493,7 @@ def dejaxify(x): from jax._src.interpreters import partial_eval as pe + def fwd_partial_eval(trace, *args, **kwargs): assert len(args) % 2 == 0 nr_primals = len(args) // 2 @@ -424,36 +504,41 @@ def fwd_partial_eval(trace, *args, **kwargs): if not (all_primals_known and some_tangents_unknown): return trace.default_process_primitive(_enzyme_fwd_p, args, kwargs) - outs_known = trace.default_process_primitive( - _enzyme_aug_p, primals, kwargs) + outs_known = trace.default_process_primitive(_enzyme_aug_p, primals, kwargs) shadow_aug_args = (trace.full_raise(outs_known[-1]),) + primals + tangents shadows_known = trace.default_process_primitive( - _enzyme_shadow_aug_p, shadow_aug_args, - kwargs) + _enzyme_shadow_aug_p, shadow_aug_args, kwargs + ) outs = tuple(v for tup in zip(outs_known[:-1], shadows_known) for v in tup) return outs + pe.custom_partial_eval_rules[_enzyme_fwd_p] = fwd_partial_eval + def enzyme_vjp(shadow_rets, *prim_args, **kwargs): - out_shapes = kwargs['out_shapes'] - del kwargs['out_shapes'] + out_shapes = kwargs["out_shapes"] + del kwargs["out_shapes"] shadows = [ad.is_undefined_primal(x) for x in prim_args] tape = prim_args[0] - prim_args = prim_args[1:1+(len(prim_args)-1)//2] - prim_args = tuple(jnp.ones(x.aval.shape, x.aval.dtype) if ad.is_undefined_primal(x) else x for x in prim_args) + prim_args = prim_args[1 : 1 + (len(prim_args) - 1) // 2] + prim_args = tuple( + jnp.ones(x.aval.shape, x.aval.dtype) if ad.is_undefined_primal(x) else x + for x in prim_args + ) in_shapes = tuple((a.shape, jaxify(a.dtype)) for a in prim_args) - args = (tape, ) + tuple(shadow_rets) - shadconv = _enzyme_rev_p.bind( - *args, **kwargs, in_shapes=in_shapes) + args = (tape,) + tuple(shadow_rets) + shadconv = _enzyme_rev_p.bind(*args, **kwargs, in_shapes=in_shapes) res = (None,) + tuple(None for _ in range(len(shadconv))) + tuple(shadconv) return res + ad.primitive_transposes[_enzyme_shadow_aug_p] = enzyme_vjp + def enzyme_jax_ir(argv=()): def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @jax.jit @@ -461,8 +546,19 @@ def wrapped(*args: Any): args_flat, in_tree = jax.tree_util.tree_flatten(args) out_shape = jax.eval_shape(func, *args) out_shape_flat, out_tree = jax.tree_util.tree_flatten(out_shape) - out_shape_flat = [jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat] - out_flat = ffi_call(*args_flat, source=(in_tree, func), fn="", out_shapes=out_shape_flat, argv=argv, lang=LANG_MHLO) + out_shape_flat = [ + jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat + ] + out_flat = ffi_call( + *args_flat, + source=(in_tree, func), + fn="", + out_shapes=out_shape_flat, + argv=argv, + lang=LANG_MHLO, + ) return jax.tree_util.tree_unflatten(out_tree, out_flat) + return wrapped - return decorator \ No newline at end of file + + return decorator