Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider switching to jax for the core functions #115

Open
kjappelbaum opened this issue Nov 23, 2020 · 2 comments
Open

Consider switching to jax for the core functions #115

kjappelbaum opened this issue Nov 23, 2020 · 2 comments

Comments

@kjappelbaum
Copy link
Collaborator

kjappelbaum commented Nov 23, 2020

Feature description

We use numba to jit some of the core/utils functions right now. But we could just replace this dependency by jax. Additionally, we would get vectorization with vmap and potential parallelization with pmap. If we have a GPU, it will run all the matrix operations on the GPU.

Implementation idea

  1. Replace import numpy as np with import jax.numpy as np and replace the numba jit decorator with jax jit compilation
  2. Look into vectorizing some of the for loops

Alternatives

  1. Stick to numba
  2. Use cython for the localized for loops (pareto_classify)
@kjappelbaum
Copy link
Collaborator Author

I played a bit with it and it is a bit more effort than expected.

For example, we cannot just change the numpy import because we use inplace operations in several places in the code
Also the jit decorator threw a lot of errors

@github-actions
Copy link
Contributor

Stale issue message

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant