Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 7, 2023
1 parent 323d2c8 commit df9d226
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 39 deletions.
8 changes: 4 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ http_archive(
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
llvm_configure(name = "llvm-project", targets = LLVM_TARGETS)

XLA_COMMIT = "10b834edcabcaa78750368efc7556da22482c57e"
XLA_SHA256 = "f5a29a7236b486c2cc185509f2d80faefcaab5c900b822cdda7ba809a930a8b2"
XLA_COMMIT = "a6e6c1f6a53d4a23451c649110519c7ba8581bf9"
XLA_SHA256 = "5fe6dfa30621bd50b022a6cab026d6f4cde9883a3e150ce1b6fd52822a57c59a"

http_archive(
name = "xla",
Expand Down Expand Up @@ -70,8 +70,8 @@ http_archive(
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)

JAX_COMMIT = "32a317f7a43440800e1e39e00ed5f2980e088ab1"
JAX_SHA256 = "6e2147be7360a5c0672b6ba0d654cdb2ac96113b63ef457dfdc76cd50fe69ff1"
JAX_COMMIT = "f691fe468a8e1f8545f7d624055d58b823ee3201"
JAX_SHA256 = ""

http_archive(
name = "jax",
Expand Down
2 changes: 2 additions & 0 deletions enzyme_jax/clang_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@ struct tensor<T, n0, N...>

ModulePassManager MPM;
PB.parsePassPipeline(MPM, "default<O3>");
llvm::errs() << " pre: " << *mod << "\n";
MPM.run(*mod, MAM);
llvm::errs() << " post: " << *mod << "\n";

return mod;
}
27 changes: 18 additions & 9 deletions enzyme_jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class CpuKernel {

ss << " __attribute__((always_inline)) static inline void abi_wrap(";
bool comma = false;
for (size_t i = 0, off = 0; i < out_shapes.size(); i++) {
for (size_t i = 0; i < out_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " " << make_type(out_names[i], out_shapes[i], false, lang)
Expand All @@ -259,7 +259,7 @@ class CpuKernel {
ss << " enzyme::tensor<char, " << tmpBuf << "> & __restrict__ tmpBuf";
comma = true;
}
for (size_t i = 0, off = 0; i < in_shapes.size(); i++) {
for (size_t i = 0; i < in_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " " << make_type(in_names[i], in_shapes[i], true, lang) << "& in_"
Expand Down Expand Up @@ -341,6 +341,10 @@ class CpuKernel {
}
}
for (auto &buf : assignment.Allocations()) {
if (buf.is_thread_local()) {
ss << " char local_" << buf.index() << "[" << buf.size() << "];\n";
continue;
}
if (!buf.maybe_live_out())
continue;
if (!buf.is_tuple())
Expand Down Expand Up @@ -387,6 +391,9 @@ class CpuKernel {
} else if (buf.is_constant()) {
ss << " "
<< "(void*)&const_" << buf.index();
} else if (buf.is_thread_local()) {
ss << " "
<< "(void*)&local_" << buf.index();
} else {
std::string err;
llvm::raw_string_ostream ess(err);
Expand All @@ -400,14 +407,14 @@ class CpuKernel {
}
} else {
comma = false;
for (size_t i = 0, off = 0; i < out_shapes.size(); i++) {
for (size_t i = 0; i < out_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " "
<< "(void*)&out_" << i;
comma = true;
}
for (size_t i = 0, off = 0; i < in_shapes.size(); i++) {
for (size_t i = 0; i < in_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " "
Expand All @@ -431,7 +438,7 @@ class CpuKernel {
if (mode != ABI::Primal) {
ss << " void entry_wrap(";
bool comma = false;
for (size_t i = 0, off = 0; i < out_shapes.size(); i++) {
for (size_t i = 0; i < out_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " " << make_type(out_names[i], out_shapes[i], false, lang)
Expand All @@ -444,7 +451,7 @@ class CpuKernel {
ss << " enzyme::tensor<char, " << tmpBuf << "> & __restrict__ tmpBuf";
comma = true;
}
for (size_t i = 0, off = 0; i < in_shapes.size(); i++) {
for (size_t i = 0; i < in_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " " << make_type(in_names[i], in_shapes[i], true, lang) << "& in_"
Expand All @@ -454,7 +461,7 @@ class CpuKernel {
ss << ") {\n";
ss << " " << fn << "(";
comma = false;
for (size_t i = 0, off = 0; i < out_shapes.size(); i++) {
for (size_t i = 0; i < out_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " "
Expand All @@ -468,7 +475,7 @@ class CpuKernel {
<< "tmpBuf";
comma = true;
}
for (size_t i = 0, off = 0; i < in_shapes.size(); i++) {
for (size_t i = 0; i < in_shapes.size(); i++) {
if (comma)
ss << ", ";
ss << " "
Expand Down Expand Up @@ -517,7 +524,7 @@ class CpuKernel {
}
}

for (size_t i = 0, off = 0; i < in_shapes.size(); i++) {
for (size_t i = 0; i < in_shapes.size(); i++) {
if (mode != ABI::Reverse && mode != ABI::Tape) {
ss << " " << make_type(in_names[i], in_shapes[i], true, lang) << "& in_"
<< i << " = "
Expand Down Expand Up @@ -713,6 +720,8 @@ class CpuKernel {
#endif
}

llvm::errs() << " inp: " << ss.str() << "\n";

auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true,
pyargv_strs, llvm_ctx.get(), std::move(linkMod));
if (!mod)
Expand Down
64 changes: 55 additions & 9 deletions enzyme_jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def _enzyme_aug_abstract_eval(
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect='mhlo')
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]

argv = argv + ( "-resource-dir", resource_dir()) + cflags()

Expand Down Expand Up @@ -175,7 +177,10 @@ def maketup(ty):
tystr = {'f32':'float','f64':'double','i32':'int32_t','i64':'int64_t'}[tystr]
return (tystr, ty.shape)


def to_jax(ty):
tystr = ty.__str__()
return {'f32':jnp.float32,'f64':jnp.float64}[tystr]

def _enzyme_primal_lowering(
ctx: jax_mlir.LoweringRuleContext,
*args_flat: ir.Value,
Expand Down Expand Up @@ -203,8 +208,9 @@ def _enzyme_primal_lowering(
mhlo = lowered_func.compiler_ir(dialect='mhlo')
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = [arg for (i, arg) in enumerate(in_args) if i in kept]
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
print(kept, in_args, in_shapes)

argv = argv + ( "-resource-dir", resource_dir() ) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Primal, lang)
Expand Down Expand Up @@ -246,20 +252,25 @@ def _enzyme_fwd_lowering(
out_shapes = list(map(maketup, out_types[::2]))

in_shapes = list(map(lambda x: maketup(x.type), args_flat[::2]))

in_args = (*args_flat,)

if lang == LANG_MHLO:
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2])
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect='mhlo')
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i//2 in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]

argv = argv + ( "-resource-dir", resource_dir() ) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Forward, lang)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)

mlir_args = (identifier_op, *args_flat)
mlir_args = (identifier_op,) + in_args

if tmpBuf != 0:
sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8))
Expand Down Expand Up @@ -295,19 +306,25 @@ def _enzyme_aug_lowering(

in_shapes = list(map(lambda x: maketup(x.type), args_flat))

in_args = (*args_flat,)

if lang == LANG_MHLO:
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect='mhlo')
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
print(kept, in_args, in_shapes)

argv = argv + ( "-resource-dir", resource_dir()) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Augmented, lang)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)

mlir_args = (identifier_op, *args_flat)
mlir_args = (identifier_op,) + in_args
custom_call = stablehlo.CustomCallOp(
out_types, mlir_args, call_target_name="jaxzyme.aug"
)
Expand All @@ -328,36 +345,65 @@ def _enzyme_rev_lowering(
) -> Sequence[ir.Value]:
del in_shapes

in_types = tuple(
pre_in_types = tuple(
itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out))
)

in_shapes = list(map(maketup, in_types))
in_shapes = list(map(maketup, pre_in_types))
pre_in_shapes = in_shapes

out_shapes = list(map(lambda x: maketup(x.type), args_flat[1:]))

in_args = (*args_flat,)

rev_return_types = pre_in_types

kept = None
if lang == LANG_MHLO:
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect='mhlo')
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
# in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
rev_return_types = [retty for (i, retty) in enumerate(rev_return_types) if i in kept]

argv = tuple(argv) + ( "-resource-dir", resource_dir()) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Reverse, lang)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)

mlir_args = (identifier_op, *args_flat)
mlir_args = (identifier_op,) + in_args

if tmpBuf != 0:
mlir_args += (stablehlo.ZeroOp(tmpBuf),)

rev_return_types = in_types

custom_call = stablehlo.CustomCallOp(
rev_return_types, mlir_args, call_target_name="jaxzyme.rev"
)
return custom_call.results
results = custom_call.results
if kept != None:
print("results", results)
results = []
cur_idx = 0
for i, ty in enumerate(pre_in_types):
if i in kept:
results.append(custom_call.results[cur_idx])
cur_idx += 1
else:
ty = ir.RankedTensorType(ty)
print(type(ty), ty, dir(ty))
shape = ty.shape
print(type(shape), shape, dir(shape))
element_type = ty.element_type
print(type(element_type), element_type, dir(element_type))
import numpy as np
results.append(stablehlo.ConstantOp(ir.DenseElementsAttr.get(np.zeros(shape, dtype=to_jax(element_type)))).results[0])
print("results", results)
return results

def ffi_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source, fn:str="f", argv: tuple[str]=(), lang:int=LANG_CPP):
return _enzyme_primal_p.bind(
Expand Down
Loading

0 comments on commit df9d226

Please sign in to comment.