Skip to content

Commit

Permalink
Merge pull request #130 from rasg-affiliates/fix_1d_bin
Browse files Browse the repository at this point in the history
Fix 1d bin
  • Loading branch information
steven-murray authored Aug 9, 2024
2 parents dba5c7d + 12bf5c5 commit 8d6819a
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testsuite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: [3.9, "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@master
with:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"pyyaml",
"astropy>=5",
"methodtools",
"pyuvdata",
"pyuvdata>=3.0.0",
"cached_property",
"rich",
"attrs",
Expand Down
8 changes: 4 additions & 4 deletions src/py21cmsense/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def phase_past_zenith(
lsts = np.array([lsts])

if use_apparent:
app_ra, app_dec = uvutils.calc_app_coords(
zenith_coord.ra.to_value("rad"),
zenith_coord.dec.to_value("rad"),
app_ra, app_dec = uvutils.phasing.calc_app_coords(
lon_coord=zenith_coord.ra.to_value("rad"),
lat_coord=zenith_coord.dec.to_value("rad"),
time_array=obstimes.utc.jd,
telescope_loc=telescope_location,
)
Expand All @@ -103,7 +103,7 @@ def phase_past_zenith(
_lsts = np.tile(lsts, len(bls_enu))
uvws = np.repeat(bls_enu, len(lsts), axis=0)

out = uvutils.calc_uvw(
out = uvutils.phasing.calc_uvw(
app_ra=app_ra,
app_dec=app_dec,
lst_array=_lsts,
Expand Down
4 changes: 2 additions & 2 deletions src/py21cmsense/observatory.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def clone(self, **kwargs) -> Observatory:
def from_uvdata(cls, uvdata, beam: beam.PrimaryBeam, **kwargs) -> Observatory:
"""Instantiate an Observatory from a :class:`pyuvdata.UVData` object."""
return cls(
antpos=uvdata.antenna_positions,
antpos=uvdata.telescope.antenna_positions,
beam=beam,
latitude=uvdata.telescope_location_lat_lon_alt[0],
latitude=uvdata.telescope.location.lat,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion src/py21cmsense/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ def _average_sense_to_1d(
self, sense: dict[tp.Wavenumber, tp.Delta], k1d: tp.Wavenumber | None = None
) -> tp.Delta:
"""Bin 2D sensitivity down to 1D."""
sense1d_inv = np.zeros(len(self.k1d)) / un.mK**4
if k1d is None:
k1d = self.k1d
sense1d_inv = np.zeros(len(k1d)) / un.mK**4

for k_perp in tqdm.tqdm(
sense.keys(),
Expand Down
12 changes: 9 additions & 3 deletions tests/test_observatory.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,17 @@ def test_min_max_antpos(bm):

def test_from_uvdata(bm):
uv = pyuvdata.UVData()
uv.antenna_positions = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [40, 0, 40]]) * units.m
uv.telescope_location = [x.value for x in EarthLocation.from_geodetic(0, 0).to_geocentric()]
uv.telescope.antenna_positions = (
np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [40, 0, 40]]) * units.m
)
uv.telescope.location = [x.value for x in EarthLocation.from_geodetic(0, 0).to_geocentric()]
uv.telescope.antenna_positions = (
np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0], [40, 0, 40]]) * units.m
)
uv.telescope.location = EarthLocation.from_geodetic(0, 0)

a = Observatory.from_uvdata(uvdata=uv, beam=bm)
assert np.all(a.antpos == uv.antenna_positions)
assert np.all(a.antpos == uv.telescope.antenna_positions)


def test_different_antpos_loaders(tmp_path: Path):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def test_sensitivity_2d_grid(observation, caplog):
def test_sensitivity_1d_binned(observation):
ps = PowerSpectrum(observation=observation)
assert np.all(ps.calculate_sensitivity_1d() == ps.calculate_sensitivity_1d_binned(ps.k1d))
kbins = np.linspace(0.1, 0.5, 10) * littleh / units.Mpc
sense1d_sample = ps.calculate_sensitivity_1d_binned(k=kbins)
assert len(sense1d_sample) == len(kbins)


def test_plots(observation):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_uvw.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def test_calc_app_coords(lat, time_past_zenith):

ra = zenith_coord.ra.to_value("rad")
dec = zenith_coord.dec.to_value("rad")
app_ra, app_dec = uvutils.calc_app_coords(
ra, dec, time_array=obstime.utc.jd, telescope_loc=telescope_location
app_ra, app_dec = uvutils.phasing.calc_app_coords(
lon_coord=ra, lat_coord=dec, time_array=obstime.utc.jd, telescope_loc=telescope_location
)

assert np.isclose(app_ra, ra, atol=0.02) # give it 1 degree wiggle room.
Expand Down

0 comments on commit 8d6819a

Please sign in to comment.