Skip to content

Commit

Permalink
add platt scale
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang committed Nov 29, 2023
1 parent d2708ec commit 55c3d7b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
34 changes: 34 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(self, dirname, soname, forced_rebuild=False):
self.link_ann_hnsw_methods()
self.link_mmap_hashmap_methods()
self.link_mmap_valstore_methods()
self.link_calibrator_methods()

def link_xlinear_methods(self):
"""
Expand Down Expand Up @@ -1939,5 +1940,38 @@ def mmap_valstore_init(self, store_type):
raise NotImplementedError(f"store_type={store_type} is not implemented.")
return self.mmap_valstore_fn_dict[store_type]

def link_calibrator_methods(self):
"""
Specify C-lib's score calibration methods arguments and return types.
"""
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform,
None,
[c_int, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
)

def fit_platt_transform(self, orig, tgt):
"""Python to C/C++ interface for platt transfrom fit.
Args:
orig (ndarray): original scores with length N
tgt (ndarray): target probability scores with length N
Returns:
A, B: coefficients for Platt's scale.
"""
assert len(orig) == len(tgt)
assert orig.dtype == np.float32
assert tgt.dtype == np.float32

AB = np.array([0, 0], dtype=np.float64)

clib.clib_float32.c_fit_platt_transform(
len(orig),
orig.ctypes.data_as(POINTER(c_float)),
tgt.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
)
return AB[0], AB[1]


clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos")
11 changes: 11 additions & 0 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,4 +651,15 @@ extern "C" {
static_cast<mmap_valstore_bytes *>(map_ptr)->batch_get(
n_sub_row, n_sub_col, sub_rows, sub_cols, trunc_val_len, ret, ret_lens, threads);
}

// ==== C Interface of Score Calibrator ====

void c_fit_platt_transform(
int num_samples,
const float* dec_values,
const float* tgt_values,
double* AB
) {
pecos::fit_platt_transform(num_samples, dec_values, tgt_values, AB[0], AB[1]);
}
}
115 changes: 115 additions & 0 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,120 @@ namespace pecos {
return cg_iter;
};
};


// Platt scale with given target curve.
// Reference Implementation:
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

#define Malloc(type,n) (type *)malloc((n)*sizeof(type))

static void fit_platt_transform(int l, const float *dec_values, const float *tgt_values, double& A, double& B) {
// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-6;

double *t = Malloc(double, l);
double fApB, p, q;
double h11, h22, h21, g1, g2, det, dA, dB, gd, stepsize;
double newA, newB, newf, d1, d2;
int iter;

// Initial Point and Initial Fun Value
A=0.0; B=1.0;
double fval = 0.0;

// check for out of bound in tgt_values
for (int i = 0; i < l; i++) {
if (tgt_values[i] > 1.0 || tgt_values[i] < 0) {
throw std::runtime_error("fit_platt_transform: target value out of bound\n");
}
}


for (int i = 0; i < l; i++) {
fApB = dec_values[i] * A + B;
if (fApB >= 0) {
fval += tgt_values[i] * fApB + log(1 + exp(-fApB));
} else {
fval += (tgt_values[i] - 1) * fApB + log(1 + exp(fApB));
}
}
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
h11 = sigma;
h22 = sigma; // numerically ensures strict PD
h21 = 0.0;
g1 = 0.0;
g2 = 0.0;

for (int i = 0; i < l; i++) {
fApB = dec_values[i] * A + B;
if (fApB >= 0) {
p = exp(-fApB) / (1.0 + exp(-fApB));
q = 1.0 / (1.0 + exp(-fApB));
} else {
p = 1.0 / (1.0 + exp(fApB));
q = exp(fApB) / (1.0 + exp(fApB));
}
d2 = p * q;
h11 += dec_values[i] * dec_values[i] * d2;

Check failure

Code scanning / CodeQL

Multiplication result converted to larger type High

Multiplication result may overflow 'float' before it is converted to 'double'.
h22 += d2;
h21 += dec_values[i] * d2;
d1 = tgt_values[i] - p;
g1 += dec_values[i] * d1;
g2 += d1;
}

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps)
break;

// Finding Newton direction: -inv(H') * g
det = h11 * h22 - h21 * h21;
dA = -(h22 * g1 - h21 * g2) / det;
dB = -(-h21 * g1 + h11 * g2) / det;
gd = g1 * dA + g2 * dB;

// Line Search
stepsize = 1;
while (stepsize >= min_step) {
newA = A + stepsize * dA;
newB = B + stepsize * dB;

// New function value
newf = 0.0;
for (int i = 0; i < l; i++) {
fApB = dec_values[i] * newA + newB;
if (fApB >= 0) {
newf += tgt_values[i] * fApB + log(1 + exp(-fApB));
} else {
newf += (tgt_values[i] - 1) * fApB + log(1 + exp(fApB));
}
}
// Check sufficient decrease
if (newf < fval + 0.0001 * stepsize * gd)
{
A = newA;
B = newB;
fval = newf;
break;
} else {
stepsize = stepsize / 2.0;
}
}

if (stepsize < min_step) {
throw std::runtime_error("fit_platt_transform: Line search fails\n");
}
}

if (iter >= max_iter) {
throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n");
}
free(t);
}
} // namespace pecos
#endif
15 changes: 15 additions & 0 deletions test/pecos/core/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,18 @@ def test_sparse_inner_products():
assert true_vals == approx(
pred_vals, abs=1e-9
), f"true_vals != pred_vals, where X/Y are drm/dcm"


def test_platt_scale():
import numpy as np
from pecos.core import clib

A = 0.25
B = 3.14

orig = np.arange(-15, 15, 1, dtype=np.float32)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float32)

At, Bt = clib.fit_platt_transform(orig, tgt)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

0 comments on commit 55c3d7b

Please sign in to comment.