Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 19, 2024
1 parent 64befcf commit d6eda25
Show file tree
Hide file tree
Showing 18 changed files with 180 additions and 136 deletions.
79 changes: 43 additions & 36 deletions docs/tutorials/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from astropy import units as un\n",
"\n",
"%matplotlib inline\n",
"\n",
"from py21cmsense import GaussianBeam, Observatory, Observation, PowerSpectrum, hera"
"from py21cmsense import GaussianBeam, Observation, Observatory, PowerSpectrum, hera"
]
},
{
Expand Down Expand Up @@ -61,11 +62,11 @@
"outputs": [],
"source": [
"sensitivity = PowerSpectrum(\n",
" observation = Observation(\n",
" observatory = Observatory(\n",
" antpos = hera(hex_num=7, separation=14*un.m),\n",
" beam = GaussianBeam(frequency=135.0*un.MHz, dish_size=14*un.m),\n",
" latitude=38*un.deg\n",
" observation=Observation(\n",
" observatory=Observatory(\n",
" antpos=hera(hex_num=7, separation=14 * un.m),\n",
" beam=GaussianBeam(frequency=135.0 * un.MHz, dish_size=14 * un.m),\n",
" latitude=38 * un.deg,\n",
" )\n",
" )\n",
")"
Expand Down Expand Up @@ -151,9 +152,9 @@
"source": [
"plt.plot(sensitivity.k1d, power_std)\n",
"plt.xlabel(\"k [h/Mpc]\")\n",
"plt.ylabel(r'$\\delta \\Delta^2_{21}$')\n",
"plt.yscale('log')\n",
"plt.xscale('log')"
"plt.ylabel(r\"$\\delta \\Delta^2_{21}$\")\n",
"plt.yscale(\"log\")\n",
"plt.xscale(\"log\")"
]
},
{
Expand Down Expand Up @@ -211,14 +212,14 @@
}
],
"source": [
"plt.plot(sensitivity.k1d, power_std, label='Total')\n",
"plt.plot(sensitivity.k1d, power_std_thermal, label='Thermal')\n",
"plt.plot(sensitivity.k1d, power_std_sample, label='Sample')\n",
"plt.plot(sensitivity.k1d, power_std, label=\"Total\")\n",
"plt.plot(sensitivity.k1d, power_std_thermal, label=\"Thermal\")\n",
"plt.plot(sensitivity.k1d, power_std_sample, label=\"Sample\")\n",
"\n",
"plt.xlabel(\"k [h/Mpc]\")\n",
"plt.ylabel(r'$\\delta \\Delta^2_{21}$')\n",
"plt.yscale('log')\n",
"plt.xscale('log')\n",
"plt.ylabel(r\"$\\delta \\Delta^2_{21}$\")\n",
"plt.yscale(\"log\")\n",
"plt.xscale(\"log\")\n",
"plt.legend();"
]
},
Expand Down Expand Up @@ -412,7 +413,7 @@
}
],
"source": [
"beam.at(160*un.MHz).b_eff/beam.b_eff"
"beam.at(160 * un.MHz).b_eff / beam.b_eff"
]
},
{
Expand Down Expand Up @@ -476,8 +477,8 @@
}
],
"source": [
"plt.figure(figsize=(5,4.5))\n",
"plt.scatter(observatory.baselines_metres[:,:, 0], observatory.baselines_metres[:,:,1], alpha=0.1)\n",
"plt.figure(figsize=(5, 4.5))\n",
"plt.scatter(observatory.baselines_metres[:, :, 0], observatory.baselines_metres[:, :, 1], alpha=0.1)\n",
"plt.xlabel(\"Baseline Length [x, m]\")\n",
"plt.ylabel(\"Baseline Length [y, m]\");"
]
Expand Down Expand Up @@ -621,9 +622,9 @@
"baseline_group_counts = observatory.baseline_weights_from_groups(red_bl)\n",
"\n",
"\n",
"plt.figure(figsize=(7,5))\n",
"plt.scatter(baseline_group_coords[:,0], baseline_group_coords[:,1], c=baseline_group_counts)\n",
"cbar = plt.colorbar();\n",
"plt.figure(figsize=(7, 5))\n",
"plt.scatter(baseline_group_coords[:, 0], baseline_group_coords[:, 1], c=baseline_group_counts)\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Number of baselines in group\", fontsize=15)\n",
"plt.tight_layout();"
]
Expand Down Expand Up @@ -656,9 +657,7 @@
],
"source": [
"coherent_grid = observatory.grid_baselines(\n",
" coherent=True,\n",
" baselines=baseline_group_coords,\n",
" weights=baseline_group_counts\n",
" coherent=True, baselines=baseline_group_coords, weights=baseline_group_counts\n",
")"
]
},
Expand Down Expand Up @@ -693,7 +692,7 @@
}
],
"source": [
"plt.imshow(coherent_grid, extent=(observatory.ugrid().min(), observatory.ugrid().max())*2)\n",
"plt.imshow(coherent_grid, extent=(observatory.ugrid().min(), observatory.ugrid().max()) * 2)\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Effective # of Samples\")"
]
Expand Down Expand Up @@ -758,13 +757,13 @@
}
],
"source": [
"plt.figure(figsize=(7,5))\n",
"plt.figure(figsize=(7, 5))\n",
"x = [bl_group[0] for bl_group in observation.baseline_groups]\n",
"y = [bl_group[1] for bl_group in observation.baseline_groups]\n",
"c = [len(bls) for bls in observation.baseline_groups.values()]\n",
"\n",
"plt.scatter(x, y, c=c)\n",
"cbar = plt.colorbar();\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Number of baselines in group\", fontsize=15)\n",
"plt.tight_layout();"
]
Expand Down Expand Up @@ -800,7 +799,10 @@
}
],
"source": [
"plt.imshow(observation.total_integration_time.to(\"hour\").value, extent=(observation.ugrid.min(), observation.ugrid.max())*2)\n",
"plt.imshow(\n",
" observation.total_integration_time.to(\"hour\").value,\n",
" extent=(observation.ugrid.min(), observation.ugrid.max()) * 2,\n",
")\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Total Integration Time [hours]\")"
]
Expand Down Expand Up @@ -859,8 +861,10 @@
}
],
"source": [
"plt.imshow(observation_2.total_integration_time.to(\"hour\").value, \n",
" extent=(observation_2.ugrid.min(), observation_2.ugrid.max())*2)\n",
"plt.imshow(\n",
" observation_2.total_integration_time.to(\"hour\").value,\n",
" extent=(observation_2.ugrid.min(), observation_2.ugrid.max()) * 2,\n",
")\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Total Integration Time [hours]\")"
]
Expand Down Expand Up @@ -994,8 +998,8 @@
}
],
"source": [
"kperp = un.Quantity(np.linspace(0.01, 0.07, 15), 'littleh/Mpc')\n",
"kpar = un.Quantity( np.linspace(0.1, 2, 15), 'littleh/Mpc')\n",
"kperp = un.Quantity(np.linspace(0.01, 0.07, 15), \"littleh/Mpc\")\n",
"kpar = un.Quantity(np.linspace(0.1, 2, 15), \"littleh/Mpc\")\n",
"sense_gridded = sensitivity.calculate_sensitivity_2d_grid(kperp_edges=kperp, kpar_edges=kpar)"
]
},
Expand Down Expand Up @@ -1028,7 +1032,7 @@
}
],
"source": [
"plt.imshow(np.log10(sense_gridded.value.T), origin='lower')\n",
"plt.imshow(np.log10(sense_gridded.value.T), origin=\"lower\")\n",
"plt.colorbar()"
]
},
Expand Down Expand Up @@ -1059,8 +1063,11 @@
"metadata": {},
"outputs": [],
"source": [
"cable_refl_range = (un.Quantity(1.2, 'littleh/Mpc'), un.Quantity(1.4, 'littleh/Mpc'))\n",
"new_sense = sensitivity.clone(systematics_mask=lambda kperp, kpar: (np.abs(kpar) < cable_refl_range[0]) | (np.abs(kpar) > cable_refl_range[1]))"
"cable_refl_range = (un.Quantity(1.2, \"littleh/Mpc\"), un.Quantity(1.4, \"littleh/Mpc\"))\n",
"new_sense = sensitivity.clone(\n",
" systematics_mask=lambda kperp, kpar: (np.abs(kpar) < cable_refl_range[0])\n",
" | (np.abs(kpar) > cable_refl_range[1])\n",
")"
]
},
{
Expand Down
92 changes: 58 additions & 34 deletions docs/tutorials/reproducing_pober_2015.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@
"metadata": {},
"outputs": [],
"source": [
"import py21cmsense as p21c\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from astropy import constants\n",
"from astropy import units as un\n",
"from astropy.cosmology.units import littleh\n",
"from astropy.cosmology import Planck15\n",
"from astropy import constants\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
"from astropy.cosmology.units import littleh\n",
"\n",
"import py21cmsense as p21c"
]
},
{
Expand All @@ -78,7 +79,7 @@
"metadata": {},
"outputs": [],
"source": [
"memo_data = np.load(p21c.data.PATH / 'hera331drift_mod_0.135.npz')"
"memo_data = np.load(p21c.data.PATH / \"hera331drift_mod_0.135.npz\")"
]
},
{
Expand All @@ -98,8 +99,8 @@
}
],
"source": [
"plt.plot(memo_data['ks'], memo_data['T_errs'], label='Thermal Variance')\n",
"plt.plot(memo_data['ks'], memo_data['errs'], label='Sample+Thermal Variance')\n",
"plt.plot(memo_data[\"ks\"], memo_data[\"T_errs\"], label=\"Thermal Variance\")\n",
"plt.plot(memo_data[\"ks\"], memo_data[\"errs\"], label=\"Sample+Thermal Variance\")\n",
"plt.yscale(\"log\")\n",
"plt.title(\"Original Variance from Pober15\")\n",
"plt.ylabel(r\"$\\Delta^2_{21}$ [mK$^2$]\")\n",
Expand Down Expand Up @@ -129,7 +130,7 @@
"metadata": {},
"outputs": [],
"source": [
"hera_ants = p21c.antpos.hera(hex_num=11, row_separation=12.12*un.m)"
"hera_ants = p21c.antpos.hera(hex_num=11, row_separation=12.12 * un.m)"
]
},
{
Expand All @@ -156,7 +157,9 @@
"metadata": {},
"outputs": [],
"source": [
"beam = p21c.GaussianBeam(frequency=135*un.MHz, dish_size=7 * (constants.c / (150 * un.MHz)).to(\"m\"))"
"beam = p21c.GaussianBeam(\n",
" frequency=135 * un.MHz, dish_size=7 * (constants.c / (150 * un.MHz)).to(\"m\")\n",
")"
]
},
{
Expand Down Expand Up @@ -200,8 +203,8 @@
"hera = p21c.Observatory(\n",
" antpos=hera_ants,\n",
" beam=beam,\n",
" latitude=0.6707845*un.rad,\n",
" Trcv=100*un.K,\n",
" latitude=0.6707845 * un.rad,\n",
" Trcv=100 * un.K,\n",
" beam_crossing_time_incl_latitude=False,\n",
")"
]
Expand Down Expand Up @@ -246,17 +249,17 @@
"source": [
"obs = p21c.Observation(\n",
" observatory=hera,\n",
" tsky_amplitude=60*un.K,\n",
" tsky_ref_freq=300*un.MHz,\n",
" tsky_amplitude=60 * un.K,\n",
" tsky_ref_freq=300 * un.MHz,\n",
" spectral_index=2.6,\n",
" n_days=180,\n",
" time_per_day=6*un.hour,\n",
" bandwidth=8*un.MHz,\n",
" time_per_day=6 * un.hour,\n",
" bandwidth=8 * un.MHz,\n",
" n_channels=82,\n",
" integration_time=60*un.s,\n",
" lst_bin_size=beam.at(150*un.MHz).fwhm.value * 12/np.pi * 3600*un.s,\n",
" integration_time=60 * un.s,\n",
" lst_bin_size=beam.at(150 * un.MHz).fwhm.value * 12 / np.pi * 3600 * un.s,\n",
" use_approximate_cosmo=True,\n",
" cosmo=Planck15.clone(H0=70.0, Om0=0.266)\n",
" cosmo=Planck15.clone(H0=70.0, Om0=0.266),\n",
")"
]
},
Expand Down Expand Up @@ -328,8 +331,8 @@
"source": [
"sense_moderate = p21c.PowerSpectrum(\n",
" observation=obs,\n",
" foreground_model='moderate',\n",
" horizon_buffer=0.1 *littleh/un.Mpc,\n",
" foreground_model=\"moderate\",\n",
" horizon_buffer=0.1 * littleh / un.Mpc,\n",
" theory_model=p21c.theory.Legacy21cmFAST(),\n",
")"
]
Expand All @@ -351,7 +354,7 @@
],
"source": [
"sense1d = sense_moderate.calculate_sensitivity_1d(thermal=True, sample=True)\n",
"sense1d_t = sense_moderate.calculate_sensitivity_1d(thermal=True, sample=False)\n"
"sense1d_t = sense_moderate.calculate_sensitivity_1d(thermal=True, sample=False)"
]
},
{
Expand Down Expand Up @@ -388,20 +391,37 @@
}
],
"source": [
"fig, ax = plt.subplots(2, 1, sharex=True, figsize=(12, 6), constrained_layout=True, gridspec_kw={\"height_ratios\": [0.65, 0.35]})\n",
"ax[0].scatter(sense_moderate.k1d, sense1d, label='Modern 21cmSense', marker='x', color='C0')\n",
"ax[0].scatter(memo_data['ks'], memo_data['errs'], label='memo', color='C0', lw=1, facecolor='none')\n",
"fig, ax = plt.subplots(\n",
" 2,\n",
" 1,\n",
" sharex=True,\n",
" figsize=(12, 6),\n",
" constrained_layout=True,\n",
" gridspec_kw={\"height_ratios\": [0.65, 0.35]},\n",
")\n",
"ax[0].scatter(sense_moderate.k1d, sense1d, label=\"Modern 21cmSense\", marker=\"x\", color=\"C0\")\n",
"ax[0].scatter(memo_data[\"ks\"], memo_data[\"errs\"], label=\"memo\", color=\"C0\", lw=1, facecolor=\"none\")\n",
"\n",
"ax[0].scatter(sense_moderate.k1d, sense1d_t, label='Modern 21cmSense (Thermal Only)', marker='x', color='C1')\n",
"ax[0].scatter(memo_data['ks'], memo_data['T_errs'], label='memo (thermal Only)', color='C1', facecolor='none')\n",
"ax[0].scatter(\n",
" sense_moderate.k1d, sense1d_t, label=\"Modern 21cmSense (Thermal Only)\", marker=\"x\", color=\"C1\"\n",
")\n",
"ax[0].scatter(\n",
" memo_data[\"ks\"], memo_data[\"T_errs\"], label=\"memo (thermal Only)\", color=\"C1\", facecolor=\"none\"\n",
")\n",
"\n",
"ax[0].set_yscale(\"log\")\n",
"ax[0].set_xscale('log')\n",
"ax[0].set_xscale(\"log\")\n",
"ax[0].set_ylabel(r\"$\\Delta^2$ [mK$^2$]\")\n",
"ax[1].set_ylabel(\"Fractional Difference (%)\")\n",
"ax[1].plot(memo_data['ks'], 100*(sense1d[:len(memo_data['ks'])].value- memo_data['errs'])/memo_data['errs'])\n",
"ax[1].plot(memo_data['ks'], 100*(sense1d_t[:len(memo_data['ks'])].value- memo_data['T_errs'])/memo_data['T_errs'])\n",
"ax[0].legend();\n"
"ax[1].plot(\n",
" memo_data[\"ks\"],\n",
" 100 * (sense1d[: len(memo_data[\"ks\"])].value - memo_data[\"errs\"]) / memo_data[\"errs\"],\n",
")\n",
"ax[1].plot(\n",
" memo_data[\"ks\"],\n",
" 100 * (sense1d_t[: len(memo_data[\"ks\"])].value - memo_data[\"T_errs\"]) / memo_data[\"T_errs\"],\n",
")\n",
"ax[0].legend();"
]
},
{
Expand All @@ -426,10 +446,14 @@
"metadata": {},
"outputs": [],
"source": [
"mask = ~np.isinf(memo_data['errs'])\n",
"mask = ~np.isinf(memo_data[\"errs\"])\n",
"\n",
"assert np.allclose(sense1d[:len(memo_data['ks'])][mask][:-1].value, memo_data['errs'][mask][:-1], rtol=1e-1)\n",
"assert np.allclose(sense1d_t[:len(memo_data['ks'])][mask][:-1].value, memo_data['T_errs'][mask][:-1], rtol=1e-2)\n"
"assert np.allclose(\n",
" sense1d[: len(memo_data[\"ks\"])][mask][:-1].value, memo_data[\"errs\"][mask][:-1], rtol=1e-1\n",
")\n",
"assert np.allclose(\n",
" sense1d_t[: len(memo_data[\"ks\"])][mask][:-1].value, memo_data[\"T_errs\"][mask][:-1], rtol=1e-2\n",
")"
]
}
],
Expand Down
Loading

0 comments on commit d6eda25

Please sign in to comment.