Skip to content

Commit

Permalink
Add weighted sum unary function
Browse files Browse the repository at this point in the history
This unary function will permit, for example, multi-well harmonic
potentials. Added to generalize
`IMP.pmi.restraints.basic.BiStableDistanceRestraint`.
Relates salilab/pmi#211
  • Loading branch information
sethaxen committed Oct 13, 2016
1 parent 79ff0f7 commit 9e42be4
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
84 changes: 84 additions & 0 deletions modules/core/include/WeightedSum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/**
* \file IMP/core/WeightedSum.h \brief Weighted sum of unary functions.
*
* Copyright 2007-2016 IMP Inventors. All rights reserved.
*/

#ifndef IMPCORE_WEIGHTED_SUM_H
#define IMPCORE_WEIGHTED_SUM_H

#include <IMP/core/core_config.h>
#include <IMP/UnaryFunction.h>

IMPCORE_BEGIN_NAMESPACE

//! Weighted sum of unary functions.
/** A unary function that computes the weighted sum of multiple functions.
*/
class WeightedSum : public UnaryFunction {
public:
/** Create with the functions and their respective weights */
WeightedSum(UnaryFunctions funcs, Floats weights) : funcs_(funcs), weights_(weights) {
IMP_USAGE_CHECK(weights.size() == funcs.size(),
"Number of functions and weights must match.");
IMP_USAGE_CHECK(funcs.size() > 1,
"More than one function and weight must be provided.");
}

virtual DerivativePair evaluate_with_derivative(double feature) const {
double eval = 0;
double derv = 0;
DerivativePair fout;
for (unsigned int i = 0; i < funcs_.size(); ++i) {
fout = funcs_[i]->evaluate_with_derivative(feature);
eval += weights_[i] * fout.first;
derv += weights_[i] * fout.second;
}
return DerivativePair(eval, derv);
}

virtual double evaluate(double feature) const {
double ret = 0;
for (unsigned int i = 0; i < funcs_.size(); ++i) {
ret += weights_[i] * funcs_[i]->evaluate(feature);
}
return ret;
}

//! Get the number of functions
unsigned int get_function_number() {
return funcs_.size();
}

//! Set the function weights
void set_weights(Floats weights) {
IMP_USAGE_CHECK(weights.size() == get_function_number(),
"Number of weights and functions must match.");
weights_ = weights;
}

//! Get the function weights
Floats get_weights() { return weights_; }

//! Get function weight at index
double get_weight(unsigned int i) const {
IMP_USAGE_CHECK(i < weights_.size(), "Invalid weight index");
return weights_[i];
}

//! Get function at index
UnaryFunction* get_function(unsigned int i) {
IMP_USAGE_CHECK(i < get_function_number(), "Invalid function index");
return funcs_[i];
}

IMP_OBJECT_METHODS(WeightedSum);

private:
UnaryFunctions funcs_;
Floats weights_;
};

IMPCORE_END_NAMESPACE

#endif /* IMPCORE_WEIGHTED_SUM_H */
2 changes: 2 additions & 0 deletions modules/core/pyext/swig.i-in
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBound, HarmonicUpperBounds);
IMP_SWIG_OBJECT( IMP::core, HarmonicSphereDistancePairScore, HarmonicSphereDistancePairScores);
IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBoundSphereDistancePairScore, HarmonicUpperBoundSphereDistancePairScores);
IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBoundSphereDiameterPairScore, HarmonicUpperBoundSphereDiameterPairScores);
IMP_SWIG_OBJECT( IMP::core, WeightedSum, WeightedSums);
IMP_SWIG_OBJECT( IMP::core, IncrementalScoringFunction, IncrementalScoringFunctions);
IMP_SWIG_OBJECT( IMP::core, KClosePairsPairScore, KClosePairsPairScores);
IMP_SWIG_OBJECT( IMP::core, LeavesRefiner, LeavesRefiners);
Expand Down Expand Up @@ -210,6 +211,7 @@ IMP_SWIG_OBJECT(IMP::core, MultipleBinormalRestraint, MultipleBinormalRestraints
%include "IMP/core/HarmonicWell.h"
%include "IMP/core/HarmonicLowerBound.h"
%include "IMP/core/HarmonicUpperBound.h"
%include "IMP/core/WeightedSum.h"
%include "IMP/core/PeriodicOptimizerState.h"
%include "IMP/core/MSConnectivityRestraint.h"
%inline %{
Expand Down
50 changes: 50 additions & 0 deletions modules/core/test/test_weighted_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import IMP
import IMP.test
import IMP.core


class Tests(IMP.test.TestCase):

def test_values(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 3.)
sf = IMP.core.WeightedSum([f1, f2], [.3, .7])
for i in range(-10, 10):
i = float(i)
self.assertAlmostEqual(
sf.evaluate(i), .3 * f1.evaluate(i) + .7 * f2.evaluate(i))
score, deriv = sf.evaluate_with_derivative(i)
score_sum = 0
deriv_sum = 0
for w, f in zip([.3, .7], [f1, f2]):
s, d = f.evaluate_with_derivative(i)
score_sum += w * s
deriv_sum += w * d
self.assertAlmostEqual(score, score_sum, delta=1e-4)
self.assertAlmostEqual(deriv, deriv_sum, delta=1e-4)

def test_accessors(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 3.)
sf = IMP.core.WeightedSum([f1, f2], [.3, .7])
self.assertAlmostEqual(sf.get_weight(0), .3)
self.assertAlmostEqual(sf.get_weight(1), .7)
self.assertAlmostEqual(sf.get_weights()[0], .3)
sf.set_weights([.4, .6])
self.assertAlmostEqual(sf.get_weight(0), .4)
self.assertAlmostEqual(sf.get_weight(1), .6)

def test_errors(self):
f1 = IMP.core.Harmonic(0., 1.)
f2 = IMP.core.Harmonic(2., 3.)
self.assertRaisesUsageException(IMP.core.WeightedSum,
[f1], [1.])
self.assertRaisesUsageException(IMP.core.WeightedSum,
[f1, f2], [1.])
sf = IMP.core.WeightedSum([f1, f2], [.3, .7])
self.assertRaisesUsageException(sf.set_weights,
[1.])


if __name__ == '__main__':
IMP.test.main()

1 comment on commit 9e42be4

@sethaxen
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The UnaryFunctions should probably be passed by reference so that they can potentially be updated. I had compilation errors while trying to do this though.

Please sign in to comment.