Skip to content

Commit

Permalink
Add weighted sum of exponential unary function
Browse files Browse the repository at this point in the history
This unary function permits the calculation of the
weighted sum of negative exponential of multiple
unary functions, useful when the output of the functions
is a score, -log(p). Added to generalize
`IMP.pmi.restraints.basic.BiStableDistanceRestraint`.
Relates salilab/pmi#211
  • Loading branch information
sethaxen committed Oct 16, 2016
1 parent 100f387 commit b47d68e
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 0 deletions.
110 changes: 110 additions & 0 deletions modules/core/include/WeightedSumOfExponential.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/**
* \file IMP/core/WeightedSumOfExponential.h
* \brief Negative logarithm of weighted sum of negative exponential of unary functions.
*
* Copyright 2007-2016 IMP Inventors. All rights reserved.
*/

#ifndef IMPCORE_WEIGHTED_SUM_OF_EXPONENTIAL_H
#define IMPCORE_WEIGHTED_SUM_OF_EXPONENTIAL_H

#include <IMP/core/core_config.h>
#include <IMP/UnaryFunction.h>
#include <cmath>

IMPCORE_BEGIN_NAMESPACE

//! Negative logarithm of weighted sum of negative exponential of unary functions.
/** Given unary functions \f$ f_i(x) \f$ and weights \f$ w_i \f$, compute the function
\f[ F(x) = -d \log\left[ \sum_i{ \left( w_i e^{-f_i(x) / d} \right) } \right] ,\f]
where \f$ d \f$ is the denominator of the exponential.
This is used when the functions \f$ f_i(x) \f$ are scores (\f$ -\log(p) \f$), and the
desired score \f$ F(x) \f$ is the score resulting from the weighted convolution of
their probability distributions.
\see WeightedSum
*/
class WeightedSumOfExponential : public UnaryFunction {
public:
/** Create with the functions and their respective weights */
WeightedSumOfExponential(UnaryFunctions funcs,
Floats weights,
Float denom = 1.0) : funcs_(funcs), weights_(weights), denom_(denom) {
IMP_USAGE_CHECK(weights.size() == funcs.size(),
"Number of functions and weights must match.");
IMP_USAGE_CHECK(funcs.size() > 1,
"More than one function and weight must be provided.");
IMP_USAGE_CHECK(denom != 0.,
"Exponential denominator must be nonzero.");

}

virtual DerivativePair evaluate_with_derivative(double feature) const {
double exp_sum = 0;
double derv_num = 0;
double weight_exp;
DerivativePair fout;
for (unsigned int i = 0; i < funcs_.size(); ++i) {
fout = funcs_[i]->evaluate_with_derivative(feature);
weight_exp = weights_[i] * std::exp(-fout.first / denom_);
exp_sum += weight_exp;
derv_num += weight_exp * fout.second;
}
return DerivativePair(-std::log(exp_sum) * denom_, derv_num / exp_sum);
}

virtual double evaluate(double feature) const {
double exp_sum = 0;
for (unsigned int i = 0; i < funcs_.size(); ++i) {
exp_sum += weights_[i] * std::exp(-funcs_[i]->evaluate(feature) / denom_);
}
return -std::log(exp_sum) * denom_;
}

//! Get the number of functions
unsigned int get_function_number() {
return funcs_.size();
}

//! Set the function weights
void set_weights(Floats weights) {
IMP_USAGE_CHECK(weights.size() == get_function_number(),
"Number of weights and functions must match.");
weights_ = weights;
}

//! Get the function weights
Floats get_weights() { return weights_; }

//! Get function weight at index
double get_weight(unsigned int i) const {
IMP_USAGE_CHECK(i < weights_.size(), "Invalid weight index");
return weights_[i];
}

//! Get function at index
UnaryFunction* get_function(unsigned int i) {
IMP_USAGE_CHECK(i < get_function_number(), "Invalid function index");
return funcs_[i];
}

//! Set the denominator of the exponential
void set_denominator(double denom) {
IMP_USAGE_CHECK(denom != 0.,
"Exponential denominator must be nonzero.");
denom_ = denom;
}

//! Get the denominator of the exponential
double get_denominator() { return denom_; }

IMP_OBJECT_METHODS(WeightedSumOfExponential);

private:
UnaryFunctions funcs_;
Floats weights_;
Float denom_;
};

IMPCORE_END_NAMESPACE

#endif /* IMPCORE_WEIGHTED_SUM_OF_EXPONENTIAL_H */
2 changes: 2 additions & 0 deletions modules/core/pyext/swig.i-in
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ IMP_SWIG_OBJECT( IMP::core, HarmonicSphereDistancePairScore, HarmonicSphereDista
IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBoundSphereDistancePairScore, HarmonicUpperBoundSphereDistancePairScores);
IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBoundSphereDiameterPairScore, HarmonicUpperBoundSphereDiameterPairScores);
IMP_SWIG_OBJECT( IMP::core, WeightedSum, WeightedSums);
IMP_SWIG_OBJECT( IMP::core, WeightedSumOfExponential, WeightedSumOfExponentials);
IMP_SWIG_OBJECT( IMP::core, IncrementalScoringFunction, IncrementalScoringFunctions);
IMP_SWIG_OBJECT( IMP::core, KClosePairsPairScore, KClosePairsPairScores);
IMP_SWIG_OBJECT( IMP::core, LeavesRefiner, LeavesRefiners);
Expand Down Expand Up @@ -212,6 +213,7 @@ IMP_SWIG_OBJECT(IMP::core, MultipleBinormalRestraint, MultipleBinormalRestraints
%include "IMP/core/HarmonicLowerBound.h"
%include "IMP/core/HarmonicUpperBound.h"
%include "IMP/core/WeightedSum.h"
%include "IMP/core/WeightedSumOfExponential.h"
%include "IMP/core/PeriodicOptimizerState.h"
%include "IMP/core/MSConnectivityRestraint.h"
%inline %{
Expand Down
88 changes: 88 additions & 0 deletions modules/core/test/test_weighted_sum_of_exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import math
import IMP
import IMP.test
import IMP.core


def _sum_of_exponent(fs, weights, x, d=1.):
exp_sum = 0.
for w, f in zip(weights, fs):
exp_sum += w * math.exp(-f.evaluate(x) / d)
return -math.log(exp_sum) * d


def _derv_sum_of_exponent(fs, weights, x, d=1.):
derv_num = 0.
exp_sum = 0.
for w, f in zip(weights, fs):
val, derv = f.evaluate_with_derivative(x)
exp_sum += w * math.exp(-val / d)
derv_num += w * math.exp(-val / d) * derv
return derv_num / exp_sum


class Tests(IMP.test.TestCase):

def test_values(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 3.)
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7])
for i in range(-10, 10):
i = float(i)
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i)
self.assertAlmostEqual(sf.evaluate(i), exp_score, delta=1e-6)
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i)
exp_deriv = _derv_sum_of_exponent([f1, f2], [.3, .7], i)
score, deriv = sf.evaluate_with_derivative(i)
self.assertAlmostEqual(score, exp_score, delta=1e-6)
self.assertAlmostEqual(deriv, exp_deriv, delta=1e-6)

sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7], 20.)
for i in range(-10, 10):
i = float(i)
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i, d=20.)
self.assertAlmostEqual(sf.evaluate(i), exp_score, delta=1e-4)
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i, d=20.)
exp_deriv = _derv_sum_of_exponent([f1, f2], [.3, .7], i, d=20.)
score, deriv = sf.evaluate_with_derivative(i)
self.assertAlmostEqual(score, exp_score, delta=1e-4)
self.assertAlmostEqual(deriv, exp_deriv, delta=1e-4)

def test_update_functions(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 2.)
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.5, .5])
self.assertAlmostEqual(sf.evaluate(0), .674997, delta=1e-6)
f2.set_k(1.)
self.assertAlmostEqual(sf.evaluate(0), .566219, delta=1e-6)

def test_accessors(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 3.)
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7], 2.)
self.assertAlmostEqual(sf.get_weight(0), .3)
self.assertAlmostEqual(sf.get_weight(1), .7)
self.assertAlmostEqual(sf.get_weights()[0], .3)
sf.set_weights([.4, .6])
self.assertAlmostEqual(sf.get_weight(0), .4)
self.assertAlmostEqual(sf.get_weight(1), .6)
self.assertAlmostEqual(sf.get_denominator(), 2.)
sf.set_denominator(3.)
self.assertAlmostEqual(sf.get_denominator(), 3.)

def test_errors(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 3.)
self.assertRaisesUsageException(IMP.core.WeightedSumOfExponential,
[f1], [1.])
self.assertRaisesUsageException(IMP.core.WeightedSumOfExponential,
[f1, f2], [1.])
self.assertRaisesUsageException(IMP.core.WeightedSumOfExponential,
[f1, f2], [1.], .0)
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7])
self.assertRaisesUsageException(sf.set_weights, [1.])
self.assertRaisesUsageException(sf.set_denominator, 0.)


if __name__ == '__main__':
IMP.test.main()

0 comments on commit b47d68e

Please sign in to comment.