Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
thanks ChatGPT o1
  • Loading branch information
jsdillon committed Jan 7, 2025
1 parent 5caac08 commit fe6902c
Showing 1 changed file with 138 additions and 20 deletions.
158 changes: 138 additions & 20 deletions hera_cal/tests/test_datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setup_method(self):
self.pols = ['xx', 'yy']
self.blpol = {}
self.lsts = np.array([1.1])
self.freqs = np.array([101.1])#the edge
self.freqs = np.array([101.1]) # the edge
for bl in self.antpairs:
self.blpol[bl] = {}
for pol in self.pols:
Expand All @@ -35,7 +35,7 @@ def setup_method(self):
for bl in self.antpairs:
self.both[bl + (pol,)] = 1j
self.bools = {}
for pol in self.pols:
for pol in self.pols:
for bl in self.antpairs:
self.bools[bl + (pol,)] = np.array([True])
self.blpolarr = {}
Expand Down Expand Up @@ -340,6 +340,124 @@ def keys(self):
dc = datacontainer.DataContainer(blpol)
assert dc.ants == blpol.ants

def test_deinterleave_basic(self):
from hera_cal.datacontainer import DataContainer
import numpy as np

# Create a small DataContainer with times and lsts set.
# Suppose we have 2 baselines, each with shape (10, 5): 10 times, 5 freq channels.
# We'll fill them with ascending integers for easy checking.
data_dict = {
(0, 1, 'ee'): np.arange(50).reshape(10, 5),
(1, 2, 'ee'): np.arange(50, 100).reshape(10, 5)
}
dc = DataContainer(data_dict)
dc.times = np.arange(10)
dc.lsts = np.arange(10) * 2 * np.pi / 10 # dummy

# 1) Basic test: deinterleave into 2
dcs = dc.deinterleave(2)
assert len(dcs) == 2, "Expected 2 DataContainers from deinterleave(2)."

# Each output DC should have times=[0,2,4,6,8] and [1,3,5,7,9] respectively.
# Also drop any leftover times if not divisible (in this case 10 is divisible by 2).
np.testing.assert_array_equal(dcs[0].times, [0, 2, 4, 6, 8])
np.testing.assert_array_equal(dcs[1].times, [1, 3, 5, 7, 9])

# Check the shapes and contents of data
# The original data for (0,1,'ee') was a 10x5 array of 0..49.
# After taking slice [::2], shape should be 5x5, and it should contain
# rows 0,2,4,6,8 from the original.
for i in range(2):
assert (0, 1, 'ee') in dcs[i]
dat = dcs[i][(0, 1, 'ee')]
assert dat.shape == (5, 5)
# Check actual numeric contents
# The original is 0..49 in row-major order, so row 0 in the original is [0,1,2,3,4],
# row 1 is [5,6,7,8,9], row 2 is [10,11,12,13,14], etc.
# For i=0, we expect rows [0,2,4,6,8], i.e. 0, 10, 20, 30, 40
# For i=1, we expect rows [1,3,5,7,9], i.e. 5, 15, 25, 35, 45
expected_rows = [row for row in range(i, 10, 2)]
expected_data = np.arange(50).reshape(10, 5)[expected_rows, :]
np.testing.assert_array_equal(dat, expected_data)

# Check the same for the second baseline
for i in range(2):
assert (1, 2, 'ee') in dcs[i]
dat = dcs[i][(1, 2, 'ee')]
assert dat.shape == (5, 5)
expected_rows = [row for row in range(i, 10, 2)]
expected_data = np.arange(50, 100).reshape(10, 5)[expected_rows, :]
np.testing.assert_array_equal(dat, expected_data)

def test_deinterleave_uneven(self):
"""Test that leftover integrations at the end are dropped so each DC is the same length."""
from hera_cal.datacontainer import DataContainer
import numpy as np

data_dict = {
(0, 1, 'ee'): np.arange(45).reshape(9, 5),
}
dc = DataContainer(data_dict)
dc.times = np.arange(9) # 9 times
dc.lsts = np.arange(9) * 2 # dummy

# deinterleave with ninterleaves=4 means each subset takes steps of size 4
# but the length is truncated so each subset has the same shape.
# total times = 9, each subset would skip 3 frames after the last full cycle
# so effectively we drop the last leftover frame if it doesn't fill all 4 subsets.
dcs = dc.deinterleave(4)
# The largest multiple of 4 that fits into 9 is 8. So each subset should have shape [2,5].
assert len(dcs) == 4
for i, subdc in enumerate(dcs):
assert subdc[(0, 1, 'ee')].shape == (2, 5)
np.testing.assert_array_equal(
subdc.times,
[t for t in range(9) if (t - i) % 4 == 0 and t < 8] # i, i+4, but < 8
)

def test_deinterleave_missing_times(self):
"""Test that if .times or .lsts is missing, we get an error."""
from hera_cal.datacontainer import DataContainer
import numpy as np
import pytest

data_dict = {
(0, 1, 'ee'): np.arange(50).reshape(10, 5)
}
dc = DataContainer(data_dict)

# times not set
with pytest.raises(ValueError, match='Cannot deinterleave if self.times is not set'):
dc.deinterleave(2)

# times set but lsts not set
dc.times = np.arange(10)
with pytest.raises(ValueError, match='Cannot deinterleave if self.lsts is not set'):
dc.deinterleave(2)

def test_deinterleave_tslice(self):
"""Test providing a tslice that only picks some portion of times before splitting."""
from hera_cal.datacontainer import DataContainer
import numpy as np

data_dict = {
(0, 1, 'ee'): np.arange(50).reshape(10, 5)
}
dc = DataContainer(data_dict)
dc.times = np.arange(10)
dc.lsts = np.arange(10) * 2.0

# Suppose we only want the middle times [2..7] before splitting,
# then deinterleave into n=2. That means each subset gets times [2,4,6] and [3,5,7].
dcs = dc.deinterleave(2, tslice=slice(2, 8))
assert len(dcs) == 2
np.testing.assert_array_equal(dcs[0].times, [2, 4, 6])
np.testing.assert_array_equal(dcs[1].times, [3, 5, 7])
for subdc in dcs:
# each subset's data shape is 3 (times) x 5 (freq)
assert subdc[(0, 1, 'ee')].shape == (3, 5)


@pytest.mark.filterwarnings("ignore:The default for the `center` keyword has changed")
class TestDataContainerWithRealData:
Expand All @@ -348,10 +466,10 @@ def test_adder(self):
test_file = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA")
d, f = io.load_vis(test_file, pop_autos=True)
d2 = d + d
assert type(d2.freqs)==type(d.freqs)
assert type(d2.lsts)==type(d.lsts)
assert np.allclose(d2.freqs,d.freqs)
assert np.allclose(d2.lsts,d.lsts)
assert type(d2.freqs) is type(d.freqs)
assert type(d2.lsts) is type(d.lsts)
assert np.allclose(d2.freqs, d.freqs)
assert np.allclose(d2.lsts, d.lsts)
assert np.allclose(d2[(24, 25, 'ee')][30, 30], d[(24, 25, 'ee')][30, 30] * 2)
# test exception
d2, f2 = io.load_vis(test_file, pop_autos=True)
Expand All @@ -361,18 +479,18 @@ def test_adder(self):
pytest.raises(ValueError, d.__add__, d2)
d2 = d + 1
assert np.isclose(d2[(24, 25, 'ee')][30, 30], d[(24, 25, 'ee')][30, 30] + 1)
assert np.allclose(d2.freqs,d.freqs)
assert np.allclose(d2.lsts,d.lsts)
assert np.allclose(d2.freqs, d.freqs)
assert np.allclose(d2.lsts, d.lsts)

def test_sub(self):
test_file = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA")
d, f = io.load_vis(test_file, pop_autos=True)
d2 = d - d
assert np.allclose(d2[(24, 25, 'ee')][30, 30], 0.0)
assert type(d2.freqs)==type(d.freqs)
assert type(d2.lsts)==type(d.lsts)
assert np.allclose(d2.freqs,d.freqs)
assert np.allclose(d2.lsts,d.lsts)
assert type(d2.freqs) is type(d.freqs)
assert type(d2.lsts) is type(d.lsts)
assert np.allclose(d2.freqs, d.freqs)
assert np.allclose(d2.lsts, d.lsts)
# test exception
d2, f2 = io.load_vis(test_file, pop_autos=True)
d2[list(d2.keys())[0]] = d2[list(d2.keys())[0]][:, :10]
Expand All @@ -388,10 +506,10 @@ def test_mul(self):
f[(24, 25, 'ee')][:, 0] = False
f2 = f * f
assert not np.any(f2[(24, 25, 'ee')][0, 0])
assert type(f2.freqs)==type(f.freqs)
assert type(f2.lsts)==type(f.lsts)
assert np.allclose(f2.freqs,f.freqs)
assert np.allclose(f2.lsts,f.lsts)
assert type(f2.freqs) is type(f.freqs)
assert type(f2.lsts) is type(f.lsts)
assert np.allclose(f2.freqs, f.freqs)
assert np.allclose(f2.lsts, f.lsts)
# test exception
d2, f2 = io.load_vis(test_file, pop_autos=True)
d2[list(d2.keys())[0]] = d2[list(d2.keys())[0]][:, :10]
Expand All @@ -410,10 +528,10 @@ def test_div(self):
d, f = io.load_vis(test_file, pop_autos=True)
d2 = d / d
assert np.allclose(d2[(24, 25, 'ee')][30, 30], 1.0)
assert type(d2.freqs)==type(d.freqs)
assert type(d2.lsts)==type(d.lsts)
assert np.allclose(d2.freqs,d.freqs)
assert np.allclose(d2.lsts,d.lsts)
assert type(d2.freqs) is type(d.freqs)
assert type(d2.lsts) is type(d.lsts)
assert np.allclose(d2.freqs, d.freqs)
assert np.allclose(d2.lsts, d.lsts)
d2 = d / 2.0
assert np.allclose(d2[(24, 25, 'ee')][30, 30], d[(24, 25, 'ee')][30, 30] / 2.0)
d2 = d // d
Expand Down

0 comments on commit fe6902c

Please sign in to comment.