-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add weighted sum of exponential unary function
This unary function permits the calculation of the weighted sum of negative exponential of multiple unary functions, useful when the output of the functions is a score, -log(p). Added to generalize `IMP.pmi.restraints.basic.BiStableDistanceRestraint`. Relates salilab/pmi#211
- Loading branch information
Showing
3 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/** | ||
* \file IMP/core/WeightedSumOfExponential.h | ||
* \brief Negative logarithm of weighted sum of negative exponential of unary functions. | ||
* | ||
* Copyright 2007-2016 IMP Inventors. All rights reserved. | ||
*/ | ||
|
||
#ifndef IMPCORE_WEIGHTED_SUM_OF_EXPONENTIAL_H | ||
#define IMPCORE_WEIGHTED_SUM_OF_EXPONENTIAL_H | ||
|
||
#include <IMP/core/core_config.h> | ||
#include <IMP/UnaryFunction.h> | ||
#include <cmath> | ||
|
||
IMPCORE_BEGIN_NAMESPACE | ||
|
||
//! Negative logarithm of weighted sum of negative exponential of unary functions. | ||
/** Given unary functions \f$ f_i(x) \f$ and weights \f$ w_i \f$, compute the function | ||
\f[ F(x) = -d \log\left[ \sum_i{ \left( w_i e^{-f_i(x) / d} \right) } \right] ,\f] | ||
where \f$ d \f$ is the denominator of the exponential. | ||
This is used when the functions \f$ f_i(x) \f$ are scores (\f$ -\log(p) \f$), and the | ||
desired score \f$ F(x) \f$ is the score resulting from the weighted convolution of | ||
their probability distributions. | ||
\see WeightedSum | ||
*/ | ||
class WeightedSumOfExponential : public UnaryFunction { | ||
public: | ||
/** Create with the functions and their respective weights */ | ||
WeightedSumOfExponential(UnaryFunctions funcs, | ||
Floats weights, | ||
Float denom = 1.0) : funcs_(funcs), weights_(weights), denom_(denom) { | ||
IMP_USAGE_CHECK(weights.size() == funcs.size(), | ||
"Number of functions and weights must match."); | ||
IMP_USAGE_CHECK(funcs.size() > 1, | ||
"More than one function and weight must be provided."); | ||
IMP_USAGE_CHECK(denom != 0., | ||
"Exponential denominator must be nonzero."); | ||
|
||
} | ||
|
||
virtual DerivativePair evaluate_with_derivative(double feature) const { | ||
double exp_sum = 0; | ||
double derv_num = 0; | ||
double weight_exp; | ||
DerivativePair fout; | ||
for (unsigned int i = 0; i < funcs_.size(); ++i) { | ||
fout = funcs_[i]->evaluate_with_derivative(feature); | ||
weight_exp = weights_[i] * std::exp(-fout.first / denom_); | ||
exp_sum += weight_exp; | ||
derv_num += weight_exp * fout.second; | ||
} | ||
return DerivativePair(-std::log(exp_sum) * denom_, derv_num / exp_sum); | ||
} | ||
|
||
virtual double evaluate(double feature) const { | ||
double exp_sum = 0; | ||
for (unsigned int i = 0; i < funcs_.size(); ++i) { | ||
exp_sum += weights_[i] * std::exp(-funcs_[i]->evaluate(feature) / denom_); | ||
} | ||
return -std::log(exp_sum) * denom_; | ||
} | ||
|
||
//! Get the number of functions | ||
unsigned int get_function_number() { | ||
return funcs_.size(); | ||
} | ||
|
||
//! Set the function weights | ||
void set_weights(Floats weights) { | ||
IMP_USAGE_CHECK(weights.size() == get_function_number(), | ||
"Number of weights and functions must match."); | ||
weights_ = weights; | ||
} | ||
|
||
//! Get the function weights | ||
Floats get_weights() { return weights_; } | ||
|
||
//! Get function weight at index | ||
double get_weight(unsigned int i) const { | ||
IMP_USAGE_CHECK(i < weights_.size(), "Invalid weight index"); | ||
return weights_[i]; | ||
} | ||
|
||
//! Get function at index | ||
UnaryFunction* get_function(unsigned int i) { | ||
IMP_USAGE_CHECK(i < get_function_number(), "Invalid function index"); | ||
return funcs_[i]; | ||
} | ||
|
||
//! Set the denominator of the exponential | ||
void set_denominator(double denom) { | ||
IMP_USAGE_CHECK(denom != 0., | ||
"Exponential denominator must be nonzero."); | ||
denom_ = denom; | ||
} | ||
|
||
//! Get the denominator of the exponential | ||
double get_denominator() { return denom_; } | ||
|
||
IMP_OBJECT_METHODS(WeightedSumOfExponential); | ||
|
||
private: | ||
UnaryFunctions funcs_; | ||
Floats weights_; | ||
Float denom_; | ||
}; | ||
|
||
IMPCORE_END_NAMESPACE | ||
|
||
#endif /* IMPCORE_WEIGHTED_SUM_OF_EXPONENTIAL_H */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import math | ||
import IMP | ||
import IMP.test | ||
import IMP.core | ||
|
||
|
||
def _sum_of_exponent(fs, weights, x, d=1.): | ||
exp_sum = 0. | ||
for w, f in zip(weights, fs): | ||
exp_sum += w * math.exp(-f.evaluate(x) / d) | ||
return -math.log(exp_sum) * d | ||
|
||
|
||
def _derv_sum_of_exponent(fs, weights, x, d=1.): | ||
derv_num = 0. | ||
exp_sum = 0. | ||
for w, f in zip(weights, fs): | ||
val, derv = f.evaluate_with_derivative(x) | ||
exp_sum += w * math.exp(-val / d) | ||
derv_num += w * math.exp(-val / d) * derv | ||
return derv_num / exp_sum | ||
|
||
|
||
class Tests(IMP.test.TestCase): | ||
|
||
def test_values(self): | ||
f1 = IMP.core.Harmonic(0., 1.) | ||
f2 = IMP.core.Harmonic(2., 3.) | ||
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7]) | ||
for i in range(-10, 10): | ||
i = float(i) | ||
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i) | ||
self.assertAlmostEqual(sf.evaluate(i), exp_score, delta=1e-6) | ||
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i) | ||
exp_deriv = _derv_sum_of_exponent([f1, f2], [.3, .7], i) | ||
score, deriv = sf.evaluate_with_derivative(i) | ||
self.assertAlmostEqual(score, exp_score, delta=1e-6) | ||
self.assertAlmostEqual(deriv, exp_deriv, delta=1e-6) | ||
|
||
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7], 20.) | ||
for i in range(-10, 10): | ||
i = float(i) | ||
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i, d=20.) | ||
self.assertAlmostEqual(sf.evaluate(i), exp_score, delta=1e-4) | ||
exp_score = _sum_of_exponent([f1, f2], [.3, .7], i, d=20.) | ||
exp_deriv = _derv_sum_of_exponent([f1, f2], [.3, .7], i, d=20.) | ||
score, deriv = sf.evaluate_with_derivative(i) | ||
self.assertAlmostEqual(score, exp_score, delta=1e-4) | ||
self.assertAlmostEqual(deriv, exp_deriv, delta=1e-4) | ||
|
||
def test_update_functions(self): | ||
f1 = IMP.core.Harmonic(0., 1.) | ||
f2 = IMP.core.Harmonic(2., 2.) | ||
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.5, .5]) | ||
self.assertAlmostEqual(sf.evaluate(0), .674997, delta=1e-6) | ||
f2.set_k(1.) | ||
self.assertAlmostEqual(sf.evaluate(0), .566219, delta=1e-6) | ||
|
||
def test_accessors(self): | ||
f1 = IMP.core.Harmonic(0., 1.) | ||
f2 = IMP.core.Harmonic(2., 3.) | ||
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7], 2.) | ||
self.assertAlmostEqual(sf.get_weight(0), .3) | ||
self.assertAlmostEqual(sf.get_weight(1), .7) | ||
self.assertAlmostEqual(sf.get_weights()[0], .3) | ||
sf.set_weights([.4, .6]) | ||
self.assertAlmostEqual(sf.get_weight(0), .4) | ||
self.assertAlmostEqual(sf.get_weight(1), .6) | ||
self.assertAlmostEqual(sf.get_denominator(), 2.) | ||
sf.set_denominator(3.) | ||
self.assertAlmostEqual(sf.get_denominator(), 3.) | ||
|
||
def test_errors(self): | ||
f1 = IMP.core.Harmonic(0., 1.) | ||
f2 = IMP.core.Harmonic(2., 3.) | ||
self.assertRaisesUsageException(IMP.core.WeightedSumOfExponential, | ||
[f1], [1.]) | ||
self.assertRaisesUsageException(IMP.core.WeightedSumOfExponential, | ||
[f1, f2], [1.]) | ||
self.assertRaisesUsageException(IMP.core.WeightedSumOfExponential, | ||
[f1, f2], [1.], .0) | ||
sf = IMP.core.WeightedSumOfExponential([f1, f2], [.3, .7]) | ||
self.assertRaisesUsageException(sf.set_weights, [1.]) | ||
self.assertRaisesUsageException(sf.set_denominator, 0.) | ||
|
||
|
||
if __name__ == '__main__': | ||
IMP.test.main() |