Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/jax wrapper autogalaxy fixes #156

Open
wants to merge 1 commit into
base: feature/jax_wrapper
Choose a base branch
from

Conversation

Jammy2211
Copy link
Owner

Quick fixes and debugging for getting JAX examples to run agian.

@Jammy2211 Jammy2211 changed the base branch from main to feature/jax_wrapper December 16, 2024 11:42
@Jammy2211
Copy link
Owner Author

I will document the fix here, as it will hopefully illustrate one example of the JAX implementaiton.

I ran the following script:

https://github.com/Jammy2211/autogalaxy_workspace_test/blob/master/jax_examples/task_2_simple_conversions/func_grad_manual.py

I received the following exception:

File "/mnt/c/Users/Jammy/Code/PyAutoJAX/autogalaxy_workspace_test/jax_examples/task_2_simple_conversions/func_grad_manual.py", line 52, in <module>
    dataset = ag.Imaging(
              ^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoArray/autoarray/dataset/imaging/dataset.py", line 137, in __init__
    psf = [Kernel2D.no](http://kernel2d.no/)_mask(
          ^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoArray/autoarray/structures/arrays/kernel_2d.py", line 86, in no_mask
    values = [Array2D.no](http://array2d.no/)_mask(
             ^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoArray/autoarray/structures/arrays/uniform_2d.py", line 612, in no_mask
    values = array_2d_util.convert_array(array=values)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoArray/autoarray/structures/arrays/array_2d_util.py", line 32, in convert_array
    array = jax.lax.cond(
            ^^^^^^^^^^^^^
  File "/mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoArray/autoarray/structures/arrays/array_2d_util.py", line 34, in <lambda>
    lambda _: np.asarray(array),
              ^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2683, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jammy/venvs/PyAuto311JAX/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2606, in array
    raise TypeError(f"Unexpected input type for array: {type(object)}")
TypeError: Unexpected input type for array: <class 'autoarray.structures.arrays.uniform_2d.Array2D'>
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The error TypeError: Unexpected input type for array: <class 'autoarray.structures.arrays.uniform_2d.Array2D'> typically means that a function call is not passing the the pure ndarray rerpresentaiton in the codde.

I hunted around and found the following line in autoarray/structures/arrays/kernel_2d.py:

        values = Array2D.no_mask(
            values=values,
            shape_native=shape_native,
            pixel_scales=pixel_scales,
            origin=origin,
        )

Here, values was the Array2D and not its ndarray rerpesentation used to make JAX work, so I updated the code to read:

        values = Array2D.no_mask(
            values=values.array,
            shape_native=shape_native,
            pixel_scales=pixel_scales,
            origin=origin,
        )

I also had to update the following:

self._array[:] = np.divide(self._array, np.sum(self._array))

To:

self._array = np.divide(self._array, np.sum(self._array))

This was more just standard JAX practise that you cant manipulate in place memory.

This illustrate is to help us explain in the meeting the .array API.

Copy link
Collaborator

@rhayes777 rhayes777 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run some automated linter on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants