diff --git a/src/tdastro/astro_utils/opsim.py b/src/tdastro/astro_utils/opsim.py index 850b63ff..ca1c969b 100644 --- a/src/tdastro/astro_utils/opsim.py +++ b/src/tdastro/astro_utils/opsim.py @@ -418,3 +418,107 @@ def opsim_add_random_data(opsim_data, colname, min_val=0.0, max_val=1.0): """ values = np.random.uniform(low=min_val, high=max_val, size=len(opsim_data)) opsim_data.add_column(colname, values) + + +def oversample_opsim( + opsim: OpSim, + *, + pointing: tuple[float, float] = (200, -50), + search_radius: float = 1.75, + delta_t: float = 0.01, + time_range: tuple[float | None, float | None] = (None, None), + bands: list[str] | None = None, + strategy: str = "darkest_sky", +): + """Single-pointing oversampled OpSim table. + + It includes observations for a single pointing only, + but with very high time resolution. The observations + would alternate between the bands. + + Parameters + ---------- + opsim : OpSim + The OpSim table to oversample. + pointing : tuple of RA and Dec in degrees + The pointing to use for the oversampled table. + search_radius : float, optional + The search radius for the oversampled table in degrees. + The default is the half of the LSST's field of view. + delta_t : float, optional + The time between observations in days. + time_range : tuple or floats or Nones, optional + The start and end times of the observations in MJD. + `None` means to use the minimum (maximum) time in + all the observations found for the given pointing. + Time is being samples as np.arange(*time_range, delta_t). + bands : list of str or None, optional + The list of bands to include in the oversampled table. + The default is to include all bands found for the given pointing. + strategy : str, optional + The strategy to select prototype observations. + - "darkest_sky" selects the observations with the minimal sky brightness + (maximum "skyBrightness" value) in each band. This is the default. + - "random" selects the observations randomly. Fixed seed is used. + + """ + ra, dec = pointing + observations = opsim.table.iloc[opsim.range_search(ra, dec, search_radius)] + if len(observations) == 0: + raise ValueError("No observations found for the given pointing.") + + time_min, time_max = time_range + if time_min is None: + time_min = np.min(observations["observationStartMJD"]) + if time_max is None: + time_max = np.max(observations["observationStartMJD"]) + if time_min >= time_max: + raise ValueError(f"Invalid time_range: start > end: {time_min} > {time_max}") + + uniq_bands = np.unique(observations["filter"]) + if bands is None: + bands = uniq_bands + elif not set(bands).issubset(uniq_bands): + raise ValueError(f"Invalid bands: {bands}") + + new_times = np.arange(time_min, time_max, delta_t) + n = len(new_times) + if n < len(bands): + raise ValueError("Not enough time points to cover all bands.") + + new_table = pd.DataFrame( + { + # Just in case, to not have confusion with the original table + "observationId": opsim.table["observationId"].max() + 1 + np.arange(n), + "observationStartMJD": new_times, + "fieldRA": ra, + "fieldDec": dec, + "filter": np.tile(bands, n // len(bands)), + } + ) + other_columns = [column for column in observations.columns if column not in new_table.columns] + + if strategy == "darkest_sky": + for band in bands: + # MAXimum magnitude is MINimum brightness (darkest sky) + idxmax = observations["skyBrightness"][observations["filter"] == band].idxmax() + idx = new_table.index[new_table["filter"] == band] + darkest_sky_obs = pd.DataFrame.from_records([observations.loc[idxmax]] * idx.size, index=idx) + new_table.loc[idx, other_columns] = darkest_sky_obs[other_columns] + elif strategy == "random": + rng = np.random.default_rng(0) + for band in bands: + single_band_obs = observations[observations["filter"] == band] + idx = new_table.index[new_table["filter"] == band] + random_obs = single_band_obs.sample(idx.size, replace=True, random_state=rng).set_index(idx) + new_table.loc[idx, other_columns] = random_obs[other_columns] + else: + raise ValueError(f"Invalid strategy: {strategy}") + + return OpSim( + new_table, + colmap=opsim.colmap, + pixel_scale=opsim.pixel_scale, + read_noise=opsim.read_noise, + dark_current=opsim.dark_current, + ) diff --git a/tests/tdastro/astro_utils/test_opsim.py b/tests/tdastro/astro_utils/test_opsim.py index ab0d53ff..372d44ae 100644 --- a/tests/tdastro/astro_utils/test_opsim.py +++ b/tests/tdastro/astro_utils/test_opsim.py @@ -9,6 +9,7 @@ from tdastro.astro_utils.opsim import ( OpSim, opsim_add_random_data, + oversample_opsim, ) @@ -251,3 +252,36 @@ def test_opsim_flux_err_point_source(opsim_shorten): # Tolerance is very high, we should investigate why the values are so different. np.testing.assert_allclose(flux_err, expected_flux_err, rtol=0.2) + + +def test_oversample_opsim(opsim_shorten): + """Test that we can oversample an OpSim file.""" + opsim = OpSim.from_db(opsim_shorten) + + bands = ["g", "r"] + ra, dec = 205.0, -57.0 + time_range = 60_000.0, 60_010.0 + delta_t = 0.01 + + for strategy in ["darkest_sky", "random"]: + oversampled = oversample_opsim( + opsim, + pointing=(ra, dec), + time_range=time_range, + delta_t=delta_t, + bands=bands, + strategy=strategy, + ) + assert set(opsim.table.columns) == set(oversampled.table.columns), "columns are not the same" + np.testing.assert_allclose( + np.diff(oversampled["observationStartMJD"]), delta_t, err_msg="delta_t is not correct" + ) + np.testing.assert_allclose(oversampled["fieldRA"], ra, err_msg="RA is not correct") + np.testing.assert_allclose(oversampled["fieldDec"], dec, err_msg="Dec is not correct") + assert np.all(oversampled["observationStartMJD"] >= time_range[0]), "time range is not correct" + assert np.all(oversampled["observationStartMJD"] <= time_range[1]), "time range is not correct" + assert set(oversampled["filter"]) == set(bands), "oversampled table has the wrong bands" + assert ( + oversampled["skyBrightness"].unique().size >= oversampled["filter"].unique().size + ), "there should be at least as many skyBrightness values as bands" + assert oversampled["skyBrightness"].isna().sum() == 0, "skyBrightness has NaN values"