You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We use numba to jit some of the core/utils functions right now. But we could just replace this dependency by jax. Additionally, we would get vectorization with vmap and potential parallelization with pmap. If we have a GPU, it will run all the matrix operations on the GPU.
Implementation idea
Replace import numpy as np with import jax.numpy as np and replace the numbajit decorator with jaxjit compilation
Look into vectorizing some of the for loops
Alternatives
Stick to numba
Use cython for the localized for loops (pareto_classify)
The text was updated successfully, but these errors were encountered:
I played a bit with it and it is a bit more effort than expected.
For example, we cannot just change the numpy import because we use inplace operations in several places in the code
Also the jit decorator threw a lot of errors
Feature description
We use
numba
to jit some of the core/utils functions right now. But we could just replace this dependency byjax
. Additionally, we would get vectorization withvmap
and potential parallelization withpmap
. If we have a GPU, it will run all the matrix operations on the GPU.Implementation idea
import numpy as np
withimport jax.numpy as np
and replace thenumba
jit
decorator withjax
jit
compilationAlternatives
numba
cython
for the localized for loops (pareto_classify
)The text was updated successfully, but these errors were encountered: