Skip to content

JAX mode compiled function with unused RNGs keep them in the fgraph inputs #1428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lucianopaz opened this issue May 28, 2025 · 0 comments

Comments

@lucianopaz
Copy link
Member

Description

When we try to compile a function from a computational graph that initially included an RNG shared variable, where the RNG can be taken out of the FunctionGraph thanks to rewrite simplifications, the RNG still shows up in the fgraph.inputs list, even though the shared variable isn't in the fgraph.variables list. If the function gets compiled using the JAX mode, this raises a ValueError when the function tries to replace the original RNGs with copies:

Minimum reproducible example

import pytensor
import pymc as pm


pytensor.function(
    [],
    pm.Normal.dist().shape,
    mode="JAX"
)

Raises

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
    321     fn = orig_function(
    322         inputs,
    323         outputs,
   (...)    327         trust_input=trust_input,
    328     )
    329 else:
    330     # note: pfunc will also call orig_function -- orig_function is
    331     #      a choke point that all compilation must pass through
--> 332     fn = pfunc(
    333         params=inputs,
    334         outputs=outputs,
    335         mode=mode,
    336         updates=updates,
    337         givens=givens,
    338         no_default_updates=no_default_updates,
    339         accept_inplace=accept_inplace,
    340         name=name,
    341         rebuild_strict=rebuild_strict,
    342         allow_input_downcast=allow_input_downcast,
    343         on_unused_input=on_unused_input,
    344         profile=profile,
    345         output_keys=output_keys,
    346         trust_input=trust_input,
    347     )
    348 return fn

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
    452     profile = ProfileStats(message=profile)
    454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    455     params,
    456     outputs,
   (...)    463     fgraph=fgraph,
    464 )
--> 466 return orig_function(
    467     inputs,
    468     cloned_outputs,
    469     mode,
    470     accept_inplace=accept_inplace,
    471     name=name,
    472     profile=profile,
    473     on_unused_input=on_unused_input,
    474     output_keys=output_keys,
    475     fgraph=fgraph,
    476     trust_input=trust_input,
    477 )

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
   1820     m = Maker(
   1821         inputs,
   1822         outputs,
   (...)   1830         trust_input=trust_input,
   1831     )
   1832     with config.change_flags(compute_test_value="off"):
-> 1833         fn = m.create(defaults)
   1834 finally:
   1835     if profile and fn:

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
   1714 start_import_time = pytensor.link.c.cmodule.import_time
   1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717     _fn, _i, _o = self.linker.make_thunk(
   1718         input_storage=input_storage_lists, storage_map=storage_map
   1719     )
   1721 end_linker = time.perf_counter()
   1723 linker_time = end_linker - start_linker

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/miniforge3/envs/pymc/lib/python3.12/site-packages/pytensor/link/jax/linker.py:64, in JAXLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
     59         old_inp_fgrap_index = fgraph.inputs.index(old_inp)
     60         fgraph.remove_input(
     61             old_inp_fgrap_index,
     62             reason="JAXLinker.fgraph_convert",
     63         )
---> 64         fgraph.inputs.remove(new_inp)
     65         fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
     67 return jax_funcify(
     68     fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
     69 )

ValueError: list.remove(x): x not in list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant