diff --git a/src/main/java/org/flag4j/linalg/decompositions/hess/SymmHess.java b/src/main/java/org/flag4j/linalg/decompositions/hess/SymmHess.java
index 3d1e302ed..0e5870e00 100644
--- a/src/main/java/org/flag4j/linalg/decompositions/hess/SymmHess.java
+++ b/src/main/java/org/flag4j/linalg/decompositions/hess/SymmHess.java
@@ -26,9 +26,7 @@
import org.flag4j.dense.Matrix;
-import org.flag4j.linalg.decompositions.Decomposition;
import org.flag4j.linalg.transformations.Householder;
-import org.flag4j.util.Flag4jConstants;
import org.flag4j.util.ParameterChecks;
import org.flag4j.util.exceptions.LinearAlgebraException;
@@ -52,7 +50,7 @@
* [ 0 0 0 x x ]]
*
*/
-public class SymmHess implements Decomposition {
+public class SymmHess extends RealHess {
/**
* Flag indicating if an explicit check should be made that the matrix to be decomposed is symmetric.
@@ -60,59 +58,11 @@ public class SymmHess implements Decomposition {
protected boolean enforceSymmetric;
/**
- * Storage for the symmetric tri-diagonal matrix and if requested, the Householder vectors used to bring the original matrix
- * into upper Hessenberg form. The symmetric tri-diagonal matrix will be stored in the principle diagonal and the first
- * super-diagonal (Since the matrix is symmetric there is no need to store the first sub-diagonal). The rows of the strictly
- * lower-triangular portion of the matrix will be used to store the Householder vectors used to transform the source matrix
- * to upper Hessenburg form if it is requested via {@link #computeQ}. These can be used to compute the full orthogonal matrix
- * {@code Q} of the Hessenberg decomposition.
- */
- protected Matrix transformMatrix;
- /**
- * For storing norms of columns in A when computing Householder reflectors.
- */
- double norm;
- /**
- * Size of the symmetric matrix to be decomposed. That is, the number of rows and columns.
- */
- protected int size;
- /**
- * Flag indicating if the orthogonal transformation matrix from the Hessenburg decomposition should be explicitly computed.
- */
- protected boolean computeQ;
- /**
- * Stores the scalar factor for the current Householder reflector.
- */
- double currentFactor;
- /**
- * Storage of the scalar factors for the Householder reflectors used in the decomposition.
- */
- protected double[] qFactors;
- /**
- * For storing a Householder vectors.
- */
- protected double[] householderVector;
- /**
- * For temporarily storage when applying Householder vectors. This is useful for
- * avoiding unneeded garbage collection and for improving cache performance when traversing columns.
- */
- protected double[] workArray;
- /**
- * Flag indicating if a Householder reflector was needed for the current column meaning an update needs to be applied.
- */
- protected boolean applyUpdate;
- /**
- * Stores the shifted value of the first entry in a Householder vector.
- */
- private double shift;
-
-
- /**
- * Constructs a Hessenberg decomposer for symmetric matrices. To compute the
+ * Constructs a Hessenberg decomposer for symmetric matrices. By default, the Householder vectors used in the decomposition will be
+ * stored so that the full orthogonal {@code Q} matrix can be formed by calling {@link #getQ()}.
*/
public SymmHess() {
- computeQ = false;
- enforceSymmetric = false;
+ super();
}
@@ -122,8 +72,7 @@ public SymmHess() {
* If true, then the {@code Q} matrix will be computed explicitly.
*/
public SymmHess(boolean computeQ) {
- enforceSymmetric = false;
- this.computeQ = computeQ;
+ super(computeQ);
}
@@ -137,8 +86,8 @@ public SymmHess(boolean computeQ) {
* matrix is not symmetric, then the values in the upper triangular portion of the matrix are taken to be the values.
*/
public SymmHess(boolean computeQ, boolean enforceSymmetric) {
+ super(computeQ);
this.enforceSymmetric = enforceSymmetric;
- this.computeQ = computeQ;
}
@@ -150,14 +99,7 @@ public SymmHess(boolean computeQ, boolean enforceSymmetric) {
*/
@Override
public SymmHess decompose(Matrix src) {
- setUp(src);
- int stop = size-2;
-
- for(int k=0; k 1) {
- int rowColBase = size*size - 1;
+ if(numRows > 1) {
+ int rowColBase = numRows*numRows - 1;
H.entries[rowColBase] = transformMatrix.entries[rowColBase];
- H.entries[rowColBase - 1] = transformMatrix.entries[rowColBase - size];
+ H.entries[rowColBase - 1] = transformMatrix.entries[rowColBase - numRows];
}
return H;
}
- /**
- * Gets the unitary {@code Q} matrix from the Hessenberg decomposition.
- *
- * Note, if the reflectors for this decomposition were not saved, then {@code Q} can not be computed and this method will be
- * null.
- *
- * @return The {@code Q} matrix from the {@code QR} decomposition. Note, if the reflectors for this decomposition were not saved,
- * then {@code Q} can not be computed and this method will return {@code null}.
- */
- public Matrix getQ() {
- if(!computeQ)
- return null;
-
- Matrix Q = Matrix.I(size);
-
- for(int j=size - 1; j>=1; j--) {
- householderVector[j] = 1.0; // Ensure first value of reflector is 1.
-
- for(int i=j + 1; i= Flag4jConstants.EPS_F64;
-
- if(!applyUpdate) {
- currentFactor = 0;
- } else {
- computePhasedNorm(j, maxAbs);
-
- householderVector[j] = 1.0; // Ensure first value in Householder vector is one.
- for(int i=j+1; i {
* For computing determinant of coefficient matrix during solve.
*/
protected CNumber det;
- /**
- * For checking against other values.
- */
- private final CNumber z = CNumber.zero();
/**
@@ -92,7 +88,7 @@ public CVector solve(CMatrix U, CVector b) {
int uIndex;
int n = b.size;
x = new CVector(U.numRows);
- det = U.entries[n*n-1];
+ det = U.entries[n*n-1].copy();
x.entries[n-1] = b.entries[n-1].div(det);
@@ -134,12 +130,13 @@ public CMatrix solve(CMatrix U, CMatrix B) {
int uIndex, xIndex;
int n = B.numRows;
X = new CMatrix(B.shape);
- det = U.entries[n*n-1].copy();
+ det = U.entries[U.entries.length-1].copy();
xCol = new CNumber[n];
for(int j=0; j-1; i--) {
+ for(int i=n-2; i>=0; i--) {
sum = 0;
uIndex = i*U.numCols;
@@ -130,7 +130,7 @@ public Matrix solve(Matrix U, Matrix B) {
double uValue = U.entries[n*n-1];
int rowOffset = (n-1)*B.numCols;
X = new Matrix(B.shape);
- det = 1;
+ det = U.entries[n*n-1];
xCol = new double[n];
@@ -148,7 +148,7 @@ public Matrix solve(Matrix U, Matrix B) {
xIndex = i*X.numCols + j;
diag = U.entries[i*(n+1)];
- det*=diag;
+ if(j==0) det *= diag;
for(int k=i+1; k-1; i--) {
+ for(int i=n-2; i>=0; i--) {
sum = (i == j) ? 1 : 0;
- uIndex = i*U.numCols;
+ uIndex = i*n;
xIndex = uIndex + j;
uIndex += i+1;
diag = U.entries[i*(n+1)];
- det*=diag;
+ if(j==0) det *= diag;
for(int k=i+1; khess.decompose(A));
+ }
+}
diff --git a/src/test/java/org/flag4j/linalg/solvers/ComplexForwardSolverTests.java b/src/test/java/org/flag4j/linalg/solvers/ComplexForwardSolverTests.java
index 0016ab247..ad56f746d 100644
--- a/src/test/java/org/flag4j/linalg/solvers/ComplexForwardSolverTests.java
+++ b/src/test/java/org/flag4j/linalg/solvers/ComplexForwardSolverTests.java
@@ -125,7 +125,7 @@ void solveMatrixTestCase() {
lEntries = new CNumber[][]{
{new CNumber(1.25, -9.25), new CNumber(), new CNumber()},
{new CNumber(-815.5, 1.444), new CNumber(2.45, 15.5), new CNumber()},
- {new CNumber(0, -9.256), new CNumber(2.45, -83.2), new CNumber()}
+ {new CNumber(0, -9.256), new CNumber(2.45, -83.2), new CNumber(1)}
};
L = new CMatrix(lEntries);
bEntries = new CNumber[][]{
diff --git a/target/flag4j-v0.0.1-beta.jar b/target/flag4j-v0.0.1-beta.jar
index 0e63e9921..bd0e3e4c4 100644
Binary files a/target/flag4j-v0.0.1-beta.jar and b/target/flag4j-v0.0.1-beta.jar differ