Skip to content

Commit

Permalink
fixes for JAX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 28, 2024
1 parent 0cb0f41 commit dcd188f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
13 changes: 10 additions & 3 deletions k3_addons/metrics/f_scores_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def _test_tf(avg, beta, act, pred, sample_weights, threshold):


def _test_fbeta_score(actuals, preds, sample_weights, avg, beta_val, result, threshold):
actuals = ops.convert_to_tensor(actuals, "float32")
preds = ops.convert_to_tensor(preds, "float32")
if sample_weights is not None:
sample_weights = ops.convert_to_tensor(sample_weights, "float32")


tf_score = _test_tf(avg, beta_val, actuals, preds, sample_weights, threshold)
np.testing.assert_allclose(tf_score, result, atol=1e-7, rtol=1e-6)

Expand Down Expand Up @@ -186,7 +192,7 @@ def test_eq():
actuals = ops.convert_to_tensor(actuals, "float32")
fbeta.update_state(actuals, preds)
f1.update_state(actuals, preds)
np.testing.assert_allclose(fbeta.result().numpy(), f1.result().numpy())
np.testing.assert_allclose(ops.convert_to_numpy(fbeta.result()), ops.convert_to_numpy(f1.result()))


def test_sample_eq():
Expand All @@ -202,11 +208,11 @@ def test_sample_eq():
[0, 0, 1],
])
actuals = ops.convert_to_tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]])
sample_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
sample_weights = ops.convert_to_tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

f1.update_state(actuals, preds)
f1_weighted(actuals, preds, sample_weights)
np.testing.assert_allclose(f1.result().numpy(), f1_weighted.result().numpy())
np.testing.assert_allclose(ops.convert_to_numpy(f1.result()), ops.convert_to_numpy(f1_weighted.result()))


def test_keras_model_f1():
Expand All @@ -215,6 +221,7 @@ def test_keras_model_f1():


def test_config_f1():

f1 = F1Score(3)
config = f1.get_config()
assert "beta" not in config
Expand Down
4 changes: 2 additions & 2 deletions k3_addons/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ def _get_model(metric, num_output):
optimizer="adam", loss="categorical_crossentropy", metrics=["acc", metric]
)

data = np.random.random((10, 3))
labels = np.random.random((10, num_output))
data = keras.ops.convert_to_tensor(np.random.random((10, 3)))
labels = keras.ops.convert_to_tensor(np.random.random((10, num_output)))
model.fit(data, labels, epochs=1, batch_size=5, verbose=0)

0 comments on commit dcd188f

Please sign in to comment.