Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Dec 24, 2023
1 parent 11aac26 commit ae6fe21
Showing 1 changed file with 2 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def test_loss_gradient(art_warning, get_pytorch_detr):

print("expected_gradients1")
print(grads[0, 0, 10, :32])
print("expected_gradients2")
print(grads[1, 0, 10, :32])

np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=4)

Expand Down Expand Up @@ -185,9 +187,6 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
]
)

print("expected_gradients2")
print(grads[1, 0, 10, :32])

np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=4)

except ARTTestException as e:
Expand Down

0 comments on commit ae6fe21

Please sign in to comment.