Skip to content

Commit

Permalink
Replace jax.linear_util -> jax.extend.linear_util to suppress `De…
Browse files Browse the repository at this point in the history
…precationWarning`s

Also suppressed a pytype warning about conditional `tree` import.

PiperOrigin-RevId: 571299183
  • Loading branch information
alimuldal authored and copybara-github committed Oct 6, 2023
1 parent 94d1362 commit 2dedb42
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,12 @@ hk_py_library(
name = "dot",
srcs = ["dot.py"],
deps = [
":config",
":config", # build_cleaner: keep
":data_structures",
":module",
":utils",
# pip: jax
# pip: tree
# pip: jax:extend
],
)

Expand Down
9 changes: 5 additions & 4 deletions haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
import jax
import jax.core
from jax.experimental import pjit
from jax.extend import linear_util


# Import tree if available, but only throw error at runtime.
# Permits us to drop dm-tree from deps.
try:
import tree # pylint: disable=g-import-not-at-top
import tree # pylint: disable=g-import-not-at-top # pytype: disable=import-error
except ImportError:
tree = None

Expand Down Expand Up @@ -134,7 +135,7 @@ def to_graph(fun):
@functools.wraps(fun)
def wrapped_fun(*args):
"""See `fun`."""
f = jax.linear_util.wrap_init(fun)
f = linear_util.wrap_init(fun)
args_flat, in_tree = jax.tree_util.tree_flatten((args, {}))
flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree)
graph = Graph.create(title=name_or_str(fun))
Expand All @@ -160,7 +161,7 @@ def method_hook(mod: module.Module, method_name: str):
return wrapped_fun


@jax.linear_util.transformation
@linear_util.transformation
def _interpret_subtrace(main, *in_vals):
trace = DotTrace(main, jax.core.cur_sublevel())
in_tracers = [DotTracer(trace, val) for val in in_vals]
Expand Down Expand Up @@ -202,7 +203,7 @@ def process_primitive(self, primitive, tracers, params):
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
fun = jax.linear_util.wrap_init(f)
fun = linear_util.wrap_init(f)
return self.process_call(primitive, fun, tracers, params)

inputs = [t.val for t in tracers]
Expand Down

0 comments on commit 2dedb42

Please sign in to comment.