From 44bf8ef37ac6990090ab1e38c944b8474ce05b8f Mon Sep 17 00:00:00 2001 From: Max West Date: Thu, 30 Jan 2025 17:08:54 -0800 Subject: [PATCH] initial pass --- pyproject.toml | 2 +- src/kbmod/filters/stamp_filters.py | 63 ++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 002661bc..60f829bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ Changelog = "https://epyc.astro.washington.edu/~kbmod/project_details/release_no [project.optional-dependencies] analysis = [ - "tensorflow>=2.9", + "tensorflow<=2.15", "matplotlib>=3.6.1", "ipywidgets>=8.0", "ephem>=4.1" diff --git a/src/kbmod/filters/stamp_filters.py b/src/kbmod/filters/stamp_filters.py index 0fd7d08d..0e144ca9 100644 --- a/src/kbmod/filters/stamp_filters.py +++ b/src/kbmod/filters/stamp_filters.py @@ -248,3 +248,66 @@ def append_all_stamps(result_data, im_stack, stamp_radius): # columns between tables. result_data.table["all_stamps"] = np.array(all_stamps) stamp_timer.stop() + +def _normalize_stamps(stamps): + """Normalize a list of stamps. Used for `filter_stamps_by_cnn`.""" + normed_stamps = [] + sigma_g_coeff = 0.7413 + for stamp in stamps: + stamp = np.copy(stamp) + stamp[np.isnan(stamp)] = 0 + + per25, per50, per75 = np.percentile(stamp, [25,50,75]) + sigmaG = sigma_g_coeff * (per75 - per25) + stamp[stamp<(per50-2*sigmaG)] = per50-2*sigmaG + + stamp -= np.min(stamp) + stamp /= np.sum(stamp) + stamp[np.isnan(stamp)] = 0 + normed_stamps.append(stamp.reshape(21,21)) + return np.array(normed_stamps) + +def filter_stamps_by_cnn(result_data, model_path, coadd_type="mean", stamp_radius=10, verbose=False): + """Given a set of results data, run the the requested coadded stamps through a + provided convolutional neural network and assign a new column that contains the + stamp classification, i.e. whehter or not the result passed the CNN filter. + + Parameters + ---------- + result_data : `Result` + The current set of results. Modified directly. + model_path : `str` + Path to the the tensorflow model and weights file. + coadd_type : `str` + Which coadd type to use in the filtering. Depends on how the model was trained. + stamp_radius : `int` + The radius used to generate the stamps. The dimmension of the stamps should be + (stamp_radius * 2) + 1. + verbose : `bool` + Verbosity option the CNN predicition. Off by default. + """ + import tensorflow as tf + + coadd_column = f"coadd_{coadd_type}" + if coadd_column not in result_data.colnames: + raise ValueError("result_data does not have provided coadd type as a column.") + + cnn = tf.keras.models.load_model(model_path) + + stamps = result_data.table[coadd_column].data + normalized_stamps = _normalize_stamps(stamps) + + # resize to match the tensorflow input + # will probably not be needed when we switch to PyTorch + stamp_dimm = (stamp_radius * 2) + 1 + resized_stamps = normalized_stamps.reshape(-1, stamp_dimm, stamp_dimm, 1) + + predictions = cnn.predict(resized_stamps, verbose=verbose) + + classsifications = [] + for p in predictions: + classsifications.append(np.argmax(p)) + + # TODO: maybe cast as a bool? + result_data.table["cnn_class"] = np.array(classsifications) +