Skip to content

Commit

Permalink
Fix: Jensen Shannon square root (#233)
Browse files Browse the repository at this point in the history
Closes #207 

The Jensen-Shannon measure implementation in SimSIMD is different 
from the implementation in SciPy. It turns out that a scaling factor of 0.5 
was missed and this fix seems to match the example provided in SciPy.

---------

Co-authored-by: Ash Vardanian <[email protected]>
  • Loading branch information
GoWind and ashvardanian authored Nov 17, 2024
1 parent c124410 commit 5fed772
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
11 changes: 6 additions & 5 deletions include/simsimd/probability.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_
d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \
d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \
} \
*result = (simsimd_distance_t)d / 2; \
*result = SIMSIMD_SQRT(((simsimd_distance_t)d / 2)); \
}

SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial
Expand Down Expand Up @@ -219,12 +219,13 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const *a, simsimd_f32_t co
float32x4_t log_ratio_b_vec = _simsimd_log2_f32_neon(ratio_b_vec);
float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec);
float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec);

sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec));
if (n != 0) goto simsimd_js_f32_neon_cycle;

simsimd_f32_t log2_normalizer = 0.693147181f;
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer;
*result = sum / 2;
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2;
*result = SIMSIMD_SQRT(sum);
}

#pragma clang attribute pop
Expand Down Expand Up @@ -296,8 +297,8 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const *a, simsimd_f16_t co
if (n) goto simsimd_js_f16_neon_cycle;

simsimd_f32_t log2_normalizer = 0.693147181f;
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer;
*result = sum / 2;
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2;
*result = SIMSIMD_SQRT(sum);
}

#pragma clang attribute pop
Expand Down
15 changes: 11 additions & 4 deletions scripts/test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,15 @@ test("Kullback-Leibler C vs JS", () => {
});

test("Jensen-Shannon C vs JS", () => {
const f32sDistribution = new Float32Array([1.0 / 6, 2.0 / 6, 3.0 / 6]);
const result = simsimd.jensenshannon(f32sDistribution, f32sDistribution);
const resultjs = fallback.jensenshannon(f32sDistribution, f32sDistribution);
assertAlmostEqual(resultjs, result, 0.01);
const f32sDistribution = new Float32Array([1.0, 0.0]);
const f32sDistribution2 = new Float32Array([0.5, 0.5]);
const result = simsimd.jensenshannon(f32sDistribution, f32sDistribution2);
const resultjs = fallback.jensenshannon(f32sDistribution, f32sDistribution2);
assertAlmostEqual(result, resultjs, 0.01);

const orthogonalVec1 = new Float32Array([1.0, 0.0, 0.0]);
const orthogonalVec2 = new Float32Array([0.0, 1.0, 0.0]);
const orthoResult = simsimd.jensenshannon(orthogonalVec1, orthogonalVec2);
const orthoResultJs = fallback.jensenshannon(orthogonalVec1, orthogonalVec2);
assertAlmostEqual(orthoResult, orthoResultJs, 0.01);
});
10 changes: 6 additions & 4 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
python test.py
"""

import os
import math
import time
Expand Down Expand Up @@ -124,7 +123,7 @@ def baseline_wsum(x, y, alpha, beta):
baseline_euclidean = lambda x, y: np.array(spd.euclidean(x, y)) #! SciPy returns a scalar
baseline_sqeuclidean = spd.sqeuclidean
baseline_cosine = spd.cosine
baseline_jensenshannon = lambda x, y: spd.jensenshannon(x, y) ** 2
baseline_jensenshannon = lambda x, y: spd.jensenshannon(x, y)
baseline_hamming = lambda x, y: spd.hamming(x, y) * len(x)
baseline_jaccard = spd.jaccard

Expand Down Expand Up @@ -453,6 +452,8 @@ def name_to_kernels(name: str):
return baseline_fma, simd.fma
elif name == "wsum":
return baseline_wsum, simd.wsum
elif name == "jensenshannon":
return baseline_jensenshannon, simd.jensenshannon
else:
raise ValueError(f"Unknown kernel name: {name}")

Expand Down Expand Up @@ -839,12 +840,13 @@ def test_dense_bits(ndim, metric, capability, stats_fixture):
collect_errors(metric, ndim, "bin8", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture)


@pytest.mark.skip(reason="Problems inferring the tolerance bounds for numerical errors")
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
@pytest.mark.parametrize("capability", possible_capabilities)
def test_jensen_shannon(ndim, dtype, capability):
def test_jensen_shannon(ndim, dtype, capability, stats_fixture):
"""Compares the simd.jensenshannon() function with scipy.spatial.distance.jensenshannon(), measuring the accuracy error for f16, and f32 types."""
np.random.seed()
a = np.abs(np.random.randn(ndim)).astype(dtype)
Expand Down

0 comments on commit 5fed772

Please sign in to comment.