Skip to content

Commit

Permalink
Merge branch 'master' into current_verification
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis authored Jun 15, 2023
2 parents c6161cf + c4d61e3 commit c8257ab
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 28 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jobs:
strategy:
fail-fast: false
matrix:
jax-version: [0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17, 0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24, 0.3.25, 0.4.1]
jax-version: [0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17, 0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24, 0.3.25, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11]
group: [1, 2]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
Expand All @@ -35,4 +36,4 @@ jobs:
run: |
pwd
lscpu
python -m pytest -m unit --durations=0 --mpl --maxfail=1
python -m pytest -m unit --durations=0 --mpl --maxfail=1 --splits 2 --group ${{ matrix.group }} --splitting-algorithm least_duration
4 changes: 2 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
matrix:
python-version: ['3.9']
group: [1, 2, 3, 4]
group: [1, 2, 3, 4, 5]

steps:
- uses: actions/checkout@v3
Expand All @@ -41,7 +41,7 @@ jobs:
run: |
pwd
lscpu
python -m pytest -v -m regression --durations=0 --cov-report xml:cov.xml --cov-config=setup.cfg --cov=desc/ --mpl --mpl-results-path=mpl_results.html --mpl-generate-summary=html --splits 4 --group ${{ matrix.group }} --splitting-algorithm least_duration --db ./prof.db
python -m pytest -v -m regression --durations=0 --cov-report xml:cov.xml --cov-config=setup.cfg --cov=desc/ --mpl --mpl-results-path=mpl_results.html --mpl-generate-summary=html --splits 5 --group ${{ matrix.group }} --splitting-algorithm least_duration --db ./prof.db
- name: save coverage file and plot comparison results
if: always()
uses: actions/upload-artifact@v3
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
matrix:
python-version: ['3.8', '3.10']
group: [1, 2, 3]
group: [1, 2, 3, 4, 5]

steps:
- uses: actions/checkout@v3
Expand All @@ -41,7 +41,7 @@ jobs:
run: |
pwd
lscpu
python -m pytest -v -m unit --durations=0 --cov-report xml:cov.xml --cov-config=setup.cfg --cov=desc/ --mpl --mpl-results-path=mpl_results.html --mpl-generate-summary=html --splits 3 --group ${{ matrix.group }} --splitting-algorithm least_duration --db ./prof.db
python -m pytest -v -m unit --durations=0 --cov-report xml:cov.xml --cov-config=setup.cfg --cov=desc/ --mpl --mpl-results-path=mpl_results.html --mpl-generate-summary=html --splits 5 --group ${{ matrix.group }} --splitting-algorithm least_duration --db ./prof.db
- name: save coverage file and plot comparison results
if: always()
uses: actions/upload-artifact@v3
Expand Down
83 changes: 64 additions & 19 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _set_derivatives(self):
self._jac_scaled = Derivative(self.compute_scaled, mode="fwd")
self._jac_unscaled = Derivative(self.compute_unscaled, mode="fwd")

def jit(self):
def jit(self): # noqa: C901
"""Apply JIT to compute methods, or re-apply after updating self."""
# can't loop here because del doesn't work on getattr
# main idea is that when jitting a method, jax replaces that method
Expand All @@ -152,39 +152,72 @@ def jit(self):
# CompiledFunction object, which will then leave the raw method in its place,
# and then jit the raw method with the new self

# doing str name type checking to avoid importing weird jax private stuff
# for proper isinstance check
if "CompiledFunction" in str(type(self.compute_scaled)):
self._use_jit = True

try:
del self.compute_scaled
except AttributeError:
pass
self.compute_scaled = jit(self.compute_scaled)
if "CompiledFunction" in str(type(self.compute_scaled_error)):

try:
del self.compute_scaled_error
except AttributeError:
pass
self.compute_scaled_error = jit(self.compute_scaled_error)
if "CompiledFunction" in str(type(self.compute_unscaled)):

try:
del self.compute_unscaled
except AttributeError:
pass
self.compute_unscaled = jit(self.compute_unscaled)
if "CompiledFunction" in str(type(self.compute_scalar)):

try:
del self.compute_scalar
except AttributeError:
pass
self.compute_scalar = jit(self.compute_scalar)
if "CompiledFunction" in str(type(self.jac_scaled)):

try:
del self.jac_scaled
except AttributeError:
pass
self.jac_scaled = jit(self.jac_scaled)
if "CompiledFunction" in str(type(self.jac_unscaled)):

try:
del self.jac_unscaled
except AttributeError:
pass
self.jac_unscaled = jit(self.jac_unscaled)
if "CompiledFunction" in str(type(self.hess)):

try:
del self.hess
except AttributeError:
pass
self.hess = jit(self.hess)
if "CompiledFunction" in str(type(self.grad)):

try:
del self.grad
except AttributeError:
pass
self.grad = jit(self.grad)
if "CompiledFunction" in str(type(self.jvp_scaled)):

try:
del self.jvp_scaled
except AttributeError:
pass
self.jvp_scaled = jit(self.jvp_scaled)
if "CompiledFunction" in str(type(self.jvp_unscaled)):

try:
del self.jvp_unscaled
except AttributeError:
pass
self.jvp_unscaled = jit(self.jvp_unscaled)

for obj in self._objectives:
if obj._use_jit:
obj.jit()

def build(self, eq, use_jit=None, verbose=1):
"""Build the objective.
Expand Down Expand Up @@ -742,20 +775,32 @@ def _set_derivatives(self):

def jit(self):
"""Apply JIT to compute methods, or re-apply after updating self."""
# doing str name type checking to avoid importing weird jax private stuff
# for proper isinstance check
if "CompiledFunction" in str(type(self.compute_scaled)):
self._use_jit = True

try:
del self.compute_scaled
except AttributeError:
pass
self.compute_scaled = jit(self.compute_scaled)
if "CompiledFunction" in str(type(self.compute_scaled_error)):

try:
del self.compute_scaled_error
except AttributeError:
pass
self.compute_scaled_error = jit(self.compute_scaled_error)
if "CompiledFunction" in str(type(self.compute_unscaled)):

try:
del self.compute_unscaled
except AttributeError:
pass
self.compute_unscaled = jit(self.compute_unscaled)
if "CompiledFunction" in str(type(self.compute_scalar)):

try:
del self.compute_scalar
except AttributeError:
pass
self.compute_scalar = jit(self.compute_scalar)

del self._derivatives
self._set_derivatives()
for mode, val in self._derivatives.items():
Expand Down
2 changes: 1 addition & 1 deletion devtools/dev-requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- termcolor
- pip
- pip:
- jax[cpu] >= 0.2.11, <= 0.4.1
- jax[cpu] >= 0.3.2, <= 0.4.11
- nvgpu
# testing and benchmarking
- qsc
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
colorama
h5py >= 3.0.0
jax[cpu] >= 0.2.11, <= 0.4.1
jax[cpu] >= 0.3.2, <= 0.4.11
matplotlib >= 3.3.0, <= 3.6.0, != 3.4.3
mpmath >= 1.0.0
netcdf4 >= 1.5.4
Expand Down
2 changes: 1 addition & 1 deletion requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ dependencies:
- termcolor
- pip
- pip:
- jax[cpu] >= 0.2.11, <= 0.4.1
- jax[cpu] >= 0.3.2, <= 0.4.11
- nvgpu

0 comments on commit c8257ab

Please sign in to comment.