From dcd188f138fa742b084d27a5727aeaf2f5e4280d Mon Sep 17 00:00:00 2001 From: Muhammad Anas Raza Date: Wed, 28 Feb 2024 10:53:18 -0500 Subject: [PATCH] fixes for JAX backend --- k3_addons/metrics/f_scores_test.py | 13 ++++++++++--- k3_addons/metrics/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/k3_addons/metrics/f_scores_test.py b/k3_addons/metrics/f_scores_test.py index 304b31f..286d284 100644 --- a/k3_addons/metrics/f_scores_test.py +++ b/k3_addons/metrics/f_scores_test.py @@ -35,6 +35,12 @@ def _test_tf(avg, beta, act, pred, sample_weights, threshold): def _test_fbeta_score(actuals, preds, sample_weights, avg, beta_val, result, threshold): + actuals = ops.convert_to_tensor(actuals, "float32") + preds = ops.convert_to_tensor(preds, "float32") + if sample_weights is not None: + sample_weights = ops.convert_to_tensor(sample_weights, "float32") + + tf_score = _test_tf(avg, beta_val, actuals, preds, sample_weights, threshold) np.testing.assert_allclose(tf_score, result, atol=1e-7, rtol=1e-6) @@ -186,7 +192,7 @@ def test_eq(): actuals = ops.convert_to_tensor(actuals, "float32") fbeta.update_state(actuals, preds) f1.update_state(actuals, preds) - np.testing.assert_allclose(fbeta.result().numpy(), f1.result().numpy()) + np.testing.assert_allclose(ops.convert_to_numpy(fbeta.result()), ops.convert_to_numpy(f1.result())) def test_sample_eq(): @@ -202,11 +208,11 @@ def test_sample_eq(): [0, 0, 1], ]) actuals = ops.convert_to_tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]]) - sample_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + sample_weights = ops.convert_to_tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) f1.update_state(actuals, preds) f1_weighted(actuals, preds, sample_weights) - np.testing.assert_allclose(f1.result().numpy(), f1_weighted.result().numpy()) + np.testing.assert_allclose(ops.convert_to_numpy(f1.result()), ops.convert_to_numpy(f1_weighted.result())) def test_keras_model_f1(): @@ -215,6 +221,7 @@ def test_keras_model_f1(): def test_config_f1(): + f1 = F1Score(3) config = f1.get_config() assert "beta" not in config diff --git a/k3_addons/metrics/utils.py b/k3_addons/metrics/utils.py index 9688156..97e951c 100644 --- a/k3_addons/metrics/utils.py +++ b/k3_addons/metrics/utils.py @@ -11,6 +11,6 @@ def _get_model(metric, num_output): optimizer="adam", loss="categorical_crossentropy", metrics=["acc", metric] ) - data = np.random.random((10, 3)) - labels = np.random.random((10, num_output)) + data = keras.ops.convert_to_tensor(np.random.random((10, 3))) + labels = keras.ops.convert_to_tensor(np.random.random((10, num_output))) model.fit(data, labels, epochs=1, batch_size=5, verbose=0)