Skip to content

Commit

Permalink
fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 13, 2024
1 parent fb7179e commit 21e5d84
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from enzyme_ad.jax.primitives import cpp_call, enzyme_jax_ir
from enzyme_ad.jax.primitives import cpp_call, enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline
11 changes: 8 additions & 3 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pass_pipeline(self):
raise NotImplementedError()

# MLIR pass pipeline
def mlir_ad(self)
def mlir_ad(self):
raise NotImplementedError()

class OldXLAPipeline:
Expand All @@ -48,7 +48,7 @@ def mlir_ad(self):
return False

class NewXLAPipeline:
def __init__(self; passes=None, mlirad=False):
def __init__(self, passes=None, mlirad=False):
if passes is None:
passes = """
inline{default-pipeline=canonicalize max-iterations=4},
Expand Down Expand Up @@ -777,10 +777,15 @@ def make_zero(tan, prim):
pipeline_options = kwargs["pipeline_options"]

shadconv = None
if pipeline_options.mlir_ad()
if pipeline_options.mlir_ad():
act_tup = (",".join(["enzyme_dup" for a in args]))
newpasses = "enzyme-wrap{infn=main outfn=main retTy=enzyme_dup argTys="+act_tup+" mode=ForwardMode}," + pipeline_options.pass_pipeline()
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())
outshapes2 = []
for o in kwargs["out_shapes"]:
outshapes2.append(o)
outshapes2.append(o)
shadconv = ffi_call(*args, out_shapes=outshapes2, source=kwargs["source"], fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options)
else:
shadconv = _enzyme_fwd_p.bind(
*args,
Expand Down
45 changes: 27 additions & 18 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import jax
import jax.numpy as jnp
from enzyme_ad.jax import enzyme_jax_ir
from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline
from absl.testing import absltest
import timeit

AllPipelines = [("NewXLAMLIR", enzyme_ad.jax.NewXLAPipeline(mlirad=True)), ("NewXLA", enzyme_ad.jax.NewXLAPipeline()), ("OldXLA", enzyme_ad.jax.OldXLAPipeline())]
AllPipelines = [("NewXLAMLIR", NewXLAPipeline(mlirad=True)), ("NewXLA", NewXLAPipeline()), ("OldXLA", OldXLAPipeline())]
PrimalPipelines = AllPipelines[1:]
FwdPipelines = AllPipelines
RevPipelines = AllPipelines[1:]
Expand All @@ -15,7 +15,7 @@
def splatjvp(in_fn):
def fwd(*args):
assert len(args) % 2 == 0
return jax.jvp(in_fn, tuple(args[:len(args)/2]), tuple(args[len(args)/2:]))
return jax.jvp(in_fn, tuple(args[:len(args)//2]), tuple(args[len(args)//2:]))
return fwd

# @jax.jit
Expand All @@ -31,16 +31,20 @@ def rev(dout, *args):
return rev

class EnzymeJaxTest(absltest.TestCase):
def setUp(self):
self.name = None

def test(self):
if self.name is None:
return
self.harness(self.name, self.fn, self.ins, self.dins, self.douts)

def harness(self, name, in_fn, ins, dins, douts):
assert len(ins) == len(dins)
rfn_jax = jax.jit(in_fn)

aop = rfn_jax(*ins)

assert len(aop) == len(douts)
assert 1 == len(douts)

primalstr = "fn(" + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")"
primalins = {("in" + str(i)):ins[0] for i in range(len(ins))}
Expand All @@ -50,60 +54,65 @@ def harness(self, name, in_fn, ins, dins, douts):
primalstr,
globals={
"fn": rfn_jax,
} + primalins,
} | primalins,
).timeit()
)

fwd_jax = jax.jit(splatjvp(rfn_jax))

primals_p, tangents_p = fwd_jax(*(ins+dins))
print(primals_p)
print((jnp.abs(aop - primals_p) < 1e-6).all())
self.assertTrue((jnp.abs(aop - primals_p) < 1e-6).all())

fwdstr = "fwd(" + (", ".join(["in" + str(i) for i in range(len(ins))])) + ", " + (", ".join(["din" + str(i) for i in range(len(dins))])) + ")"
fwdins = primalins + {("din" + str(i)):dins[0] for i in range(len(dins))}
fwdins = primalins | {("din" + str(i)):dins[0] for i in range(len(dins))}
print(name + " JaX Fwd: ",
timeit.Timer(
fwdstr,
globals={
"fwd": fwd_jax,
} + fwdins,
} | fwdins,
).timeit()
)

assert len(douts) == 1

rev_jax = jax.jit(splatvjp(rfn_jax))

primals_p, grads_p = rev_jax(dout, *ins)

primals_p, grads_p = rev_jax(*douts, *ins)

print(primals_p)
print((jnp.abs(aop - primals_p) < 1e-6).all())
self.assertTrue((jnp.abs(aop - primals_p) < 1e-6).all())

revstr = "rev(dout, " + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")"
revins = primalins + {"dout":douts[0]}
revins = primalins | {"dout":douts[0]}

print(name + " JaX Rev: ",
timeit.Timer(
revstr,
globals={
"rev": rev_jax,
} + revins,
} | revins,
).timeit()
)

for (name, pipeline) in AllPipelines:
rfn_enzyme = enzyme_jax_ir(pipeline_options=pipeline)(in_fn)


if (name, pipeline) in PrimalPipelines:
ao = rfn_enzyme(*ins)
print(aop)
print((jnp.abs(aop - aop) < 1e-6).all())
self.assertTrue((jnp.abs(ao - aop) < 1e-6).all())

print(name + " EnzymeMLIR(",name,") Primal: ",
timeit.Timer(
primalstr,
globals={
"fn": rfn_enzyme,
} + primalins,
} | primalins,
).timeit()
)

Expand All @@ -122,14 +131,14 @@ def harness(self, name, in_fn, ins, dins, douts):
fwdstr,
globals={
"fwd": fwd_enzyme,
} + fwdins,
} | fwdins,
).timeit()
)

if (name, pipeline) in RevPipelines:
rev_enzyme = jax.jit(splatvjp(rfn_enzyme))

primals, grads = rev_jax(dout, *ins)
primals, grads = rev_enzyme(*douts, *ins)
self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())

for i, (g, g_p) in enumerate(zip(grads, grads_p)):
Expand All @@ -140,8 +149,8 @@ def harness(self, name, in_fn, ins, dins, douts):
timeit.Timer(
revstr,
globals={
"rev": rev_jax,
} + revins,
"rev": rev_enzyme,
} | revins,
).timeit()
)

Expand Down

0 comments on commit 21e5d84

Please sign in to comment.