Skip to content

Commit

Permalink
other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwest-uw committed Jan 31, 2025
1 parent 44bf8ef commit 0e0f27b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,13 @@ def filter_stamps_by_cnn(result_data, model_path, coadd_type="mean", stamp_radiu
verbose : `bool`
Verbosity option the CNN predicition. Off by default.
"""
import tensorflow as tf
from tensorflow.keras.models import load_model

coadd_column = f"coadd_{coadd_type}"
if coadd_column not in result_data.colnames:
raise ValueError("result_data does not have provided coadd type as a column.")

cnn = tf.keras.models.load_model(model_path)
cnn = load_model(model_path)

stamps = result_data.table[coadd_column].data
normalized_stamps = _normalize_stamps(stamps)
Expand All @@ -308,6 +308,6 @@ def filter_stamps_by_cnn(result_data, model_path, coadd_type="mean", stamp_radiu
for p in predictions:
classsifications.append(np.argmax(p))

# TODO: maybe cast as a bool?
result_data.table["cnn_class"] = np.array(classsifications)
bool_arr = np.array(classsifications) != 0
result_data.table["cnn_class"] = bool_arr

0 comments on commit 0e0f27b

Please sign in to comment.