Skip to content

Commit

Permalink
Use numpy.testing to test arrays with different length
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Jan 23, 2024
1 parent aa833da commit 0d3fb95
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_deterministic_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import fuse
import tempfile
import numpy as np
from numpy.testing import assert_array_equal, assert_raises

class TestDeterministicSeed(unittest.TestCase):

Expand Down Expand Up @@ -41,7 +42,7 @@ def test_MicroPhysics_SameSeed(self):
output_0 = self.test_context_0.get_array(self.run_number_0, "microphysics_summary")
output_1 = self.test_context_1.get_array(self.run_number_0, "microphysics_summary")

self.assertTrue(np.all(output_0 == output_1))
assert_array_equal(output_0, output_1)

def test_MicroPhysics_DifferentSeed(self):
"""Test that a different run_number produce a different random seed and thus different output"""
Expand All @@ -52,7 +53,7 @@ def test_MicroPhysics_DifferentSeed(self):
output_0 = self.test_context_0.get_array(self.run_number_0, "microphysics_summary")
output_1 = self.test_context_1.get_array(self.run_number_1, "microphysics_summary")

self.assertFalse(np.all(output_0 == output_1))
assert_raises(AssertionError, assert_array_equal, output_0, output_1)

def test_FullChain_SameSeed(self):
"""Test that the same run_number and lineage produce the same random seed and thus the same output"""
Expand All @@ -63,7 +64,7 @@ def test_FullChain_SameSeed(self):
output_0 = self.test_context_0.get_array(self.run_number_0, "raw_records")
output_1 = self.test_context_1.get_array(self.run_number_0, "raw_records")

self.assertTrue(np.all(output_0 == output_1))
assert_array_equal(output_0, output_1)

def test_FullChain_DifferentSeed(self):
"""Test that a different run_number produce a different random seed and thus different output"""
Expand All @@ -74,7 +75,7 @@ def test_FullChain_DifferentSeed(self):
output_0 = self.test_context_0.get_array(self.run_number_0, "raw_records")
output_1 = self.test_context_1.get_array(self.run_number_1, "raw_records")

self.assertFalse(np.all(output_0 == output_1))
assert_raises(AssertionError, assert_array_equal, output_0, output_1)

if __name__ == '__main__':
unittest.main()

0 comments on commit 0d3fb95

Please sign in to comment.