diff --git a/modules/core/include/WeightedSum.h b/modules/core/include/WeightedSum.h new file mode 100644 index 0000000000..2aba2c679b --- /dev/null +++ b/modules/core/include/WeightedSum.h @@ -0,0 +1,84 @@ +/** + * \file IMP/core/WeightedSum.h \brief Weighted sum of unary functions. + * + * Copyright 2007-2016 IMP Inventors. All rights reserved. + */ + +#ifndef IMPCORE_WEIGHTED_SUM_H +#define IMPCORE_WEIGHTED_SUM_H + +#include +#include + +IMPCORE_BEGIN_NAMESPACE + +//! Weighted sum of unary functions. +/** A unary function that computes the weighted sum of multiple functions. + */ +class WeightedSum : public UnaryFunction { + public: + /** Create with the functions and their respective weights */ + WeightedSum(UnaryFunctions funcs, Floats weights) : funcs_(funcs), weights_(weights) { + IMP_USAGE_CHECK(weights.size() == funcs.size(), + "Number of functions and weights must match."); + IMP_USAGE_CHECK(funcs.size() > 1, + "More than one function and weight must be provided."); + } + + virtual DerivativePair evaluate_with_derivative(double feature) const { + double eval = 0; + double derv = 0; + DerivativePair fout; + for (unsigned int i = 0; i < funcs_.size(); ++i) { + fout = funcs_[i]->evaluate_with_derivative(feature); + eval += weights_[i] * fout.first; + derv += weights_[i] * fout.second; + } + return DerivativePair(eval, derv); + } + + virtual double evaluate(double feature) const { + double ret = 0; + for (unsigned int i = 0; i < funcs_.size(); ++i) { + ret += weights_[i] * funcs_[i]->evaluate(feature); + } + return ret; + } + + //! Get the number of functions + unsigned int get_function_number() { + return funcs_.size(); + } + + //! Set the function weights + void set_weights(Floats weights) { + IMP_USAGE_CHECK(weights.size() == get_function_number(), + "Number of weights and functions must match."); + weights_ = weights; + } + + //! Get the function weights + Floats get_weights() { return weights_; } + + //! Get function weight at index + double get_weight(unsigned int i) const { + IMP_USAGE_CHECK(i < weights_.size(), "Invalid weight index"); + return weights_[i]; + } + + //! Get function at index + UnaryFunction* get_function(unsigned int i) { + IMP_USAGE_CHECK(i < get_function_number(), "Invalid function index"); + return funcs_[i]; + } + + IMP_OBJECT_METHODS(WeightedSum); + + private: + UnaryFunctions funcs_; + Floats weights_; +}; + +IMPCORE_END_NAMESPACE + +#endif /* IMPCORE_WEIGHTED_SUM_H */ diff --git a/modules/core/pyext/swig.i-in b/modules/core/pyext/swig.i-in index a7f6a4884f..41b99776db 100644 --- a/modules/core/pyext/swig.i-in +++ b/modules/core/pyext/swig.i-in @@ -43,6 +43,7 @@ IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBound, HarmonicUpperBounds); IMP_SWIG_OBJECT( IMP::core, HarmonicSphereDistancePairScore, HarmonicSphereDistancePairScores); IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBoundSphereDistancePairScore, HarmonicUpperBoundSphereDistancePairScores); IMP_SWIG_OBJECT( IMP::core, HarmonicUpperBoundSphereDiameterPairScore, HarmonicUpperBoundSphereDiameterPairScores); +IMP_SWIG_OBJECT( IMP::core, WeightedSum, WeightedSums); IMP_SWIG_OBJECT( IMP::core, IncrementalScoringFunction, IncrementalScoringFunctions); IMP_SWIG_OBJECT( IMP::core, KClosePairsPairScore, KClosePairsPairScores); IMP_SWIG_OBJECT( IMP::core, LeavesRefiner, LeavesRefiners); @@ -210,6 +211,7 @@ IMP_SWIG_OBJECT(IMP::core, MultipleBinormalRestraint, MultipleBinormalRestraints %include "IMP/core/HarmonicWell.h" %include "IMP/core/HarmonicLowerBound.h" %include "IMP/core/HarmonicUpperBound.h" +%include "IMP/core/WeightedSum.h" %include "IMP/core/PeriodicOptimizerState.h" %include "IMP/core/MSConnectivityRestraint.h" %inline %{ diff --git a/modules/core/test/test_weighted_sum.py b/modules/core/test/test_weighted_sum.py new file mode 100644 index 0000000000..6d8e938814 --- /dev/null +++ b/modules/core/test/test_weighted_sum.py @@ -0,0 +1,50 @@ +import IMP +import IMP.test +import IMP.core + + +class Tests(IMP.test.TestCase): + + def test_values(self): + f1 = IMP.core.Harmonic(0., 1.) + f2 = IMP.core.Harmonic(2., 3.) + sf = IMP.core.WeightedSum([f1, f2], [.3, .7]) + for i in range(-10, 10): + i = float(i) + self.assertAlmostEqual( + sf.evaluate(i), .3 * f1.evaluate(i) + .7 * f2.evaluate(i)) + score, deriv = sf.evaluate_with_derivative(i) + score_sum = 0 + deriv_sum = 0 + for w, f in zip([.3, .7], [f1, f2]): + s, d = f.evaluate_with_derivative(i) + score_sum += w * s + deriv_sum += w * d + self.assertAlmostEqual(score, score_sum, delta=1e-4) + self.assertAlmostEqual(deriv, deriv_sum, delta=1e-4) + + def test_accessors(self): + f1 = IMP.core.Harmonic(0., 1.) + f2 = IMP.core.Harmonic(2., 3.) + sf = IMP.core.WeightedSum([f1, f2], [.3, .7]) + self.assertAlmostEqual(sf.get_weight(0), .3) + self.assertAlmostEqual(sf.get_weight(1), .7) + self.assertAlmostEqual(sf.get_weights()[0], .3) + sf.set_weights([.4, .6]) + self.assertAlmostEqual(sf.get_weight(0), .4) + self.assertAlmostEqual(sf.get_weight(1), .6) + + def test_errors(self): + f1 = IMP.core.Harmonic(0., 1.) + f2 = IMP.core.Harmonic(2., 3.) + self.assertRaisesUsageException(IMP.core.WeightedSum, + [f1], [1.]) + self.assertRaisesUsageException(IMP.core.WeightedSum, + [f1, f2], [1.]) + sf = IMP.core.WeightedSum([f1, f2], [.3, .7]) + self.assertRaisesUsageException(sf.set_weights, + [1.]) + + +if __name__ == '__main__': + IMP.test.main()