Skip to content

Commit

Permalink
feat: add numba.jit support
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored and spflueger committed Feb 18, 2021
1 parent 81bf402 commit 6da0e0b
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"ncalls",
"ndarray",
"noqa",
"numba",
"pandoc",
"phasespace",
"phsp",
Expand Down
1 change: 1 addition & 0 deletions reqs/3.6/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ nbstripout==0.3.9
nest-asyncio==1.5.1
nodeenv==1.5.0
notebook==6.2.0
numba==0.52.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
Expand Down
1 change: 1 addition & 0 deletions reqs/3.7/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ nbstripout==0.3.9
nest-asyncio==1.5.1
nodeenv==1.5.0
notebook==6.2.0
numba==0.52.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
Expand Down
1 change: 1 addition & 0 deletions reqs/3.8/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ nbstripout==0.3.9
nest-asyncio==1.5.1
nodeenv==1.5.0
notebook==6.2.0
numba==0.52.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ ignore_errors = True
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-numba.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-pandas.*]
Expand Down
11 changes: 11 additions & 0 deletions src/tensorwaves/physics/amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def jax_lambdify() -> Callable:
if isinstance(backend, str):
if backend == "jax":
return jax_lambdify()
if backend == "numba":
from numba import jit

return jit(
sympy.lambdify(
variables,
expression,
modules="numpy",
),
parallel=True,
)
if isinstance(backend, tuple):
if any("jax" in x.__name__ for x in backend):
return jax_lambdify()
Expand Down

0 comments on commit 6da0e0b

Please sign in to comment.