Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing ')' to enzyme_call and add tests for old pipeline #31

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ class CpuKernel {
comma = true;
}
}
ss << "};\n";
ss << " " << fn
<< "(nullptr, nullptr, nullptr, buffers, nullptr, nullptr);\n";
}
Expand Down
15 changes: 13 additions & 2 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@
LANG_LLVM = enzyme_call.Language.LLVM
LANG_MHLO = enzyme_call.Language.MHLO

## options
## true (default) -> new xla pipeline, default passes
## false -> old xla pipeline, internal passes
## string -> new xla pipeline, using passes as specified


def xla_runtime(options):
return True
if type(options) == type(False) and options == False:
return False
else:
return True


def pass_pipeline(options):
return """
if type(options) == type(""):
return options
else:
return """
inline{default-pipeline=canonicalize max-iterations=4},
expand-hlo-tuples{entry-function=main},
func.func(mhlo-flatten-tuple),
Expand Down
173 changes: 156 additions & 17 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def add_one(x: jax.Array, y) -> jax.Array:
return x + 1 + y


@enzyme_jax_ir(pipeline_options=False)
def add_one_old(x: jax.Array, y) -> jax.Array:
return x + 1 + y


@jax.jit
def add_one_plain(x: jax.Array, y) -> jax.Array:
return x + 1 + y
Expand All @@ -20,6 +25,11 @@ def add_two(x: jax.Array, z, y) -> jax.Array:
return x + y


@enzyme_jax_ir(pipeline_options=False)
def add_two_old(x: jax.Array, z, y) -> jax.Array:
return x + y


@jax.jit
def add_two_plain(x: jax.Array, z, y) -> jax.Array:
return x + y
Expand All @@ -30,6 +40,11 @@ def fwd(in0, in1, din0, din1):
return jax.jvp(add_one, (in0, in1), (din0, din1))


@jax.jit
def fwd_old(in0, in1, din0, din1):
return jax.jvp(add_one_old, (in0, in1), (din0, din1))


@jax.jit
def fwd_plain(in0, in1, din0, din1):
return jax.jvp(add_one_plain, (in0, in1), (din0, din1))
Expand All @@ -40,6 +55,11 @@ def fwd2(in0, in1, in2, din0, din1, din2):
return jax.jvp(add_two, (in0, in1, in2), (din0, din1, din2))


@jax.jit
def fwd2_old(in0, in1, in2, din0, din1, din2):
return jax.jvp(add_two_old, (in0, in1, in2), (din0, din1, din2))


@jax.jit
def fwd2_plain(in0, in1, in2, din0, din1, din2):
return jax.jvp(add_two_plain, (in0, in1, in2), (din0, din1, din2))
Expand All @@ -52,6 +72,13 @@ def rev(in0, in1, dout):
return primals, grads


@jax.jit
def rev_old(in0, in1, dout):
primals, f_vjp = jax.vjp(add_one_old, in0, in1)
grads = f_vjp(dout)
return primals, grads


@jax.jit
def rev_plain(in0, in1, dout):
primals, f_vjp = jax.vjp(add_one_plain, in0, in1)
Expand All @@ -66,6 +93,13 @@ def rev2(in0, in1, in2, dout):
return primals, grads


@jax.jit
def rev2_old(in0, in1, in2, dout):
primals, f_vjp = jax.vjp(add_two_old, in0, in1, in2)
grads = f_vjp(dout)
return primals, grads


@jax.jit
def rev2_plain(in0, in1, in2, dout):
primals, f_vjp = jax.vjp(add_two_plain, in0, in1, in2)
Expand All @@ -83,9 +117,13 @@ def setUp(self):
self.din2 = jnp.array([1300.0, 1700.0, 1900.0])

def test_add_one_primal(self):
ao = add_one(self.in0, self.in1)
aop = add_one_plain(self.in0, self.in1)

ao = add_one(self.in0, self.in1)
ao_old = add_one(self.in0, self.in1)

self.assertTrue((jnp.abs(ao - aop) < 1e-6).all())
self.assertTrue((jnp.abs(ao_old - aop) < 1e-6).all())

# Benchmark.
print(
Expand All @@ -94,6 +132,12 @@ def test_add_one_primal(self):
globals={"add_one": add_one, "in0": self.in0, "in1": self.in1},
).timeit()
)
print(
timeit.Timer(
"add_one_old(in0, in1)",
globals={"add_one_old": add_one_old, "in0": self.in0, "in1": self.in1},
).timeit()
)
print(
timeit.Timer(
"add_one_plain(in0, in1)",
Expand All @@ -106,17 +150,25 @@ def test_add_one_primal(self):
)

def test_add_two_deadarg(self):
at = add_two(self.in0, self.in1, self.in2)
atp = add_two_plain(self.in0, self.in1, self.in2)

at = add_two(self.in0, self.in1, self.in2)
ato = add_two_old(self.in0, self.in1, self.in2)

self.assertTrue((jnp.abs(at - atp) < 1e-6).all())
self.assertTrue((jnp.abs(ato - atp) < 1e-6).all())

def test_add_one_forward(self):
primals, tangents = fwd(self.in0, self.in1, self.din0, self.din1)
primals_p, tangents_p = fwd_plain(self.in0, self.in1, self.din0, self.din1)

primals, tangents = fwd(self.in0, self.in1, self.din0, self.din1)
primals_old, tangents_old = fwd_old(self.in0, self.in1, self.din0, self.din1)

self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for t, t_p in zip(tangents, tangents_p):
self.assertTrue((jnp.abs(primals_old - primals_p) < 1e-6).all())
for t, t_old, t_p in zip(tangents, tangents_old, tangents_p):
self.assertTrue((jnp.abs(t - t_p) < 1e-6).all())
self.assertTrue((jnp.abs(t_old - t_p) < 1e-6).all())

print(
timeit.Timer(
Expand All @@ -130,6 +182,18 @@ def test_add_one_forward(self):
},
).timeit()
)
print(
timeit.Timer(
"fwd_old(in0, in1, din0, din1)",
globals={
"fwd_old": fwd_old,
"in0": self.in0,
"in1": self.in1,
"din0": self.din0,
"din1": self.din1,
},
).timeit()
)
print(
timeit.Timer(
"fwd_plain(in0, in1, din0, din1)",
Expand All @@ -144,39 +208,56 @@ def test_add_one_forward(self):
)

def test_add_two_deadarg_forward(self):
primals_p, tangents_p = fwd2_plain(
self.in0, self.in1, self.in2, self.din0, self.din1, self.din2
)

primals, tangents = fwd2(
self.in0, self.in1, self.in2, self.din0, self.din1, self.din2
)
primals_p, tangents_p = fwd2_plain(

primals_o, tangents_o = fwd2_old(
self.in0, self.in1, self.in2, self.din0, self.din1, self.din2
)

print(primals, primals_p)
print(primals, primals_o, primals_p)
self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for i, (t, t_p) in enumerate(zip(tangents, tangents_p)):
print(i, t, t_p)
for i, (t, t_o, t_p) in enumerate(zip(tangents, tangents_o, tangents_p)):
print(i, to, t_p)
self.assertTrue((jnp.abs(t - t_p) < 1e-6).all())
self.assertTrue((jnp.abs(t_o - t_p) < 1e-6).all())

def test_add_one_reverse(self):
dout = jnp.array([500.0, 700.0, 110.0])
primals_p, grads_p = rev_plain(self.in0, self.in1, dout)

primals, grads = rev(self.in0, self.in1, dout)
# TODO enzyme will in place 0 the gradient inputs, which may not be expected
print(dout)
# TODO enzyme will in place 0 the gradient inputs, which may not be expected
dout = jnp.array([500.0, 700.0, 110.0])
primals_p, grads_p = rev_plain(self.in0, self.in1, dout)
primals, grads = rev(self.in0, self.in1, dout)

dout = jnp.array([500.0, 700.0, 110.0])
primals_old, grads_old = rev_old(self.in0, self.in1, dout)

self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for i, (g, g_p) in enumerate(zip(grads, grads_p)):
print(i, g, g_p)
self.assertTrue((jnp.abs(primals_old - primals_p) < 1e-6).all())
for i, (g, g_old, g_p) in enumerate(zip(grads, grads_old, grads_p)):
print(i, g, g_old, g_p)
self.assertTrue((jnp.abs(g - g_p) < 1e-6).all())
self.assertTrue((jnp.abs(g_old - g_p) < 1e-6).all())

print(
timeit.Timer(
"rev(in0, in1, dout)",
globals={"rev": rev, "in0": self.in0, "in1": self.in1, "dout": dout},
).timeit()
)
print(
timeit.Timer(
"rev_old(in0, in1, dout)",
globals={"rev": rev, "in0": self.in0, "in1": self.in1, "dout": dout},
).timeit()
)
print(
timeit.Timer(
"rev_plain(in0, in1, dout)",
Expand All @@ -191,28 +272,43 @@ def test_add_one_reverse(self):

def test_add_two_deadarg_reverse(self):
dout = jnp.array([500.0, 700.0, 110.0])
primals, grads = rev2(self.in0, self.in1, self.in2, dout)
primals_p, grads_p = rev2_plain(self.in0, self.in1, self.in2, dout)
# TODO enzyme will in place 0 the gradient inputs, which may not be expected
print(dout)
dout = jnp.array([500.0, 700.0, 110.0])
primals_p, grads_p = rev2_plain(self.in0, self.in1, self.in2, dout)
primals, grads = rev2(self.in0, self.in1, self.in2, dout)

dout = jnp.array([500.0, 700.0, 110.0])
primals_old, grads_old = rev2_old(self.in0, self.in1, self.in2, dout)

self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for i, (g, g_p) in enumerate(zip(grads, grads_p)):
print(i, g, g_p)
self.assertTrue((jnp.abs(primals_old - primals_p) < 1e-6).all())
for i, (g, g_old, g_p) in enumerate(zip(grads, grads_old, grads_p)):
print(i, g, g_old, g_p)
self.assertTrue((jnp.abs(g - g_p) < 1e-6).all())
self.assertTrue((jnp.abs(g_old - g_p) < 1e-6).all())


@enzyme_jax_ir()
def esum(x):
return jnp.sum(x)


@enzyme_jax_ir(pipeline_options=False)
def esum_old(x):
return jnp.sum(x)


@jax.jit
def sumfwd(in0, din0):
return jax.jvp(esum, (in0,), (din0,))


@jax.jit
def sumfwd_old(in0, din0):
return jax.jvp(esum_old, (in0,), (din0,))


@jax.jit
def sumrev_p(in0):
primals, f_vjp = jax.vjp(jnp.sum, in0)
Expand All @@ -227,6 +323,13 @@ def sumrev(in0):
return primals, grads


@jax.jit
def sumrev_old(in0):
primals, f_vjp = jax.vjp(esum_old, in0)
grads = f_vjp(1.0)
return primals, grads


class Sum(absltest.TestCase):
def setUp(self):
self.x = jnp.array(range(50), dtype=jnp.float32)
Expand All @@ -243,6 +346,12 @@ def test_forward(self):
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue(jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6)

def test_forward_old(self):
primals, tangents = sumfwd_old(self.x, self.dx)
print(primals, tangents)
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue(jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6)

def test_reverse_p(self):
primals, grads = sumrev_p(self.x)
print(primals, grads)
Expand All @@ -253,19 +362,37 @@ def test_reverse(self):
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue((jnp.abs(grads[0] - 1) < 1e-6).all())

def test_reverse_old(self):
primals, grads = sumrev_old(self.x)
print(primals, grads)
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue((jnp.abs(grads[0] - 1) < 1e-6).all())


@enzyme_jax_ir()
def ecache(x):
return x * x[0]


@enzyme_jax_ir(pipeline_options=False)
def ecache_old(x):
return x * x[0]


@jax.jit
def cacherev(in0, din0):
primals, f_vjp = jax.vjp(ecache, in0)
grads = f_vjp(din0)
return grads


@jax.jit
def cacherev_old(in0, din0):
primals, f_vjp = jax.vjp(ecache_old, in0)
grads = f_vjp(din0)
return grads


class Cache(absltest.TestCase):
def test_reverse(self):
dim = 288
Expand All @@ -279,6 +406,18 @@ def test_reverse(self):
)
self.assertTrue((jnp.abs(grads[0][1:]) < 1e-6).all())

def test_reverse_old(self):
dim = 288

x = jnp.array(range(dim), dtype=jnp.float32)
dx = jnp.array(range(dim), dtype=jnp.float32)

grads = cacherev_old(x, dx)
self.assertTrue(
jnp.abs(grads[0][0] - (dim - 1) * dim * (2 * (dim - 1) + 1) / 6) < 1e-6
)
self.assertTrue((jnp.abs(grads[0][1:]) < 1e-6).all())


if __name__ == "__main__":
absltest.main()
Loading