Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 23, 2025
1 parent 1b94f42 commit 442a4d5
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 50 deletions.
20 changes: 10 additions & 10 deletions src/py21cmsense/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from astropy import units as un
from astropy.coordinates import EarthLocation, SkyCoord
from astropy.time import Time
from pyuvdata import utils as uvutils
from lunarsky import MoonLocation
from lunarsky import SkyCoord as LunarSkyCoord
from lunarsky import Time as LTime
from pyuvdata import utils as uvutils


def between(xmin, xmax):
"""Return an attrs validation function that checks a number is within bounds."""
Expand Down Expand Up @@ -87,7 +88,7 @@ def phase_past_zenith(

# JD is arbitrary
jd = 2454600

if world == "earth":
tm = Time(jd, format="jd")

Expand All @@ -99,14 +100,14 @@ def phase_past_zenith(
location=telescope_location,
)
else:
tm = LTime(jd, format='jd')
tm = LTime(jd, format="jd")
phase_center_coord = LunarSkyCoord(
alt=90 * un.deg,
az=0 * un.deg,
obstime=tm,
frame="lunartopo",
location=telescope_location,
)
alt=90 * un.deg,
az=0 * un.deg,
obstime=tm,
frame="lunartopo",
location=telescope_location,
)

phase_center_coord = phase_center_coord.transform_to("icrs")

Expand All @@ -120,7 +121,6 @@ def phase_past_zenith(
)

phase_center_coord.obstime.location = telescope_location


obstimes = phase_center_coord.obstime + time_past_zenith
lsts = obstimes.sidereal_time("apparent", longitude=0.0).rad
Expand Down
20 changes: 8 additions & 12 deletions src/py21cmsense/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ class Observation:
)
track: tp.Time | None = attr.ib(
None,
validator=attr.validators.optional(
[tp.vld_physical_type("time")]
),
validator=attr.validators.optional([tp.vld_physical_type("time")]),
)
lst_bin_size: tp.Time = attr.ib(
validator=(tp.vld_physical_type("time")),
Expand Down Expand Up @@ -171,28 +169,26 @@ def __sethstate__(self, d: dict[str, Any]) -> None:

@time_per_day.validator
def _time_per_day_vld(self, att, val):
day_length = 24*un.hour if self.observatory.world == 'earth' else 655.2*un.hour
day_length = 24 * un.hour if self.observatory.world == "earth" else 655.2 * un.hour

if not 0*un.hour <= val <= day_length:
if not 0 * un.hour <= val <= day_length:
raise ValueError(f"time_per_day should be between 0 and {day_length}")

@track.validator
def _track_vld(self, att, val):
if val != None:
day_length = 24*un.hour if self.observatory.world == 'earth' else 655.2*un.hour
day_length = 24 * un.hour if self.observatory.world == "earth" else 655.2 * un.hour

if not 0*un.hour <= val <= day_length:
if not 0 * un.hour <= val <= day_length:
raise ValueError(f"track should be between 0 and {day_length}")


@lst_bin_size.validator
def _lst_bin_size_vld(self, att, val):
day_length = 24*un.hour if self.observatory.world == 'earth' else 655.2*un.hour
day_length = 24 * un.hour if self.observatory.world == "earth" else 655.2 * un.hour

if not 0*un.hour <= val <= day_length:
if not 0 * un.hour <= val <= day_length:
raise ValueError(f"lst_bin_size should be between 0 and {day_length}")


if val > self.time_per_day:
raise ValueError("lst_bin_size must be <= time_per_day")

Expand All @@ -203,7 +199,7 @@ def _integration_time_vld(self, att, val):

@time_per_day.default
def _time_per_day_default(self):
if self.observatory.world == 'earth':
if self.observatory.world == "earth":
return 6 * un.hour
else:
return 163.8 * un.hour
Expand Down
5 changes: 2 additions & 3 deletions src/py21cmsense/observatory.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ class Observatory:
default=0.0 * un.m, validator=(tp.vld_physical_type("length"), ut.nonnegative)
)
beam_crossing_time_incl_latitude: bool = attr.ib(default=True, converter=bool)
world: str = attr.ib(default = "earth", validator=vld.in_(["earth", "moon"])
)
world: str = attr.ib(default="earth", validator=vld.in_(["earth", "moon"]))

@_antpos.validator
def _antpos_validator(self, att, val):
Expand Down Expand Up @@ -364,7 +363,7 @@ def observation_duration(self) -> un.Quantity[un.day]:
if self.world == "earth":
return un.day * self.beam.fwhm / (2 * np.pi * un.rad * latfac)
else:
return 27.3 * un.day * self.beam.fwhm/(2 * np.pi * un.rad * latfac)
return 27.3 * un.day * self.beam.fwhm / (2 * np.pi * un.rad * latfac)

def get_redundant_baselines(
self,
Expand Down
15 changes: 5 additions & 10 deletions src/py21cmsense/theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,27 +197,22 @@ def delta_squared(self, z: float, k: np.ndarray) -> un.Quantity[un.mK**2]:

return self.spline(k) << un.mK**2


class FarViewModel(TheoryModel):
"""21cmFAST-based theory model explicitly for z=30, [Insert paper link here later]"""

use_littleh: bool = False

def __init__(self) -> None:
k_pth = (
Path(__file__).parent
/ "data/farview/kmag_ultimate.npy"
)
k_pth = Path(__file__).parent / "data/farview/kmag_ultimate.npy"

delta_pth = (
Path(__file__).parent
/ "data/farview/P_Tb_ultimate.npy"
)
#Should at some point reorganize the data so these steps aren't necessary
delta_pth = Path(__file__).parent / "data/farview/P_Tb_ultimate.npy"
# Should at some point reorganize the data so these steps aren't necessary
k_fixed = np.load(k_pth)
power = np.load(delta_pth)
k_fixed = k_fixed[~np.isnan(power)]
power = power[~np.isnan(power)]
delta = (k_fixed**3 * power) / (2*np.pi**2)
delta = (k_fixed**3 * power) / (2 * np.pi**2)

self.k = k_fixed
self.delta_squared_raw = delta
Expand Down
18 changes: 11 additions & 7 deletions tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
def bm():
return GaussianBeam(150.0 * units.MHz, dish_size=14 * units.m)

@pytest.fixture(scope="module", params=["earth","moon"])

@pytest.fixture(scope="module", params=["earth", "moon"])
def wd(request):
return request.param


@pytest.fixture(scope="module")
def observatory(bm,wd):
def observatory(bm, wd):
return Observatory(
antpos=np.array([[0, 0, 0], [14, 0, 0], [28, 0, 0], [70, 0, 0], [0, 14, 0], [23, -45, 0]])
* units.m,
latitude=-32 * units.deg,
beam=bm,
world=wd
world=wd,
)


Expand Down Expand Up @@ -94,23 +96,25 @@ def test_from_yaml(observatory):
with pytest.raises(ValueError, match="yaml_file must be a string filepath"):
Observation.from_yaml(3)


def test_huge_time_per_day_size(observatory: Observatory, wd):
tpd = 25 * units.hour if wd=="earth" else 682.5 * units.hour
tpd = 25 * units.hour if wd == "earth" else 682.5 * units.hour
with pytest.raises(ValueError, match="time_per_day should be between 0 and"):
Observation(observatory=observatory, time_per_day=tpd)


def test_huge_track_size(observatory: Observatory, wd):
tck = 25 * units.hour if wd=="earth" else 682.5 * units.hour
tck = 25 * units.hour if wd == "earth" else 682.5 * units.hour
with pytest.raises(ValueError, match="track should be between 0 and"):
Observation(observatory=observatory, track=tck)


def test_huge_lst_bin_size(observatory: Observatory, wd):
lst = 23 * units.hour if wd=="earth" else 627.9 * units.hour
lst = 23 * units.hour if wd == "earth" else 627.9 * units.hour
with pytest.raises(ValueError, match="lst_bin_size must be <= time_per_day"):
Observation(observatory=observatory, lst_bin_size=lst)

lst2 = 25 * units.hour if wd=="earth" else 682.5 * units.hour
lst2 = 25 * units.hour if wd == "earth" else 682.5 * units.hour
with pytest.raises(ValueError, match="lst_bin_size should be between 0 and"):
Observation(observatory=observatory, lst_bin_size=lst2)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
def bm():
return GaussianBeam(150.0 * units.MHz, dish_size=14 * units.m)

@pytest.fixture(scope="module", params=["earth","moon"])

@pytest.fixture(scope="module", params=["earth", "moon"])
def wd(request):
return request.param


@pytest.fixture(scope="module")
def observatory(bm, wd):
return Observatory(
antpos=np.array([[0, 0, 0], [14, 0, 0], [28, 0, 0], [70, 0, 0]]) * units.m,
beam=bm,
world=wd
world=wd,
)


Expand Down Expand Up @@ -82,10 +84,12 @@ def test_sensitivity_optimistic(observation):
ps = PowerSpectrum(observation=observation, foreground_model="optimistic")
assert ps.horizon_limit(10.0) > ps.horizon_limit(5.0)


def test_sensitivity_foreground_free(observation):
ps = PowerSpectrum(observation=observation, foreground_model="foreground_free")
assert ps.horizon_limit(10.0) == 0


def test_infs_in_trms(observation):
# default dumb layout should have lots of infs..
assert np.any(np.isinf(observation.Trms))
Expand Down
3 changes: 2 additions & 1 deletion tests/test_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from py21cmsense.theory import EOS2021, EOS2016Bright, EOS2016Faint, Legacy21cmFAST, FarViewModel
from py21cmsense.theory import EOS2021, EOS2016Bright, EOS2016Faint, FarViewModel, Legacy21cmFAST


def test_eos_extrapolation():
Expand Down Expand Up @@ -39,6 +39,7 @@ def test_eos_2016():

assert faint.delta_squared(9.1, 1.0) != bright.delta_squared(9.1, 1.0)


def test_FarView():
theory = FarViewModel()
assert theory.delta_squared(29.6, 1.0) == theory.delta_squared(30.4, 1.0)
Expand Down
14 changes: 9 additions & 5 deletions tests/test_uvw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_phase_at_zenith(lat, use_apparent):
time_past_zenith=0.0 * un.day,
bls_enu=bls_enu,
latitude=lat * un.rad,
world = 'earth',
world="earth",
use_apparent=use_apparent,
)

Expand All @@ -46,7 +46,7 @@ def test_phase_past_zenith(use_apparent):
time_past_zenith=0.2 * un.day,
bls_enu=bls_enu,
latitude=0 * un.rad,
world = 'earth',
world="earth",
use_apparent=use_apparent,
)
)
Expand All @@ -69,7 +69,9 @@ def test_phase_past_zenith_shape():
times = np.array([0, 0.1, 0, 0.1]) * un.day

# Almost rotated to the horizon.
uvws = phase_past_zenith(time_past_zenith=times, bls_enu=bls_enu, latitude=0 * un.rad, world='earth')
uvws = phase_past_zenith(
time_past_zenith=times, bls_enu=bls_enu, latitude=0 * un.rad, world="earth"
)

assert uvws.shape == (5, 4, 3)
assert np.allclose(uvws[0], uvws[2]) # Same baselines
Expand All @@ -89,12 +91,14 @@ def test_use_apparent(lat):
times = np.linspace(-1, 1, 3) * un.hour

# Almost rotated to the horizon.
uvws = phase_past_zenith(time_past_zenith=times, bls_enu=bls_enu, latitude=lat * un.rad, world='earth')
uvws = phase_past_zenith(
time_past_zenith=times, bls_enu=bls_enu, latitude=lat * un.rad, world="earth"
)
uvws0 = phase_past_zenith(
time_past_zenith=times,
bls_enu=bls_enu,
latitude=lat * un.rad,
world = 'earth',
world="earth",
use_apparent=True,
)

Expand Down

0 comments on commit 442a4d5

Please sign in to comment.