Skip to content

Commit

Permalink
Added tests for matrix inversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdwatters committed Jul 11, 2024
1 parent 5887265 commit 3a7341d
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/flag4j/dense/CMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ public CMatrix(CNumber[][] entries) {
}



/**
* Creates a complex dense matrix whose entries are specified by a double array.
* @param entries Entries of the real dense matrix.
Expand Down
54 changes: 45 additions & 9 deletions src/main/java/org/flag4j/linalg/Invert.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.flag4j.dense.CMatrix;
import org.flag4j.dense.Matrix;
import org.flag4j.linalg.decompositions.chol.Cholesky;
import org.flag4j.linalg.decompositions.chol.ComplexCholesky;
import org.flag4j.linalg.decompositions.chol.RealCholesky;
import org.flag4j.linalg.decompositions.lu.ComplexLU;
import org.flag4j.linalg.decompositions.lu.LU;
Expand All @@ -43,6 +44,8 @@
import org.flag4j.util.ParameterChecks;
import org.flag4j.util.exceptions.SingularMatrixException;

import static org.flag4j.util.Flag4jConstants.EPS_F64;

/**
* This class provides methods for computing the inverse of a matrix. Specialized methods are provided for inverting triangular,
* diagonal, and symmetric positive definite matrices.
Expand Down Expand Up @@ -143,21 +146,19 @@ public static Matrix invTriL(Matrix src) {
*/
public static Matrix invDiag(Matrix src) {
ParameterChecks.assertSquareMatrix(src.shape);

Matrix inverse = new Matrix(src.shape);

double value;
int step = src.numCols+1;
double rank_condition = Math.ulp(1d);
double det = 1;

for(int i=0; i<src.numRows; i+=step) {
for(int i=0; i<src.entries.length; i+=step) {
value = src.entries[i];
det *= value;
inverse.entries[i] = 1.0/value;
}

if(Math.abs(det) <= rank_condition*Math.max(src.numRows, src.numCols)) {
if(Math.abs(det) <= EPS_F64*Math.max(src.numRows, src.numCols)) {
throw new SingularMatrixException("Could not invert.");
}

Expand Down Expand Up @@ -206,16 +207,15 @@ public static CMatrix invDiag(CMatrix src) {

CNumber value;
int step = src.numCols+1;
double rank_condition = Math.ulp(1.0d);
CNumber det = CNumber.one();

for(int i=0; i<src.numRows; i+=step) {
for(int i=0; i<src.entries.length; i+=step) {
value = src.entries[i];
det.multEq(value);
inverse.entries[i] = value.multInv();
}

if(det.mag() <= rank_condition*Math.max(src.numRows, src.numCols)) {
if(det.mag() <= EPS_F64*Math.max(src.numRows, src.numCols)) {
throw new SingularMatrixException("Could not invert.");
}

Expand Down Expand Up @@ -249,14 +249,50 @@ public static Matrix invSymPosDef(Matrix src) {
public static Matrix invSymPosDef(Matrix src, boolean checkPosDef) {
Cholesky<Matrix> chol = new RealCholesky(checkPosDef).decompose(src);
RealBackSolver backSolver = new RealBackSolver();
RealForwardSolver forwardSolver = new RealForwardSolver(true);
RealForwardSolver forwardSolver = new RealForwardSolver();

// Compute the inverse of unit lower triangular matrix L.
// Compute the inverse of lower triangular matrix L.
Matrix Linv = forwardSolver.solveIdentity(chol.getL());

return backSolver.solveLower(chol.getLH(), Linv); // Compute inverse of src.
}


/**
* Inverts a hermation positive definite matrix.
* @param src Positive definite matrix. It will <i>not</i> be verified if {@code src} is actually hermation positive definite.
* @return The inverse of the {@code src} matrix.
* @throws IllegalArgumentException If the matrix is not square.
* @throws SingularMatrixException If the {@code src} matrix is singular.
* @see #invSymPosDef(Matrix, boolean)
*/
public static CMatrix invHermPosDef(CMatrix src) {
return invHermPosDef(src, false);
}


/**
* Inverts a hermation positive definite matrix.
* @param src Positive definite matrix.
* @param checkPosDef Flag indicating if a check should be made to see if {@code src} is actually hermation
* positive definite. <b>WARNING</b>: Checking if the matrix is positive definite can be very computationally
* expensive.
* @return The inverse of the {@code src} matrix.
* @throws IllegalArgumentException If the matrix is not square.
* @throws SingularMatrixException If the {@code src} matrix is singular.
*/
public static CMatrix invHermPosDef(CMatrix src, boolean checkPosDef) {
Cholesky<CMatrix> chol = new ComplexCholesky(checkPosDef).decompose(src);
ComplexBackSolver backSolver = new ComplexBackSolver();
ComplexForwardSolver forwardSolver = new ComplexForwardSolver();

// Compute the inverse of lower triangular matrix L.
CMatrix Linv = forwardSolver.solveIdentity(chol.getL());

return backSolver.solveLower(chol.getLH(), Linv); // Compute inverse of src.
}


// ------------------------------------------- Pseudo-inverses below -------------------------------------------

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ public RealCholesky(boolean enforceSymmetric) {
@Override
public RealCholesky decompose(Matrix src) {
if(enforceHermation && src.isSymmetric()) {
throw new IllegalArgumentException("Matrix must be symmetric positive-definite.");
throw new LinearAlgebraException("Matrix must be symmetric positive-definite.");
} else {
ParameterChecks.assertSquareMatrix(src.shape);
}

L = new Matrix(src.numRows);
double posDefTolerance = Math.max(L.numRows* Flag4jConstants.EPS_F64, DEFAULT_POS_DEF_TOLERANCE);
double posDefTolerance = Math.max(L.numRows*Flag4jConstants.EPS_F64, DEFAULT_POS_DEF_TOLERANCE);
double sum;

int lIndex1;
Expand Down
109 changes: 109 additions & 0 deletions src/test/java/org/flag4j/linalg/CMatrixInvertTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package org.flag4j.linalg;

import org.flag4j.complex_numbers.CNumber;
import org.flag4j.dense.CMatrix;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class CMatrixInvertTests {

static CMatrix A;
static CMatrix exp;
static CNumber[][] entries;
static CNumber[][] expEntries;


@Test
void invTriUTests() {
// --------------------- Sub-case 1 ---------------------
entries = new CNumber[][]{
{new CNumber("0.72964+0.77161i"), new CNumber("0.04017+0.35861i")},
{new CNumber("0.0"), new CNumber("0.02885+0.61375i")}};
A = new CMatrix(entries);
expEntries = new CNumber[][]{
{new CNumber(0.6469836227593948, -0.6841991025127141), new CNumber(-0.35324900268232806, 0.4255132196349395)},
{new CNumber(0.0, 0.0), new CNumber(0.07641951197016296, -1.6257357182560666)}};
exp = new CMatrix(expEntries);

assertEquals(exp, Invert.invTriU(A));

// --------------------- Sub-case 2 ---------------------
entries = new CNumber[][]{
{new CNumber("0.88984+0.74576i"), new CNumber("0.35899+0.67567i"), new CNumber("0.67057+0.03486i"), new CNumber("0.27239+0.14667i")},
{new CNumber("0.0"), new CNumber("0.29455+0.74263i"), new CNumber("0.99331+0.19522i"), new CNumber("0.89592+0.76741i")},
{new CNumber("0.0"), new CNumber("0.0"), new CNumber("0.70384+0.74177i"), new CNumber("0.59081+0.73518i")},
{new CNumber("0.0"), new CNumber("0.0"), new CNumber("0.0"), new CNumber("0.97421+0.61448i")}};
A = new CMatrix(entries);
expEntries = new CNumber[][]{
{new CNumber(0.6601318170773559, -0.5532454192929167), new CNumber(-0.569755474276714, 0.5964889359681658), new CNumber(0.0538261672936495, -0.24610680231262014), new CNumber(0.39210142039761287, -0.18312414720644973)},
{new CNumber(0.0, 0.0), new CNumber(0.46148975736667824, -1.1635244899447168), new CNumber(0.29451765428159893, 1.2036600935687118), new CNumber(-0.5481358478041402, 0.10001294703720921)},
{new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.6731359287847971, -0.709411283664894), new CNumber(-0.7101078401778858, 0.3701442953886289)},
{new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.7343268609204316, -0.4631744382611417)}};
exp = new CMatrix(expEntries);

assertEquals(exp, Invert.invTriU(A));
}


@Test
void invTriLTests() {
// --------------------- Sub-case 1 ---------------------
entries = new CNumber[][]{
{new CNumber("0.72964+0.77161i"), new CNumber("0.04017+0.35861i")},
{new CNumber("0.0"), new CNumber("0.02885+0.61375i")}};
A = new CMatrix(entries).T();
expEntries = new CNumber[][]{
{new CNumber(0.6469836227593948, -0.6841991025127141), new CNumber(-0.35324900268232806, 0.4255132196349395)},
{new CNumber(0.0, 0.0), new CNumber(0.07641951197016296, -1.6257357182560666)}};
exp = new CMatrix(expEntries).T();

assertEquals(exp, Invert.invTriL(A));

// --------------------- Sub-case 2 ---------------------
entries = new CNumber[][]{
{new CNumber("0.88984+0.74576i"), new CNumber("0.35899+0.67567i"), new CNumber("0.67057+0.03486i"), new CNumber("0.27239+0.14667i")},
{new CNumber("0.0"), new CNumber("0.29455+0.74263i"), new CNumber("0.99331+0.19522i"), new CNumber("0.89592+0.76741i")},
{new CNumber("0.0"), new CNumber("0.0"), new CNumber("0.70384+0.74177i"), new CNumber("0.59081+0.73518i")},
{new CNumber("0.0"), new CNumber("0.0"), new CNumber("0.0"), new CNumber("0.97421+0.61448i")}};
A = new CMatrix(entries).T();
expEntries = new CNumber[][]{
{new CNumber(0.6601318170773559, -0.5532454192929167), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0)},
{new CNumber(-0.5697554742767139, 0.5964889359681657), new CNumber(0.46148975736667824, -1.1635244899447168), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0)},
{new CNumber(0.05382616729364944, -0.2461068023126199), new CNumber(0.29451765428159893, 1.203660093568712), new CNumber(0.6731359287847971, -0.709411283664894), new CNumber(0.0, 0.0)},
{new CNumber(0.39210142039761287, -0.18312414720644984), new CNumber(-0.5481358478041398, 0.10001294703720916), new CNumber(-0.7101078401778858, 0.3701442953886289), new CNumber(0.7343268609204316, -0.4631744382611417)}};
exp = new CMatrix(expEntries);

assertEquals(exp, Invert.invTriL(A));
}


@Test
void invDiagTests() {
// --------------------- Sub-case 1 ---------------------
entries = new CNumber[][]{
{new CNumber(-14.43, 95.1), new CNumber()},
{new CNumber(), new CNumber(0, 1.45)}};
A = new CMatrix(entries);
expEntries = new CNumber[][]{
{new CNumber(-14.43, 95.1).multInv(), new CNumber()},
{new CNumber(), new CNumber(0, 1.45).multInv()}};
exp = new CMatrix(expEntries);

assertEquals(exp, Invert.invDiag(A));

// --------------------- Sub-case 2 ---------------------
entries = new CNumber[][]{
{new CNumber(-14.43, 95.1), new CNumber(), new CNumber()},
{new CNumber(), new CNumber(0, 1.45), new CNumber()},
{new CNumber(), new CNumber(), new CNumber(234.156)}};
A = new CMatrix(entries);
expEntries = new CNumber[][]{
{new CNumber(-14.43, 95.1).multInv(), new CNumber(), new CNumber()},
{new CNumber(), new CNumber(0, 1.45).multInv(), new CNumber()},
{new CNumber(), new CNumber(), new CNumber(234.156).multInv()}};
exp = new CMatrix(expEntries);

assertEquals(exp, Invert.invDiag(A));
}
}
Loading

0 comments on commit 3a7341d

Please sign in to comment.