Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
jschueller committed May 6, 2024
1 parent ecffbf1 commit 5ee6d1c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 56 deletions.
85 changes: 39 additions & 46 deletions lib/src/SlicedInverseRegression.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ SlicedInverseRegression * SlicedInverseRegression::clone() const
void SlicedInverseRegression::run()
{
const UnsignedInteger size = inputSample_.getSize();

const UnsignedInteger inputDimension = inputSample_.getDimension();
const Indices supervisionIndices = outputSample_.argsort();
Collection<Indices> list_chunk;
UnsignedInteger offset = 0;
Expand All @@ -74,70 +74,63 @@ void SlicedInverseRegression::run()
chunk_population.add(localSize);
offset += localSize;
}
const Point center(inputSample_.computeMean());
const Point mean(inputSample_.computeMean());
Sample X_centered(inputSample_);
X_centered -= center;
CovarianceMatrix input_covariance(X_centered.computeCovariance());
input_covariance.getImplementation()->symmetrize();
Matrix u;
Matrix vT;
X_centered -= mean;
Matrix Q(size, inputDimension);
for (UnsignedInteger i = 0; i < size; ++ i)
for(UnsignedInteger j = 0; j < inputDimension; ++ j)
Q(i, j) = inputSample_(i, j);
Matrix R;
#if OPENTURNS_VERSION >= 102300
const Point singularValues = input_covariance.computeSVDInPlace(u, vT, false); // fullSVD
Q.computeQRInPlace(R);
#else
const Point singularValues = input_covariance.computeSVD(u, vT, false, false); // fullSVD, keepIntact
Q.computeQR(R, true);
#endif
Point s1(singularValues.getSize());
for(UnsignedInteger i = 0; i < s1.getSize(); ++ i)
s1[i] = 1.0 / singularValues[i];
Matrix us1(u);
for(UnsignedInteger j = 0; j < s1.getSize(); ++ j)
for(UnsignedInteger i = 0; i < s1.getSize(); ++ i)
us1(i, j) *= s1[j];
Matrix inverseCovariance(us1 * vT);
const UnsignedInteger inputDimension = inputSample_.getDimension();
Matrix weighted_covariance(inputDimension, inputDimension);
Sample Z(size, inputDimension);
for (UnsignedInteger i = 0; i < size; ++ i)
for(UnsignedInteger j = 0; j < inputDimension; ++ j)
Z(i, j) = Q(i, j) * std::sqrt(1.0 * size);
Matrix zMeans(sliceNumber_, inputDimension);
for(UnsignedInteger j = 0; j < sliceNumber_; ++ j)
{
const Point meanSlice(inputSample_.select(list_chunk[j]).computeMean());
Matrix slice_moment(inputDimension, 1);
const Point zMean(Z.select(list_chunk[j]).computeMean() * std::sqrt(1.0 * chunk_population[j]));
for(UnsignedInteger i = 0; i < inputDimension; ++ i)
slice_moment(i, 0) = meanSlice[i];
Matrix slice_covariance(slice_moment * slice_moment.transpose());
const Scalar w = chunk_population[j] * 1.0 / size;
weighted_covariance = weighted_covariance + slice_covariance * w;
zMeans(j, i) = zMean[i];
}
SquareComplexMatrix eigen_vector;
SquareMatrix icwc((inverseCovariance * weighted_covariance).getImplementation());
SymmetricMatrix M((zMeans.transpose() * zMeans / size).getImplementation());
SquareMatrix eigenVectors;
#if OPENTURNS_VERSION >= 102300
const SquareMatrix::ComplexCollection eigen_value = icwc.computeEVInPlace(eigen_vector);
const Point eigenValues = M.computeEVInPlace(eigenVectors);
#else
const SquareMatrix::ComplexCollection eigen_value = icwc.computeEV(eigen_vector, false); // keepIntact
const Point eigenValues = M.computeEV(eigenVectors, false); // keepIntact
#endif
Sample evs(eigen_value.getSize(), 1);
Scalar max_imag = 0.0;
for(UnsignedInteger i = 0; i < eigen_value.getSize(); ++ i)
SquareMatrix eigenVectorsRev(inputDimension, inputDimension);
Point eigenValuesRev(inputDimension);
for(UnsignedInteger j = 0; j < inputDimension; ++ j)
{
evs(i, 0) = eigen_value[i].real();
max_imag = std::max(max_imag, std::abs(eigen_value[i].imag()));
eigenValuesRev[j] = eigenValues[inputDimension - 1 - j];
for(UnsignedInteger i = 0; i < inputDimension; ++ i)
eigenVectorsRev(i, j) = eigenVectors(i, inputDimension - 1 - j);
}
TriangularMatrix R2((R * std::sqrt(1.0 * size)).getImplementation(), false);
Matrix directions = R2.solveLinearSystemInPlace(eigenVectorsRev);

if (max_imag > std::sqrt(SpecFunc::Precision))
throw NotDefinedException(HERE) << "complex eigen-values during SIR";
else if (max_imag > 0.0)
LOGWARN("negligible complex eigen-values during SIR");

// Matrix eigen_vector2(eigen_vector.real());
const Indices order(evs.argsort(false));
// prune directions and associated eigenvalues
directions.getImplementation()->resize(inputDimension, modesNumber_);
eigenValuesRev.resize(modesNumber_);

Point singular_values(modesNumber_);
Matrix basis(inputDimension, modesNumber_);
// normalize directions
for(UnsignedInteger j = 0; j < modesNumber_; ++ j)
{
singular_values[j] = eigen_value[order[j]].real();
Scalar normJ = 0.0;
for(UnsignedInteger i = 0; i < inputDimension; ++ i)
normJ += directions(i, j) * directions(i, j);
normJ = 1.0 / std::sqrt(normJ);
for(UnsignedInteger i = 0; i < inputDimension; ++ i)
basis(i, j) = eigen_vector(i, order[j]).real();
directions(i, j) *= normJ;
}
result_ = SlicedInverseRegressionResult(basis, center);
result_ = SlicedInverseRegressionResult(directions, mean, eigenValuesRev);
}

SlicedInverseRegressionResult SlicedInverseRegression::getResult() const
Expand Down
27 changes: 20 additions & 7 deletions lib/src/SlicedInverseRegressionResult.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ SlicedInverseRegressionResult::SlicedInverseRegressionResult()
// Nothing to do
}

SlicedInverseRegressionResult::SlicedInverseRegressionResult(const OT::Matrix & basis,
const OT::Point & center)
SlicedInverseRegressionResult::SlicedInverseRegressionResult(const OT::Matrix & directions,
const OT::Point & center, const OT::Point & eigenValues)
: PersistentObject()
, basis_(basis)
, directions_(directions)
, center_(center)
, eigenValues_(eigenValues)
{
}

Expand All @@ -65,29 +66,41 @@ String SlicedInverseRegressionResult::__repr__() const

Function SlicedInverseRegressionResult::getTransformation() const
{
return LinearFunction(center_, Point(center_.getDimension()), basis_);
return LinearFunction(center_, Point(center_.getDimension()), directions_);
}

Function SlicedInverseRegressionResult::getInverseTransformation() const
{
Matrix inv(basis_.solveLinearSystem(IdentityMatrix(center_.getDimension())));
Matrix inv(directions_.solveLinearSystem(IdentityMatrix(center_.getDimension())));
return LinearFunction(-center_, Point(center_.getDimension()), inv);
}

Matrix SlicedInverseRegressionResult::getDirections() const
{
return directions_;
}

Point SlicedInverseRegressionResult::getEigenvalues() const
{
return eigenValues_;
}

/* Method save() stores the object through the StorageManager */
void SlicedInverseRegressionResult::save(Advocate & adv) const
{
PersistentObject::save(adv);
adv.saveAttribute( "basis_", basis_ );
adv.saveAttribute( "directions_", directions_ );
adv.saveAttribute( "center_", center_ );
adv.saveAttribute( "eigenValues_", eigenValues_ );
}

/* Method load() reloads the object from the StorageManager */
void SlicedInverseRegressionResult::load(Advocate & adv)
{
PersistentObject::load(adv);
adv.loadAttribute( "basis_", basis_ );
adv.loadAttribute( "directions_", directions_ );
adv.loadAttribute( "center_", center_ );
adv.loadAttribute( "eigenValues_", eigenValues_ );
}


Expand Down
11 changes: 8 additions & 3 deletions lib/src/otsliced/SlicedInverseRegressionResult.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@ public:
/** Default constructor */
SlicedInverseRegressionResult();

SlicedInverseRegressionResult(const OT::Matrix & basis,
const OT::Point & center);
SlicedInverseRegressionResult(const OT::Matrix & linear,
const OT::Point & center,
const OT::Point & eigenValues);

/** Virtual constructor method */
SlicedInverseRegressionResult * clone() const override;

OT::Matrix getDirections() const;
OT::Point getEigenvalues() const;

OT::Function getTransformation() const;
OT::Function getInverseTransformation() const;

Expand All @@ -63,8 +67,9 @@ public:
void load(OT::Advocate & adv) override;

private:
OT::Matrix basis_;
OT::Matrix directions_;
OT::Point center_;
OT::Point eigenValues_;

}; /* class SlicedInverseRegressionResult */

Expand Down
22 changes: 22 additions & 0 deletions python/src/SlicedInverseRegressionResult_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,25 @@ Returns
inverseTransformation : :py:class:`openturns.Function`
Inverse transformation function
"

// ---------------------------------------------------------------------

%feature("docstring") OTSLICED::SlicedInverseRegressionResult::getDirections
"Directions accessor.

Returns
-------
directions : :py:class:`openturns.Matrix`
Directions matrix
"

// ---------------------------------------------------------------------

%feature("docstring") OTSLICED::SlicedInverseRegressionResult::getEigenvalues
"Eigen values accessor.

Returns
-------
eigenvalues : :py:class:`openturns.Point`
Eigen values
"

0 comments on commit 5ee6d1c

Please sign in to comment.