Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 2, 2024
1 parent e0aaf81 commit 00a9a72
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 69 deletions.
119 changes: 51 additions & 68 deletions src/enzyme_ad/jax/clang_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,22 @@
//===----------------------------------------------------------------------===//

#include "clang_compile.h"
#include "llvm/IRReader/IRReader.h"

#include <cstring>
#include <iostream>
#include <memory>
#include <Python.h>
#include <pybind11/pybind11.h>
#include <setjmp.h>
#include <signal.h>
#include <stdlib.h>
#include <string>
#include <sys/time.h>
#include <unistd.h>

#include "clang/CodeGen/CodeGenAction.h"
#include "llvm-c/Core.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/AsmParser/LLLexer.h"
#include "llvm/AsmParser/LLParser.h"
#include "llvm/AsmParser/LLToken.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/AsmParser/SlotMapping.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <cstring>
#include <iostream>
#include <memory>
#include <string>

#include "Enzyme/Enzyme.h"
#include "Enzyme/Utils.h"
#include "clang/AST/Decl.h"
#include "clang/Basic/DiagnosticOptions.h"
#include "clang/Basic/FileManager.h"
Expand All @@ -46,6 +32,7 @@
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/TargetOptions.h"
#include "clang/Basic/Version.h"
#include "clang/CodeGen/CodeGenAction.h"
#include "clang/Driver/Compilation.h"
#include "clang/Driver/Driver.h"
#include "clang/Driver/Tool.h"
Expand All @@ -60,21 +47,30 @@
#include "clang/Parse/Parser.h"
#include "clang/Sema/Sema.h"
#include "clang/Sema/SemaDiagnostic.h"
#include "llvm-c/Core.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/AsmParser/LLLexer.h"
#include "llvm/AsmParser/LLParser.h"
#include "llvm/AsmParser/LLToken.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/AsmParser/SlotMapping.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"

#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"

#include <Python.h>
#include <pybind11/pybind11.h>

#include "Enzyme/Enzyme.h"
#include "Enzyme/Utils.h"
#include "llvm/TargetParser/Host.h"

namespace clang {
namespace driver {
Expand All @@ -87,21 +83,21 @@ namespace tools {
void addDirectoryList(const llvm::opt::ArgList &Args,
llvm::opt::ArgStringList &CmdArgs, const char *ArgName,
const char *EnvVar);
} // namespace tools
} // namespace driver
} // namespace clang
} // namespace tools
} // namespace driver
} // namespace clang

using namespace clang;
using namespace llvm;

class ArgumentList {
private:
private:
/// Helper storage.
llvm::SmallVector<llvm::SmallString<0>> Storage;
/// List of arguments
llvm::opt::ArgStringList Args;

public:
public:
/// Add argument.
///
/// The element stored will not be owned by this.
Expand Down Expand Up @@ -167,10 +163,10 @@ static TargetMachine *GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr,
codegen::getExplicitCodeModel(), level);
}

std::unique_ptr<llvm::Module>
GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp,
ArrayRef<std::string> pyargv, LLVMContext *Context,
std::unique_ptr<llvm::Module> linkMod) {
std::unique_ptr<llvm::Module> GetLLVMFromJob(
std::string filename, std::string filecontents, bool cpp,
ArrayRef<std::string> pyargv, LLVMContext *Context,
std::unique_ptr<llvm::Module> linkMod) {
const llvm::opt::InputArgList Args;
const char *binary = cpp ? "clang++" : "clang";
// Buffer diagnostics from argument parsing so that we can output them using a
Expand All @@ -189,8 +185,7 @@ GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp,
ArgumentList Argv;

Argv.emplace_back(StringRef(filename));
for (auto v : pyargv)
Argv.emplace_back(v);
for (auto v : pyargv) Argv.emplace_back(v);

SmallVector<const char *> PreArgs;
PreArgs.push_back(binary);
Expand Down Expand Up @@ -227,7 +222,7 @@ GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp,
auto &TC = compilation->getDefaultToolChain();
if (cpp) {
bool HasStdlibxxIsystem =
false; // Args.hasArg(options::OPT_stdlibxx_isystem);
false; // Args.hasArg(options::OPT_stdlibxx_isystem);
HasStdlibxxIsystem
? TC.AddClangCXXStdlibIsystemArgs(Args, Argv.getArguments())
: TC.AddClangCXXStdlibIncludeArgs(Args, Argv.getArguments());
Expand Down Expand Up @@ -525,10 +520,8 @@ struct tensor<T, n0, N...>
}

for (auto &f : *mod) {
if (f.empty())
continue;
if (f.getName() == "entry")
continue;
if (f.empty()) continue;
if (f.getName() == "entry") continue;
f.setLinkage(Function::LinkageTypes::InternalLinkage);
}

Expand All @@ -545,7 +538,7 @@ struct tensor<T, n0, N...>
llvm::driver::createTLII(triple, Clang->getCodeGenOpts().getVecLib()));
FAM.registerPass([&] { return TargetLibraryAnalysis(*TLII); });

auto level = CodeGenOptLevel::Aggressive; // OptimizationLevel::O3;
auto level = CodeGenOptLevel::Aggressive; // OptimizationLevel::O3;

Triple ModuleTriple(mod->getTargetTriple());
std::string CPUStr, FeaturesStr;
Expand Down Expand Up @@ -579,53 +572,43 @@ struct tensor<T, n0, N...>
if (F) {
for (const auto user : llvm::make_early_inc_range(F->users())) {
auto CI = dyn_cast<CallInst>(user);
if (!CI)
continue;
if (!CI) continue;
std::deque<std::pair<llvm::Value *, llvm::Value *>> todo;
SmallVector<Value *, 1> cargs;
for (auto &arg : CI->args())
cargs.push_back(arg);
for (auto &arg : CI->args()) cargs.push_back(arg);
CI->eraseFromParent();
for (auto &arg : cargs) {
Value *cur = getBaseObject(arg);
assert(isa<LoadInst>(cur));
for (auto U : cur->users())
todo.emplace_back(U, cur);
for (auto U : cur->users()) todo.emplace_back(U, cur);
}
std::set<std::pair<Value *, Value *>> seen;
SmallPtrSet<Instruction *, 32> toErase;
while (todo.size()) {
auto pair = todo.back();
todo.pop_back();
auto [cur, prev] = pair;
if (seen.count(pair))
continue;
if (seen.count(pair)) continue;
seen.insert(pair);
if (isPointerArithmeticInst(cur)) {
for (auto u : cur->users())
todo.emplace_back(u, cur);
for (auto u : cur->users()) todo.emplace_back(u, cur);
continue;
}
if (isa<LoadInst>(cur))
continue;
if (isa<LoadInst>(cur)) continue;
if (auto MTI = dyn_cast<MemTransferInst>(cur)) {
if (MTI->getSource() == prev)
continue;
if (MTI->getSource() == prev) continue;
}
if (auto CI = dyn_cast<CallInst>(cur))
if (auto F = CI->getCalledFunction())
if (F->getName() == "memset_pattern16")
continue;
if (F->getName() == "memset_pattern16") continue;
if (auto MS = dyn_cast<MemSetInst>(cur)) {
toErase.insert(MS);
continue;
}
if (auto II = dyn_cast<IntrinsicInst>(cur)) {
if (II->getIntrinsicID() == llvm::Intrinsic::dbg_value)
continue;
if (II->getIntrinsicID() == llvm::Intrinsic::dbg_value) continue;
}
if (isa<ICmpInst>(cur))
continue;
if (isa<ICmpInst>(cur)) continue;
if (auto SI = dyn_cast<StoreInst>(cur)) {
assert(SI->getPointerOperand() == prev);
auto C = dyn_cast<Constant>(SI->getValueOperand());
Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output) {

xla::ExecutableBuildOptions build_options;
build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(true);
// build_options.mutable_debug_options()->set_xla_cpu_use_xla_runtime(true);
build_options.mutable_debug_options()->set_xla_cpu_use_xla_runtime(true);

if (build_options.device_ordinal() == -1) {
build_options.set_device_ordinal(local_client->default_device_ordinal());
Expand Down

0 comments on commit 00a9a72

Please sign in to comment.