Skip to content

Commit

Permalink
Merge branch 'release/1.3.x' into bugs/1258-_Bug_Lasso_does_not_work_…
Browse files Browse the repository at this point in the history
…on_GPU
  • Loading branch information
ClaudiaComito committed Nov 22, 2023
2 parents 5b37e05 + ee39c63 commit 3d3e8e8
Showing 3 changed files with 4 additions and 13 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/pytorch-latest-main.yml
Original file line number Diff line number Diff line change
@@ -11,9 +11,8 @@ jobs:
runs-on: ubuntu-latest
if: ${{ github.repository }} == 'hemlholtz-analytics/heat'
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
token: ${{ secrets.GHACTIONS }}
ref: '${{ env.base_branch }}'
- name: Fetch PyTorch release version
run: |
3 changes: 1 addition & 2 deletions .github/workflows/pytorch-latest-release.yml
Original file line number Diff line number Diff line change
@@ -11,9 +11,8 @@ jobs:
runs-on: ubuntu-latest
if: ${{ github.repository }} == 'hemlholtz-analytics/heat'
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
token: ${{ secrets.GHACTIONS }}
ref: '${{ env.base_branch }}'
- name: Fetch PyTorch release version
run: |
11 changes: 2 additions & 9 deletions heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
@@ -65,15 +65,8 @@ def test_lanczos(self):
lanczos_B = V_out @ T_out @ V_inv
self.assertTrue(ht.allclose(lanczos_B, B))

# single precision tolerance
if (
int(torch.__version__.split(".")[0]) == 1
and int(torch.__version__.split(".")[1]) >= 13
or int(torch.__version__.split(".")[0]) > 1
):
tolerance = 1e-3
else:
tolerance = 1e-4
# single precision tolerance for torch.inv() is pretty bad
tolerance = 1e-3

# float32, pre_defined v0, split mismatch
A = ht.random.randn(n, n, dtype=ht.float32, split=0)

0 comments on commit 3d3e8e8

Please sign in to comment.