diff --git a/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java index 3bca422d5..a2a1b0635 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java @@ -96,6 +96,37 @@ public class CsrCMatrix public final int nnz; + /** + * Constructs an empty sparse CSR matrix with the specified shape. + * @param shape Shape of the CSR matrix. + */ + public CsrCMatrix(Shape shape) { + super(shape, 0, new CNumber[0], new int[shape.dims[0]+1], new int[0]); + + numRows = shape.dims[0]; + numCols = shape.dims[1]; + this.rowPointers = indices[0]; + this.colIndices = indices[1]; + nnz = entries.length; + } + + + /** + * Constructs an empty sparse CSR matrix with the specified shape. + * @param numRows Number of rows in the CSR matrix. + * @param numCols Number of columns in the CSR matrix. + */ + public CsrCMatrix(int numRows, int numCols) { + super(new Shape(numRows, numCols), 0, new CNumber[0], new int[numRows+1], new int[0]); + + this.numRows = shape.dims[0]; + this.numCols = shape.dims[1]; + this.rowPointers = indices[0]; + this.colIndices = indices[1]; + nnz = entries.length; + } + + /** * Constructs a sparse matrix in CSR format with specified row-pointers, column indices and non-zero entries. * @param shape Shape of the matrix. @@ -1841,9 +1872,9 @@ public CooCVector getDiag() { int start = rowPointers[i]; int stop = rowPointers[i+1]; - int loc = Arrays.binarySearch(colIndices, i, start, stop); // Search for matching column index + int loc = Arrays.binarySearch(colIndices, start, stop, i); // Search for matching column index - if(loc > 0) { + if(loc >= 0) { destEntries.add(entries[loc].copy()); destIndices.add(i); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java index d2c956843..d787eaaf1 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java @@ -99,6 +99,37 @@ public class CsrMatrix public final int nnz; + /** + * Constructs an empty sparse CSR matrix with the specified shape. + * @param shape Shape of the CSR matrix. + */ + public CsrMatrix(Shape shape) { + super(shape, 0, new double[0], new int[shape.dims[0]+1], new int[0]); + + numRows = shape.dims[0]; + numCols = shape.dims[1]; + this.rowPointers = indices[0]; + this.colIndices = indices[1]; + nnz = entries.length; + } + + + /** + * Constructs an empty sparse CSR matrix with the specified shape. + * @param numRows Number of rows in the CSR matrix. + * @param numCols Number of columns in the CSR matrix. + */ + public CsrMatrix(int numRows, int numCols) { + super(new Shape(numRows, numCols), 0, new double[0], new int[numRows+1], new int[0]); + + this.numRows = shape.dims[0]; + this.numCols = shape.dims[1]; + this.rowPointers = indices[0]; + this.colIndices = indices[1]; + nnz = entries.length; + } + + /** * Constructs a sparse matrix in CSR format with specified row-pointers, column indices and non-zero entries. * @param shape Shape of the matrix. @@ -107,7 +138,7 @@ public class CsrMatrix * @param colIndices Column indices for CSR matrix. */ public CsrMatrix(Shape shape, double[] entries, int[] rowPointers, int[] colIndices) { - super(shape, entries.length, entries, new int[colIndices.length], colIndices); + super(shape, entries.length, entries, rowPointers, colIndices); this.rowPointers = rowPointers; this.colIndices = colIndices; @@ -1373,9 +1404,9 @@ public CooVector getDiag() { int start = rowPointers[i]; int stop = rowPointers[i+1]; - int loc = Arrays.binarySearch(colIndices, i, start, stop); // Search for matching column index + int loc = Arrays.binarySearch(colIndices, start, stop, i); // Search for matching column index - if(loc > 0) { + if(loc >= 0) { destEntries.add(entries[loc]); destIndices.add(i); } diff --git a/src/main/java/org/flag4j/core/sparse_base/SparseTensorBase.java b/src/main/java/org/flag4j/core/sparse_base/SparseTensorBase.java index 7f8a7be59..9415175e6 100644 --- a/src/main/java/org/flag4j/core/sparse_base/SparseTensorBase.java +++ b/src/main/java/org/flag4j/core/sparse_base/SparseTensorBase.java @@ -92,24 +92,11 @@ protected SparseTensorBase(Shape shape, int nonZeroEntries, D entries, int[] ini int totalIndices = restIndices.length + 1; ParameterChecks.assertEquals(totalIndices, shape.getRank()); - // TODO: This needs to be enforced in CooMatrix class -// ParameterChecks.assertArrayLengthsEq(nonZeroEntries, initIndices.length); - this.indices = new int[totalIndices][]; this.indices[0] = initIndices; - // TODO: This needs to be enforced in CooMatrix class. -// for(int i=1; inew CsrCMatrix(12, 4).tr()); + } + + + @Test + void getDiagTests() { + CooCVector exp; + + // ----------------------- sub-case 1 ----------------------- + aShape = new Shape(6, 6); + aEntries = new CNumber[]{new CNumber(0.16488, 0.46447), new CNumber(0.3774, 0.35246), + new CNumber(0.48798, 0.43226), new CNumber(0.48544, 0.0083), new CNumber(0.77148, 0.92185)}; + aRowPointers = new int[]{0, 0, 0, 1, 3, 5, 5}; + aColIndices = new int[]{2, 0, 5, 2, 4}; + A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new CVector(new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.16488, 0.46447), + new CNumber(0.0, 0.0), new CNumber(0.77148, 0.92185), new CNumber(0.0, 0.0)).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 2 ----------------------- + aShape = new Shape(6, 6); + aEntries = new CNumber[]{new CNumber(0.3583, 0.11616), new CNumber(0.38055, 0.12452), + new CNumber(0.97075, 0.0994), new CNumber(0.43841, 0.0319), new CNumber(0.78954, 0.72668)}; + aRowPointers = new int[]{0, 1, 3, 4, 4, 4, 5}; + aColIndices = new int[]{4, 0, 3, 5, 2}; + A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new CVector(0, 0, 0, 0, 0, 0).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 3 ----------------------- + aShape = new Shape(12, 12); + aEntries = new CNumber[]{new CNumber(0.13994, 0.01355), new CNumber(0.5696, 0.69936), + new CNumber(0.1473, 0.92664), new CNumber(0.40113, 0.88561), new CNumber(0.45177, 0.86077), + new CNumber(0.91248, 0.28304), new CNumber(0.71058, 0.73788), new CNumber(0.04337, 0.79422), + new CNumber(0.26867, 0.54323), new CNumber(0.28884, 0.58676), new CNumber(0.62439, 0.17585), + new CNumber(0.76891, 0.03482), new CNumber(0.02111, 0.58864), new CNumber(0.36376, 0.44496)}; + aRowPointers = new int[]{0, 1, 4, 5, 6, 8, 10, 11, 12, 13, 13, 14, 14}; + aColIndices = new int[]{7, 3, 5, 8, 6, 3, 6, 7, 3, 5, 11, 10, 8, 11}; + A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new CVector(new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), + new CNumber(0.91248, 0.28304), new CNumber(0.0, 0.0), new CNumber(0.28884, 0.58676), + new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.02111, 0.58864), + new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0)).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 4 ----------------------- + aShape = new Shape(3, 7); + aEntries = new CNumber[]{new CNumber(0.85218, 0.08775), new CNumber(0.16499, 0.26153), new CNumber(0.36642, 0.63926), new CNumber(0.63494, 0.79019)}; + aRowPointers = new int[]{0, 1, 3, 4}; + aColIndices = new int[]{0, 3, 5, 1}; + A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new CVector(new CNumber(0.85218, 0.08775), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0)).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 5 ----------------------- + aShape = new Shape(7, 3); + aEntries = new CNumber[]{new CNumber(0.38881, 0.35982), new CNumber(0.82873, 0.85744), + new CNumber(0.22975, 0.7268), new CNumber(0.9345, 0.92117)}; + aRowPointers = new int[]{0, 1, 1, 2, 2, 2, 3, 4}; + aColIndices = new int[]{0, 2, 2, 0}; + A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new CVector(new CNumber(0.38881, 0.35982), new CNumber(0.0, 0.0), + new CNumber(0.82873, 0.85744)).toCoo(); + assertEquals(exp, A.getDiag()); + } +} diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java b/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java new file mode 100644 index 000000000..948faaadd --- /dev/null +++ b/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java @@ -0,0 +1,105 @@ +package org.flag4j.sparse_csr_matrix; + +import org.flag4j.arrays.dense.Vector; +import org.flag4j.arrays.sparse.CooVector; +import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.core.Shape; +import org.flag4j.util.exceptions.LinearAlgebraException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CsrMatrixTriDiagTests { + + static CsrMatrix A; + static Shape aShape; + static double[] aEntries; + static int[] aRowPointers; + static int[] aColIndices; + + @Test + void traceTest() { + // ----------------------- sub-case 1 ----------------------- + aShape = new Shape(6, 6); + aEntries = new double[]{0.550651808914819, 0.25248902903617765, 0.7347711759989852, 0.6355561756397741, 0.4229489382733195}; + aRowPointers = new int[]{0, 1, 3, 3, 4, 4, 5}; + aColIndices = new int[]{1, 0, 3, 4, 2}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + assertEquals(0.0, A.trace()); + + // ----------------------- sub-case 2 ----------------------- + aShape = new Shape(6, 6); + aEntries = new double[]{0.3947350230967598, 0.057362465663433504, 0.9347345786066202, 0.43387856849809536, 0.644454913626892}; + aRowPointers = new int[]{0, 2, 2, 2, 3, 3, 5}; + aColIndices = new int[]{2, 5, 3, 0, 3}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + assertEquals(0.9347345786066202, A.trace()); + + // ----------------------- sub-case 3 ----------------------- + aShape = new Shape(15, 15); + aEntries = new double[]{0.8086884822224377, 0.8263266272760342, 0.32525649651881317, 0.9608789667749497, 0.4395049603468911, + 0.07305897091345792, 0.7678491080726771, 0.2294487507606635, 0.1382399153004038, 0.920903220380491, 0.6895253484430522, + 0.9256640817501549, 0.8520473820197568, 0.6642547057142691, 0.9992544305209986, 0.6144764840301362, + 0.38299169954163803, 0.04890094352081187, 0.07595056410539092, 0.6778873127860436, 0.5725538366855386, + 0.633943368653018}; + aRowPointers = new int[]{0, 2, 5, 5, 6, 9, 10, 13, 14, 15, 15, 16, 17, 20, 21, 22}; + aColIndices = new int[]{8, 9, 1, 2, 13, 14, 0, 11, 13, 7, 8, 12, 14, 8, 14, 5, 4, 1, 4, 12, 9, 9}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + assertEquals(1.0031438093048568, A.trace()); + + // ----------------------- sub-case 4 ----------------------- + assertThrows(LinearAlgebraException.class, ()->new CsrMatrix(12, 4).tr()); + } + + + @Test + void getDiagTests() { + CooVector exp; + + // ----------------------- sub-case 1 ----------------------- + aShape = new Shape(6, 6); + aEntries = new double[]{0.2742696420912182, 0.07751690909235676, 0.921830928702064, 0.1935807720906103, 0.7540620527220376}; + aRowPointers = new int[]{0, 0, 3, 3, 4, 5, 5}; + aColIndices = new int[]{0, 2, 4, 5, 4}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new Vector(0.0, 0.0, 0.0, 0.0, 0.7540620527220376, 0.0).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 2 ----------------------- + aShape = new Shape(6, 6); + aEntries = new double[]{0.5117934899539107, 0.5232442605382139, 0.6241020346691114, 0.5149866370978797, 0.9899310737230431, 0.18946387228386474, 0.2824144399493015, 0.25137905573116825, 0.9609185162323945, 0.6746428794940532, 0.47997193451501086}; + aRowPointers = new int[]{0, 2, 5, 7, 9, 10, 11}; + aColIndices = new int[]{2, 5, 0, 1, 3, 2, 5, 0, 4, 1, 5}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new Vector(0.0, 0.5149866370978797, 0.18946387228386474, 0.0, 0.0, 0.47997193451501086).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 3 ----------------------- + aShape = new Shape(12, 12); + aEntries = new double[]{0.989336440083382, 0.5784226184530201, 0.4128615264146236, 0.8720518156451399, 0.015332897118263578, 0.30602701869807314, 0.5445047402785209, 0.6799626535731805, 0.8160519253929756, 0.44554133866387846, 0.5507111109163054, 0.5127437616463539, 0.12630665534801888, 0.8348974142473102}; + aRowPointers = new int[]{0, 1, 3, 5, 8, 9, 9, 10, 12, 12, 13, 14, 14}; + aColIndices = new int[]{6, 6, 8, 7, 10, 3, 6, 8, 4, 5, 3, 5, 3, 4}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new Vector(0.0, 0.0, 0.0, 0.30602701869807314, 0.8160519253929756, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 4 ----------------------- + aShape = new Shape(3, 7); + aEntries = new double[]{0.46451579858863223, 0.761818146777147, 0.4513656938478876, 0.8045385713145264}; + aRowPointers = new int[]{0, 0, 2, 4}; + aColIndices = new int[]{1, 4, 1, 2}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new Vector(0.0, 0.46451579858863223, 0.8045385713145264).toCoo(); + assertEquals(exp, A.getDiag()); + + // ----------------------- sub-case 5 ----------------------- + aShape = new Shape(7, 3); + aEntries = new double[]{0.11343649359086139, 0.6082353444109962, 0.8703305006626817, 0.9546036038577662}; + aRowPointers = new int[]{0, 0, 2, 3, 4, 4, 4, 4}; + aColIndices = new int[]{1, 2, 2, 0}; + A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices); + exp = new Vector(0.0, 0.11343649359086139, 0.8703305006626817).toCoo(); + assertEquals(exp, A.getDiag()); + } +} diff --git a/target/flag4j-v0.1.0-beta.jar b/target/flag4j-v0.1.0-beta.jar index b9345180b..b713fa317 100644 Binary files a/target/flag4j-v0.1.0-beta.jar and b/target/flag4j-v0.1.0-beta.jar differ