From 0e0f27b376f81543ca5ae8d85a0ccb74e4c1a722 Mon Sep 17 00:00:00 2001 From: Max West Date: Fri, 31 Jan 2025 10:30:06 -0800 Subject: [PATCH] other fixes --- src/kbmod/filters/stamp_filters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kbmod/filters/stamp_filters.py b/src/kbmod/filters/stamp_filters.py index 0e144ca9..e1266312 100644 --- a/src/kbmod/filters/stamp_filters.py +++ b/src/kbmod/filters/stamp_filters.py @@ -286,13 +286,13 @@ def filter_stamps_by_cnn(result_data, model_path, coadd_type="mean", stamp_radiu verbose : `bool` Verbosity option the CNN predicition. Off by default. """ - import tensorflow as tf + from tensorflow.keras.models import load_model coadd_column = f"coadd_{coadd_type}" if coadd_column not in result_data.colnames: raise ValueError("result_data does not have provided coadd type as a column.") - cnn = tf.keras.models.load_model(model_path) + cnn = load_model(model_path) stamps = result_data.table[coadd_column].data normalized_stamps = _normalize_stamps(stamps) @@ -308,6 +308,6 @@ def filter_stamps_by_cnn(result_data, model_path, coadd_type="mean", stamp_radiu for p in predictions: classsifications.append(np.argmax(p)) - # TODO: maybe cast as a bool? - result_data.table["cnn_class"] = np.array(classsifications) + bool_arr = np.array(classsifications) != 0 + result_data.table["cnn_class"] = bool_arr