Skip to content

Commit

Permalink
initial pass
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwest-uw committed Jan 31, 2025
1 parent 3fc159d commit 44bf8ef
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Changelog = "https://epyc.astro.washington.edu/~kbmod/project_details/release_no

[project.optional-dependencies]
analysis = [
"tensorflow>=2.9",
"tensorflow<=2.15",
"matplotlib>=3.6.1",
"ipywidgets>=8.0",
"ephem>=4.1"
Expand Down
63 changes: 63 additions & 0 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,66 @@ def append_all_stamps(result_data, im_stack, stamp_radius):
# columns between tables.
result_data.table["all_stamps"] = np.array(all_stamps)
stamp_timer.stop()

def _normalize_stamps(stamps):
"""Normalize a list of stamps. Used for `filter_stamps_by_cnn`."""
normed_stamps = []
sigma_g_coeff = 0.7413
for stamp in stamps:
stamp = np.copy(stamp)
stamp[np.isnan(stamp)] = 0

per25, per50, per75 = np.percentile(stamp, [25,50,75])
sigmaG = sigma_g_coeff * (per75 - per25)
stamp[stamp<(per50-2*sigmaG)] = per50-2*sigmaG

stamp -= np.min(stamp)
stamp /= np.sum(stamp)
stamp[np.isnan(stamp)] = 0
normed_stamps.append(stamp.reshape(21,21))
return np.array(normed_stamps)

def filter_stamps_by_cnn(result_data, model_path, coadd_type="mean", stamp_radius=10, verbose=False):
"""Given a set of results data, run the the requested coadded stamps through a
provided convolutional neural network and assign a new column that contains the
stamp classification, i.e. whehter or not the result passed the CNN filter.
Parameters
----------
result_data : `Result`
The current set of results. Modified directly.
model_path : `str`
Path to the the tensorflow model and weights file.
coadd_type : `str`
Which coadd type to use in the filtering. Depends on how the model was trained.
stamp_radius : `int`
The radius used to generate the stamps. The dimmension of the stamps should be
(stamp_radius * 2) + 1.
verbose : `bool`
Verbosity option the CNN predicition. Off by default.
"""
import tensorflow as tf

coadd_column = f"coadd_{coadd_type}"
if coadd_column not in result_data.colnames:
raise ValueError("result_data does not have provided coadd type as a column.")

cnn = tf.keras.models.load_model(model_path)

stamps = result_data.table[coadd_column].data
normalized_stamps = _normalize_stamps(stamps)

# resize to match the tensorflow input
# will probably not be needed when we switch to PyTorch
stamp_dimm = (stamp_radius * 2) + 1
resized_stamps = normalized_stamps.reshape(-1, stamp_dimm, stamp_dimm, 1)

predictions = cnn.predict(resized_stamps, verbose=verbose)

classsifications = []
for p in predictions:
classsifications.append(np.argmax(p))

# TODO: maybe cast as a bool?
result_data.table["cnn_class"] = np.array(classsifications)

0 comments on commit 44bf8ef

Please sign in to comment.