From fe6902c90a7a272640476a79003be32c3b83c13b Mon Sep 17 00:00:00 2001 From: Josh Dillon Date: Mon, 6 Jan 2025 11:22:44 -0800 Subject: [PATCH] add unit tests thanks ChatGPT o1 --- hera_cal/tests/test_datacontainer.py | 158 +++++++++++++++++++++++---- 1 file changed, 138 insertions(+), 20 deletions(-) diff --git a/hera_cal/tests/test_datacontainer.py b/hera_cal/tests/test_datacontainer.py index ae80ebd2b..660a6f7b0 100644 --- a/hera_cal/tests/test_datacontainer.py +++ b/hera_cal/tests/test_datacontainer.py @@ -20,7 +20,7 @@ def setup_method(self): self.pols = ['xx', 'yy'] self.blpol = {} self.lsts = np.array([1.1]) - self.freqs = np.array([101.1])#the edge + self.freqs = np.array([101.1]) # the edge for bl in self.antpairs: self.blpol[bl] = {} for pol in self.pols: @@ -35,7 +35,7 @@ def setup_method(self): for bl in self.antpairs: self.both[bl + (pol,)] = 1j self.bools = {} - for pol in self.pols: + for pol in self.pols: for bl in self.antpairs: self.bools[bl + (pol,)] = np.array([True]) self.blpolarr = {} @@ -340,6 +340,124 @@ def keys(self): dc = datacontainer.DataContainer(blpol) assert dc.ants == blpol.ants + def test_deinterleave_basic(self): + from hera_cal.datacontainer import DataContainer + import numpy as np + + # Create a small DataContainer with times and lsts set. + # Suppose we have 2 baselines, each with shape (10, 5): 10 times, 5 freq channels. + # We'll fill them with ascending integers for easy checking. + data_dict = { + (0, 1, 'ee'): np.arange(50).reshape(10, 5), + (1, 2, 'ee'): np.arange(50, 100).reshape(10, 5) + } + dc = DataContainer(data_dict) + dc.times = np.arange(10) + dc.lsts = np.arange(10) * 2 * np.pi / 10 # dummy + + # 1) Basic test: deinterleave into 2 + dcs = dc.deinterleave(2) + assert len(dcs) == 2, "Expected 2 DataContainers from deinterleave(2)." + + # Each output DC should have times=[0,2,4,6,8] and [1,3,5,7,9] respectively. + # Also drop any leftover times if not divisible (in this case 10 is divisible by 2). + np.testing.assert_array_equal(dcs[0].times, [0, 2, 4, 6, 8]) + np.testing.assert_array_equal(dcs[1].times, [1, 3, 5, 7, 9]) + + # Check the shapes and contents of data + # The original data for (0,1,'ee') was a 10x5 array of 0..49. + # After taking slice [::2], shape should be 5x5, and it should contain + # rows 0,2,4,6,8 from the original. + for i in range(2): + assert (0, 1, 'ee') in dcs[i] + dat = dcs[i][(0, 1, 'ee')] + assert dat.shape == (5, 5) + # Check actual numeric contents + # The original is 0..49 in row-major order, so row 0 in the original is [0,1,2,3,4], + # row 1 is [5,6,7,8,9], row 2 is [10,11,12,13,14], etc. + # For i=0, we expect rows [0,2,4,6,8], i.e. 0, 10, 20, 30, 40 + # For i=1, we expect rows [1,3,5,7,9], i.e. 5, 15, 25, 35, 45 + expected_rows = [row for row in range(i, 10, 2)] + expected_data = np.arange(50).reshape(10, 5)[expected_rows, :] + np.testing.assert_array_equal(dat, expected_data) + + # Check the same for the second baseline + for i in range(2): + assert (1, 2, 'ee') in dcs[i] + dat = dcs[i][(1, 2, 'ee')] + assert dat.shape == (5, 5) + expected_rows = [row for row in range(i, 10, 2)] + expected_data = np.arange(50, 100).reshape(10, 5)[expected_rows, :] + np.testing.assert_array_equal(dat, expected_data) + + def test_deinterleave_uneven(self): + """Test that leftover integrations at the end are dropped so each DC is the same length.""" + from hera_cal.datacontainer import DataContainer + import numpy as np + + data_dict = { + (0, 1, 'ee'): np.arange(45).reshape(9, 5), + } + dc = DataContainer(data_dict) + dc.times = np.arange(9) # 9 times + dc.lsts = np.arange(9) * 2 # dummy + + # deinterleave with ninterleaves=4 means each subset takes steps of size 4 + # but the length is truncated so each subset has the same shape. + # total times = 9, each subset would skip 3 frames after the last full cycle + # so effectively we drop the last leftover frame if it doesn't fill all 4 subsets. + dcs = dc.deinterleave(4) + # The largest multiple of 4 that fits into 9 is 8. So each subset should have shape [2,5]. + assert len(dcs) == 4 + for i, subdc in enumerate(dcs): + assert subdc[(0, 1, 'ee')].shape == (2, 5) + np.testing.assert_array_equal( + subdc.times, + [t for t in range(9) if (t - i) % 4 == 0 and t < 8] # i, i+4, but < 8 + ) + + def test_deinterleave_missing_times(self): + """Test that if .times or .lsts is missing, we get an error.""" + from hera_cal.datacontainer import DataContainer + import numpy as np + import pytest + + data_dict = { + (0, 1, 'ee'): np.arange(50).reshape(10, 5) + } + dc = DataContainer(data_dict) + + # times not set + with pytest.raises(ValueError, match='Cannot deinterleave if self.times is not set'): + dc.deinterleave(2) + + # times set but lsts not set + dc.times = np.arange(10) + with pytest.raises(ValueError, match='Cannot deinterleave if self.lsts is not set'): + dc.deinterleave(2) + + def test_deinterleave_tslice(self): + """Test providing a tslice that only picks some portion of times before splitting.""" + from hera_cal.datacontainer import DataContainer + import numpy as np + + data_dict = { + (0, 1, 'ee'): np.arange(50).reshape(10, 5) + } + dc = DataContainer(data_dict) + dc.times = np.arange(10) + dc.lsts = np.arange(10) * 2.0 + + # Suppose we only want the middle times [2..7] before splitting, + # then deinterleave into n=2. That means each subset gets times [2,4,6] and [3,5,7]. + dcs = dc.deinterleave(2, tslice=slice(2, 8)) + assert len(dcs) == 2 + np.testing.assert_array_equal(dcs[0].times, [2, 4, 6]) + np.testing.assert_array_equal(dcs[1].times, [3, 5, 7]) + for subdc in dcs: + # each subset's data shape is 3 (times) x 5 (freq) + assert subdc[(0, 1, 'ee')].shape == (3, 5) + @pytest.mark.filterwarnings("ignore:The default for the `center` keyword has changed") class TestDataContainerWithRealData: @@ -348,10 +466,10 @@ def test_adder(self): test_file = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA") d, f = io.load_vis(test_file, pop_autos=True) d2 = d + d - assert type(d2.freqs)==type(d.freqs) - assert type(d2.lsts)==type(d.lsts) - assert np.allclose(d2.freqs,d.freqs) - assert np.allclose(d2.lsts,d.lsts) + assert type(d2.freqs) is type(d.freqs) + assert type(d2.lsts) is type(d.lsts) + assert np.allclose(d2.freqs, d.freqs) + assert np.allclose(d2.lsts, d.lsts) assert np.allclose(d2[(24, 25, 'ee')][30, 30], d[(24, 25, 'ee')][30, 30] * 2) # test exception d2, f2 = io.load_vis(test_file, pop_autos=True) @@ -361,18 +479,18 @@ def test_adder(self): pytest.raises(ValueError, d.__add__, d2) d2 = d + 1 assert np.isclose(d2[(24, 25, 'ee')][30, 30], d[(24, 25, 'ee')][30, 30] + 1) - assert np.allclose(d2.freqs,d.freqs) - assert np.allclose(d2.lsts,d.lsts) + assert np.allclose(d2.freqs, d.freqs) + assert np.allclose(d2.lsts, d.lsts) def test_sub(self): test_file = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA") d, f = io.load_vis(test_file, pop_autos=True) d2 = d - d assert np.allclose(d2[(24, 25, 'ee')][30, 30], 0.0) - assert type(d2.freqs)==type(d.freqs) - assert type(d2.lsts)==type(d.lsts) - assert np.allclose(d2.freqs,d.freqs) - assert np.allclose(d2.lsts,d.lsts) + assert type(d2.freqs) is type(d.freqs) + assert type(d2.lsts) is type(d.lsts) + assert np.allclose(d2.freqs, d.freqs) + assert np.allclose(d2.lsts, d.lsts) # test exception d2, f2 = io.load_vis(test_file, pop_autos=True) d2[list(d2.keys())[0]] = d2[list(d2.keys())[0]][:, :10] @@ -388,10 +506,10 @@ def test_mul(self): f[(24, 25, 'ee')][:, 0] = False f2 = f * f assert not np.any(f2[(24, 25, 'ee')][0, 0]) - assert type(f2.freqs)==type(f.freqs) - assert type(f2.lsts)==type(f.lsts) - assert np.allclose(f2.freqs,f.freqs) - assert np.allclose(f2.lsts,f.lsts) + assert type(f2.freqs) is type(f.freqs) + assert type(f2.lsts) is type(f.lsts) + assert np.allclose(f2.freqs, f.freqs) + assert np.allclose(f2.lsts, f.lsts) # test exception d2, f2 = io.load_vis(test_file, pop_autos=True) d2[list(d2.keys())[0]] = d2[list(d2.keys())[0]][:, :10] @@ -410,10 +528,10 @@ def test_div(self): d, f = io.load_vis(test_file, pop_autos=True) d2 = d / d assert np.allclose(d2[(24, 25, 'ee')][30, 30], 1.0) - assert type(d2.freqs)==type(d.freqs) - assert type(d2.lsts)==type(d.lsts) - assert np.allclose(d2.freqs,d.freqs) - assert np.allclose(d2.lsts,d.lsts) + assert type(d2.freqs) is type(d.freqs) + assert type(d2.lsts) is type(d.lsts) + assert np.allclose(d2.freqs, d.freqs) + assert np.allclose(d2.lsts, d.lsts) d2 = d / 2.0 assert np.allclose(d2[(24, 25, 'ee')][30, 30], d[(24, 25, 'ee')][30, 30] / 2.0) d2 = d // d