diff --git a/pecos/core/base.py b/pecos/core/base.py index 1c4cdfdf..a72ebf7e 100644 --- a/pecos/core/base.py +++ b/pecos/core/base.py @@ -534,6 +534,7 @@ def __init__(self, dirname, soname, forced_rebuild=False): self.link_ann_hnsw_methods() self.link_mmap_hashmap_methods() self.link_mmap_valstore_methods() + self.link_calibrator_methods() def link_xlinear_methods(self): """ @@ -1939,5 +1940,60 @@ def mmap_valstore_init(self, store_type): raise NotImplementedError(f"store_type={store_type} is not implemented.") return self.mmap_valstore_fn_dict[store_type] + def link_calibrator_methods(self): + """ + Specify C-lib's score calibration methods arguments and return types. + """ + corelib.fillprototype( + self.clib_float32.c_fit_platt_transform_f32, + None, + [c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)], + ) + corelib.fillprototype( + self.clib_float32.c_fit_platt_transform_f64, + None, + [c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)], + ) + + def fit_platt_transform(self, logits, tgt_prob): + """Python to C/C++ interface for platt transfrom fit. + + Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf + + Args: + logits (ndarray): 1-d array of logit with length N. + tgt_prob (ndarray): 1-d array of target probability scores within [0, 1] with length N. + Returns: + A, B: coefficients for Platt's scale. + """ + assert isinstance(logits, np.ndarray) + assert isinstance(tgt_prob, np.ndarray) + assert len(logits) == len(tgt_prob) + assert logits.dtype == tgt_prob.dtype + + if tgt_prob.min() < 0 or tgt_prob.max() > 1.0: + raise ValueError("Target probability out of bound!") + + AB = np.array([0, 0], dtype=np.float64) + + if tgt_prob.dtype == np.float32: + clib.clib_float32.c_fit_platt_transform_f32( + len(logits), + logits.ctypes.data_as(POINTER(c_float)), + tgt_prob.ctypes.data_as(POINTER(c_float)), + AB.ctypes.data_as(POINTER(c_double)), + ) + elif tgt_prob.dtype == np.float64: + clib.clib_float32.c_fit_platt_transform_f64( + len(logits), + logits.ctypes.data_as(POINTER(c_double)), + tgt_prob.ctypes.data_as(POINTER(c_double)), + AB.ctypes.data_as(POINTER(c_double)), + ) + else: + raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}") + + return AB[0], AB[1] + clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos") diff --git a/pecos/core/libpecos.cpp b/pecos/core/libpecos.cpp index 2f001bbc..d1dcb1bf 100644 --- a/pecos/core/libpecos.cpp +++ b/pecos/core/libpecos.cpp @@ -651,4 +651,18 @@ extern "C" { static_cast(map_ptr)->batch_get( n_sub_row, n_sub_col, sub_rows, sub_cols, trunc_val_len, ret, ret_lens, threads); } + + // ==== C Interface of Score Calibrator ==== + + #define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \ + void c_fit_platt_transform ## SUFFIX( \ + size_t num_samples, \ + const VAL_TYPE* logits, \ + const VAL_TYPE* tgt_probs, \ + double* AB \ + ) { \ + pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \ + } + C_FIT_PLATT_TRANSFORM(_f32, float32_t) + C_FIT_PLATT_TRANSFORM(_f64, float64_t) } diff --git a/pecos/core/utils/newton.hpp b/pecos/core/utils/newton.hpp index 19d0a848..0dcd68da 100644 --- a/pecos/core/utils/newton.hpp +++ b/pecos/core/utils/newton.hpp @@ -272,5 +272,117 @@ namespace pecos { return cg_iter; }; }; + + + // Platt scale with given target curve. + // Reference Implementation: + // https://github.com/cjlin1/libsvm/blob/master/svm.cpp + + template + static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) { + // hyper parameters + int max_iter = 100; // Maximal number of iterations + double min_step = 1e-10; // Minimal step taken in line search + double sigma = 1e-12; // For numerically strict PD of Hessian + double eps = 1e-6; + + int iter; + + // Initial Point and Initial Fun Value + A = 0.0; B = 1.0; + double fval = 0.0; + + // check for out of bound in tgt_probs + for (size_t i = 0; i < num_samples; i++) { + if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) { + throw std::runtime_error("fit_platt_transform: target probability out of bound\n"); + } + } + + + for (size_t i = 0; i < num_samples; i++) { + double fApB = logits[i] * A + B; + if (fApB >= 0) { + fval += tgt_probs[i] * fApB + log(1 + exp(-fApB)); + } else { + fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB)); + } + } + for (iter = 0; iter < max_iter; iter++) { + // Update Gradient and Hessian (use H' = H + sigma I) + double h11 = sigma; + double h22 = sigma; // numerically ensures strict PD + double h21 = 0.0; + double g1 = 0.0; + double g2 = 0.0; + + for (size_t i = 0; i < num_samples; i++) { + double fApB = logits[i] * A + B; + double p = 0, q = 0; + if (fApB >= 0) { + p = exp(-fApB) / (1.0 + exp(-fApB)); + q = 1.0 / (1.0 + exp(-fApB)); + } else { + p = 1.0 / (1.0 + exp(fApB)); + q = exp(fApB) / (1.0 + exp(fApB)); + } + double d1 = tgt_probs[i] - p; + double d2 = p * q; + + h11 += d2 * logits[i] * logits[i]; + h22 += d2; + h21 += logits[i] * d2; + g1 += logits[i] * d1; + g2 += d1; + } + + // Stopping Criteria + if (fabs(g1) < eps && fabs(g2) < eps) + break; + + // Finding Newton direction: -inv(H') * g + double det = h11 * h22 - h21 * h21; + double dA = -(h22 * g1 - h21 * g2) / det; + double dB = -(-h21 * g1 + h11 * g2) / det; + double gd = g1 * dA + g2 * dB; + + // Line Search + double stepsize = 1.0; + + while (stepsize >= min_step) { + double newA = A + stepsize * dA; + double newB = B + stepsize * dB; + + // New function value + double newf = 0.0; + for (size_t i = 0; i < num_samples; i++) { + double fApB = logits[i] * newA + newB; + if (fApB >= 0) { + newf += tgt_probs[i] * fApB + log(1 + exp(-fApB)); + } else { + newf += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB)); + } + } + // Check sufficient decrease + if (newf < fval + 0.0001 * stepsize * gd) + { + A = newA; + B = newB; + fval = newf; + break; + } else { + stepsize = stepsize / 2.0; + } + } + + if (stepsize < min_step) { + throw std::runtime_error("fit_platt_transform: Line search fails\n"); + } + } + + if (iter >= max_iter) { + throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n"); + } + } } // namespace pecos #endif diff --git a/test/pecos/core/test_clib.py b/test/pecos/core/test_clib.py index 6448ea74..a6ee0724 100644 --- a/test/pecos/core/test_clib.py +++ b/test/pecos/core/test_clib.py @@ -67,3 +67,23 @@ def test_sparse_inner_products(): assert true_vals == approx( pred_vals, abs=1e-9 ), f"true_vals != pred_vals, where X/Y are drm/dcm" + + +def test_platt_scale(): + import numpy as np + from pecos.core import clib + + A = 0.25 + B = 3.14 + + orig = np.arange(-15, 15, 1, dtype=np.float32) + tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float32) + At, Bt = clib.fit_platt_transform(orig, tgt) + assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}" + assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}" + + orig = np.arange(-15, 15, 1, dtype=np.float64) + tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float64) + At, Bt = clib.fit_platt_transform(orig, tgt) + assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}" + assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"