Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 24, 2025
1 parent 8045585 commit 1f0527b
Showing 1 changed file with 154 additions and 69 deletions.
223 changes: 154 additions & 69 deletions docs/tutorials/SKA_forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import py21cmsense as p21sense\n",
"from astropy import units as un\n",
"from astropy.cosmology import Planck15\n",
"from matplotlib import rcParams, colors\n",
"from matplotlib import colors, rcParams\n",
"from scipy.interpolate import RegularGridInterpolator\n",
"\n",
"import py21cmsense as p21sense\n",
"from py21cmsense.observatory import Observatory\n",
"from py21cmsense.sensitivity import PowerSpectrum\n",
"from py21cmsense.theory import EOS2021, TheoryModel\n",
"from scipy.interpolate import RegularGridInterpolator\n",
"\n",
"rcParams.update({\"font.size\": 20})\n",
"\n",
Expand All @@ -65,14 +66,25 @@
"outputs": [],
"source": [
"def freq2z(f):\n",
" return 1420.4 / f - 1.\n",
" return 1420.4 / f - 1.0\n",
"\n",
"\n",
"def z2freq(z):\n",
" return 1420.4 / (z + 1.)\n",
" \n",
"def get_senses(observation, freq_bands, kperp_edges_hMpc, kpar_edges_hMpc, calc_2D=False, foreground_model=\"moderate\",theory_model=EOS2021,**kwargs):\n",
" return 1420.4 / (z + 1.0)\n",
"\n",
"\n",
"def get_senses(\n",
" observation,\n",
" freq_bands,\n",
" kperp_edges_hMpc,\n",
" kpar_edges_hMpc,\n",
" calc_2D=False,\n",
" foreground_model=\"moderate\",\n",
" theory_model=EOS2021,\n",
" **kwargs,\n",
"):\n",
" h = observation.cosmo.H0.value / 100.0\n",
" redshifts = kwargs['redshifts']\n",
" redshifts = kwargs[\"redshifts\"]\n",
" mock_senses = {}\n",
" mock_senses[\"kperp_edges_hMpc\"] = kperp_edges_hMpc\n",
" mock_senses[\"kpar_edges_hMpc\"] = kpar_edges_hMpc\n",
Expand All @@ -84,9 +96,9 @@
" this_z = {}\n",
" band_name = str(np.round(band, 1)) + \" MHz\"\n",
"\n",
" sense = PowerSpectrum(observation=observation, \n",
" foreground_model=foreground_model,\n",
" theory_model=theory_model()).at_frequency(band * un.MHz)\n",
" sense = PowerSpectrum(\n",
" observation=observation, foreground_model=foreground_model, theory_model=theory_model()\n",
" ).at_frequency(band * un.MHz)\n",
" if calc_2D:\n",
" sense2d_sample = sense.calculate_sensitivity_2d_grid(\n",
" kperp_edges=kperp_edges_hMpc, kpar_edges=kpar_edges_hMpc, thermal=False, sample=True\n",
Expand All @@ -103,7 +115,9 @@
"\n",
" sense1d_sample = sense.calculate_sensitivity_1d(thermal=False, sample=True)\n",
" this_z[\"sample_1D_mK2\"] = sense1d_sample.value\n",
" this_z[\"theory_1D_mK2\"] = sense.theory_model.delta_squared(zval, sense.k1d.value * mock_senses[\"h\"]).value\n",
" this_z[\"theory_1D_mK2\"] = sense.theory_model.delta_squared(\n",
" zval, sense.k1d.value * mock_senses[\"h\"]\n",
" ).value\n",
"\n",
" this_z[\"k_1D_Mpc\"] = sense.k1d.value * mock_senses[\"h\"]\n",
" sense1d_thermal = sense.calculate_sensitivity_1d(thermal=True, sample=False)\n",
Expand Down Expand Up @@ -217,10 +231,24 @@
"# aa4.baselines_metres has shape (Nants, Nants, 3): it gives us the basline vectors between all pairs of antennas in x,y,z.\n",
"# We ignore the z coordinate and plot the x and y baselines for all antenna pairs.\n",
"fig, ax = plt.subplots(1, 2, figsize=(14, 6), sharey=True, gridspec_kw={\"wspace\": 0.1})\n",
"im = ax[0].hexbin(aaast.baselines_metres[:,:,0].ravel(), aaast.baselines_metres[:,:,1].ravel(), label=\"AA*\",gridsize=20, bins='log', vmin = 100)\n",
"plt.colorbar(im, ax=ax[0],label = 'Number of baselines')\n",
"im = ax[1].hexbin(aa4.baselines_metres[:,:,0].ravel(), aa4.baselines_metres[:,:,1].ravel(), label=\"AA4\",gridsize=20, bins='log', vmin = 100)\n",
"plt.colorbar(im, ax=ax[1],label = 'Number of baselines')\n",
"im = ax[0].hexbin(\n",
" aaast.baselines_metres[:, :, 0].ravel(),\n",
" aaast.baselines_metres[:, :, 1].ravel(),\n",
" label=\"AA*\",\n",
" gridsize=20,\n",
" bins=\"log\",\n",
" vmin=100,\n",
")\n",
"plt.colorbar(im, ax=ax[0], label=\"Number of baselines\")\n",
"im = ax[1].hexbin(\n",
" aa4.baselines_metres[:, :, 0].ravel(),\n",
" aa4.baselines_metres[:, :, 1].ravel(),\n",
" label=\"AA4\",\n",
" gridsize=20,\n",
" bins=\"log\",\n",
" vmin=100,\n",
")\n",
"plt.colorbar(im, ax=ax[1], label=\"Number of baselines\")\n",
"ax[0].set_xlabel(\"X baseline [m]\")\n",
"ax[1].set_xlabel(\"X baseline [m]\")\n",
"ax[0].set_ylabel(\"Y baseline [m]\")\n",
Expand Down Expand Up @@ -254,17 +282,39 @@
],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(14, 6), sharey=True, gridspec_kw={\"wspace\": 0.1})\n",
"m = np.logical_and(abs(aaast.baselines_metres[:,:,0].ravel().value) < 2000, abs(aaast.baselines_metres[:,:,1].ravel().value) < 2000)\n",
"im = ax[0].hexbin(aaast.baselines_metres[:,:,0].ravel()[m], aaast.baselines_metres[:,:,1].ravel()[m], label=\"AA*\",gridsize=20, bins='log', vmin = 100)\n",
"plt.colorbar(im, ax=ax[0],label = 'Number of baselines')\n",
"m = np.logical_and(abs(aa4.baselines_metres[:,:,0].ravel().value) < 2000, abs(aa4.baselines_metres[:,:,1].ravel().value) < 2000)\n",
"im = ax[1].hexbin(aa4.baselines_metres[:,:,0].ravel()[m], aa4.baselines_metres[:,:,1].ravel()[m], label=\"AA4\",gridsize=20, bins='log', vmin = 100)\n",
"plt.colorbar(im, ax=ax[1],label = 'Number of baselines')\n",
"m = np.logical_and(\n",
" abs(aaast.baselines_metres[:, :, 0].ravel().value) < 2000,\n",
" abs(aaast.baselines_metres[:, :, 1].ravel().value) < 2000,\n",
")\n",
"im = ax[0].hexbin(\n",
" aaast.baselines_metres[:, :, 0].ravel()[m],\n",
" aaast.baselines_metres[:, :, 1].ravel()[m],\n",
" label=\"AA*\",\n",
" gridsize=20,\n",
" bins=\"log\",\n",
" vmin=100,\n",
")\n",
"plt.colorbar(im, ax=ax[0], label=\"Number of baselines\")\n",
"m = np.logical_and(\n",
" abs(aa4.baselines_metres[:, :, 0].ravel().value) < 2000,\n",
" abs(aa4.baselines_metres[:, :, 1].ravel().value) < 2000,\n",
")\n",
"im = ax[1].hexbin(\n",
" aa4.baselines_metres[:, :, 0].ravel()[m],\n",
" aa4.baselines_metres[:, :, 1].ravel()[m],\n",
" label=\"AA4\",\n",
" gridsize=20,\n",
" bins=\"log\",\n",
" vmin=100,\n",
")\n",
"plt.colorbar(im, ax=ax[1], label=\"Number of baselines\")\n",
"ax[0].set_xlabel(\"X baseline [m]\")\n",
"ax[1].set_xlabel(\"X baseline [m]\")\n",
"ax[0].set_ylabel(\"Y baseline [m]\")\n",
"ax[0].text(0.83, 0.95, \"AA*\", transform=ax[0].transAxes, fontsize=20, verticalalignment=\"top\")\n",
"ax[1].text(0.8, 0.95, \"AA4\", transform=ax[1].transAxes, fontsize=20, verticalalignment=\"top\", color=\"w\")\n",
"ax[1].text(\n",
" 0.8, 0.95, \"AA4\", transform=ax[1].transAxes, fontsize=20, verticalalignment=\"top\", color=\"w\"\n",
")\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -502,7 +552,7 @@
"metadata": {},
"outputs": [],
"source": [
"obs = obs.clone(observatory=aa4, lst_bin_size=3. * un.hour)"
"obs = obs.clone(observatory=aa4, lst_bin_size=3.0 * un.hour)"
]
},
{
Expand Down Expand Up @@ -559,9 +609,10 @@
"metadata": {},
"outputs": [],
"source": [
"obs = obs.clone(observatory=aa4, \n",
"lst_bin_size=aa4.observation_duration,# beam-crossing time\n",
") "
"obs = obs.clone(\n",
" observatory=aa4,\n",
" lst_bin_size=aa4.observation_duration, # beam-crossing time\n",
")"
]
},
{
Expand Down Expand Up @@ -589,8 +640,9 @@
"metadata": {},
"outputs": [],
"source": [
"obs = obs.clone(observatory=aaast,\n",
"lst_bin_size=aaast.observation_duration,# beam-crossing time\n",
"obs = obs.clone(\n",
" observatory=aaast,\n",
" lst_bin_size=aaast.observation_duration, # beam-crossing time\n",
")"
]
},
Expand Down Expand Up @@ -630,7 +682,7 @@
"metadata": {},
"outputs": [],
"source": [
"obs = obs.clone(observatory=aa4,lst_bin_size=observation_params[\"time_per_day_hrs\"] * un.hour)"
"obs = obs.clone(observatory=aa4, lst_bin_size=observation_params[\"time_per_day_hrs\"] * un.hour)"
]
},
{
Expand Down Expand Up @@ -696,7 +748,15 @@
"outputs": [],
"source": [
"def compare_senses(\n",
" senses1, senses2, redshift, kperp_Mpc, kpar_Mpc, label1=\"AA*\", label2=\"AA4\", plot_1d=True, **kwargs\n",
" senses1,\n",
" senses2,\n",
" redshift,\n",
" kperp_Mpc,\n",
" kpar_Mpc,\n",
" label1=\"AA*\",\n",
" label2=\"AA4\",\n",
" plot_1d=True,\n",
" **kwargs,\n",
"):\n",
" # We assume both senses have the same redshifts / freq bands\n",
" if np.all(senses1[\"redshifts\"] != senses2[\"redshifts\"]):\n",
Expand All @@ -718,12 +778,8 @@
" )\n",
" vmin = np.min(\n",
" [\n",
" np.nanpercentile(\n",
" senses1[band_name][\"sample_and_thermal_2D_mK2\"][mask1].ravel(), 5\n",
" ),\n",
" np.nanpercentile(\n",
" senses2[band_name][\"sample_and_thermal_2D_mK2\"][mask2].ravel(), 5\n",
" ),\n",
" np.nanpercentile(senses1[band_name][\"sample_and_thermal_2D_mK2\"][mask1].ravel(), 5),\n",
" np.nanpercentile(senses2[band_name][\"sample_and_thermal_2D_mK2\"][mask2].ravel(), 5),\n",
" ]\n",
" )\n",
" vmax = np.nanmin(\n",
Expand All @@ -740,7 +796,7 @@
" kperp_Mpc,\n",
" kpar_Mpc,\n",
" senses1[band_name][\"sample_and_thermal_2D_mK2\"].T,\n",
" norm=colors.LogNorm(vmin=vmin, vmax=vmax)\n",
" norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
" )\n",
" ax[0].set_title(label1)\n",
" ax[0].loglog()\n",
Expand All @@ -749,7 +805,7 @@
" kperp_Mpc,\n",
" kpar_Mpc,\n",
" senses2[band_name][\"sample_and_thermal_2D_mK2\"].T,\n",
" norm=colors.LogNorm(vmin=vmin, vmax=vmax)\n",
" norm=colors.LogNorm(vmin=vmin, vmax=vmax),\n",
" )\n",
" ax[1].set_title(label2)\n",
" ax[1].loglog()\n",
Expand Down Expand Up @@ -785,14 +841,14 @@
" all_band_names = [str(np.round(z2freq(i), 1)) + \" MHz\" for i in senses[0][\"redshifts\"]]\n",
" closest_k = np.argmin(np.abs(senses[0][band_name][\"k_1D_Mpc\"] - k))\n",
" # plot sens vs z at fixed k\n",
" fig, ax = plt.subplots(2, 1, figsize=(12, 10), sharex=True,layout=\"constrained\", gridspec_kw={\"hspace\": 0.05})\n",
" fig, ax = plt.subplots(\n",
" 2, 1, figsize=(12, 10), sharex=True, layout=\"constrained\", gridspec_kw={\"hspace\": 0.05}\n",
" )\n",
" for i, sense in enumerate(senses):\n",
" sensitivity_at_k = np.array([\n",
" sense[band][\"sample_and_thermal_1D_mK2\"][closest_k] for band in all_band_names\n",
" ])\n",
" theory_at_k = np.array([\n",
" sense[band][\"theory_1D_mK2\"][closest_k] for band in all_band_names\n",
" ])\n",
" sensitivity_at_k = np.array(\n",
" [sense[band][\"sample_and_thermal_1D_mK2\"][closest_k] for band in all_band_names]\n",
" )\n",
" theory_at_k = np.array([sense[band][\"theory_1D_mK2\"][closest_k] for band in all_band_names])\n",
" m = np.isinf(sensitivity_at_k)\n",
" ax[1].plot(\n",
" sense[\"redshifts\"][~m],\n",
Expand Down Expand Up @@ -821,15 +877,15 @@
" plt.text(\n",
" xlims[0] * 1.2,\n",
" ylims[1] * 0.5,\n",
" \"k = \"\n",
" + str(np.round(senses[0][band_name][\"k_1D_Mpc\"][closest_k], 2))\n",
" + \" Mpc$^{-1}$\",\n",
" \"k = \" + str(np.round(senses[0][band_name][\"k_1D_Mpc\"][closest_k], 2)) + \" Mpc$^{-1}$\",\n",
" fontsize=20,\n",
" )\n",
" plt.show()\n",
"\n",
" # plot sens vs k at fixed z\n",
" fig, ax = plt.subplots(2, 1, figsize=(12, 10), sharex=True,layout=\"constrained\", gridspec_kw={\"hspace\": 0.05})\n",
" fig, ax = plt.subplots(\n",
" 2, 1, figsize=(12, 10), sharex=True, layout=\"constrained\", gridspec_kw={\"hspace\": 0.05}\n",
" )\n",
" for i, sense in enumerate(senses):\n",
" sensitivity_at_z = sense[band_name][\"sample_and_thermal_1D_mK2\"]\n",
" m = np.isinf(sensitivity_at_z)\n",
Expand Down Expand Up @@ -927,7 +983,7 @@
" \"Deep + optimistic FG - AA4\",\n",
" ],\n",
" colors=[\"r\", \"r\", \"b\", \"b\", \"k\", \"k\", \"lime\", \"lime\"],\n",
" lss=[\"-\", \"--\", \"-\", \"--\", \"-\", \"--\",'-','--'],\n",
" lss=[\"-\", \"--\", \"-\", \"--\", \"-\", \"--\", \"-\", \"--\"],\n",
")"
]
},
Expand All @@ -954,11 +1010,13 @@
"outputs": [],
"source": [
"# Suppose I have a power-law 21-cm PS defined over some k and redshifts\n",
"mock={}\n",
"mock['k'] = np.logspace(-2, 1.5, 100)\n",
"mock['redshifts'] = np.linspace(6, 30, 100)\n",
"z,k = np.meshgrid(mock['redshifts'], mock['k'])\n",
"mock['PS'] = ((1 + z)**2 * (k / 100.0)**-2) << un.mK**2 # can be replaced with values read from a file."
"mock = {}\n",
"mock[\"k\"] = np.logspace(-2, 1.5, 100)\n",
"mock[\"redshifts\"] = np.linspace(6, 30, 100)\n",
"z, k = np.meshgrid(mock[\"redshifts\"], mock[\"k\"])\n",
"mock[\"PS\"] = (\n",
" (1 + z) ** 2 * (k / 100.0) ** -2\n",
") << un.mK**2 # can be replaced with values read from a file."
]
},
{
Expand All @@ -968,13 +1026,20 @@
"outputs": [],
"source": [
"import warnings\n",
"\n",
"\n",
"class MyPSmodel(TheoryModel):\n",
" \"\"\"Base class for theory models that are defined by a spline over (z,k).\"\"\"\n",
"\n",
" use_littleh = False\n",
"\n",
" def __init__(self):\n",
" self.k = mock['k']\n",
" self.z = mock['redshifts']\n",
" self.spline = RegularGridInterpolator((mock['redshifts'], mock['k']), mock['PS'],bounds_error=False)\n",
" self.k = mock[\"k\"]\n",
" self.z = mock[\"redshifts\"]\n",
" self.spline = RegularGridInterpolator(\n",
" (mock[\"redshifts\"], mock[\"k\"]), mock[\"PS\"], bounds_error=False\n",
" )\n",
"\n",
" def delta_squared(self, z: float, k: np.ndarray) -> un.Quantity[un.mK**2]:\n",
" \"\"\"Compute Delta^2(k, z) for the theory model.\n",
"\n",
Expand Down Expand Up @@ -1048,11 +1113,21 @@
}
],
"source": [
"plt.plot(ska_aaast_senses4_myps['60.0 MHz']['k_1D_Mpc'], ska_aaast_senses4_myps['60.0 MHz']['theory_1D_mK2'], color = 'cyan', label = 'PL')\n",
"plt.plot(ska_aaast_senses4['60.0 MHz']['k_1D_Mpc'], ska_aaast_senses4['60.0 MHz']['theory_1D_mK2'], color = 'green', label = 'EOS2021')\n",
"plt.yscale('log')\n",
"plt.ylabel(r'$\\Delta^2_{21}$ [mK$^2$]')\n",
"plt.xlabel(r'$k$ [Mpc$^{-1}$]')\n",
"plt.plot(\n",
" ska_aaast_senses4_myps[\"60.0 MHz\"][\"k_1D_Mpc\"],\n",
" ska_aaast_senses4_myps[\"60.0 MHz\"][\"theory_1D_mK2\"],\n",
" color=\"cyan\",\n",
" label=\"PL\",\n",
")\n",
"plt.plot(\n",
" ska_aaast_senses4[\"60.0 MHz\"][\"k_1D_Mpc\"],\n",
" ska_aaast_senses4[\"60.0 MHz\"][\"theory_1D_mK2\"],\n",
" color=\"green\",\n",
" label=\"EOS2021\",\n",
")\n",
"plt.yscale(\"log\")\n",
"plt.ylabel(r\"$\\Delta^2_{21}$ [mK$^2$]\")\n",
"plt.xlabel(r\"$k$ [Mpc$^{-1}$]\")\n",
"plt.legend(frameon=False)\n",
"plt.show()"
]
Expand Down Expand Up @@ -1081,11 +1156,21 @@
}
],
"source": [
"plt.plot(ska_aaast_senses4['60.0 MHz']['k_1D_Mpc'], ska_aaast_senses4['60.0 MHz']['sample_and_thermal_1D_mK2'], color = 'green', label = 'EOS2021')\n",
"plt.plot(ska_aaast_senses4_myps['60.0 MHz']['k_1D_Mpc'], ska_aaast_senses4_myps['60.0 MHz']['sample_and_thermal_1D_mK2'], color = 'cyan', label = 'PL')\n",
"plt.yscale('log')\n",
"plt.ylabel(r'$\\Delta^2_{21}$ [mK$^2$]')\n",
"plt.xlabel(r'$k$ [Mpc$^{-1}$]')\n",
"plt.plot(\n",
" ska_aaast_senses4[\"60.0 MHz\"][\"k_1D_Mpc\"],\n",
" ska_aaast_senses4[\"60.0 MHz\"][\"sample_and_thermal_1D_mK2\"],\n",
" color=\"green\",\n",
" label=\"EOS2021\",\n",
")\n",
"plt.plot(\n",
" ska_aaast_senses4_myps[\"60.0 MHz\"][\"k_1D_Mpc\"],\n",
" ska_aaast_senses4_myps[\"60.0 MHz\"][\"sample_and_thermal_1D_mK2\"],\n",
" color=\"cyan\",\n",
" label=\"PL\",\n",
")\n",
"plt.yscale(\"log\")\n",
"plt.ylabel(r\"$\\Delta^2_{21}$ [mK$^2$]\")\n",
"plt.xlabel(r\"$k$ [Mpc$^{-1}$]\")\n",
"plt.legend(frameon=False)\n",
"plt.show()"
]
Expand Down

0 comments on commit 1f0527b

Please sign in to comment.