Skip to content

Commit

Permalink
Added Tests for CSRCMatrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdwatters committed Jul 28, 2024
1 parent aabbd12 commit 3d4d588
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 18 deletions.
35 changes: 33 additions & 2 deletions src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,37 @@ public class CsrCMatrix
public final int nnz;


/**
* Constructs an empty sparse CSR matrix with the specified shape.
* @param shape Shape of the CSR matrix.
*/
public CsrCMatrix(Shape shape) {
super(shape, 0, new CNumber[0], new int[shape.dims[0]+1], new int[0]);

numRows = shape.dims[0];
numCols = shape.dims[1];
this.rowPointers = indices[0];
this.colIndices = indices[1];
nnz = entries.length;
}


/**
* Constructs an empty sparse CSR matrix with the specified shape.
* @param numRows Number of rows in the CSR matrix.
* @param numCols Number of columns in the CSR matrix.
*/
public CsrCMatrix(int numRows, int numCols) {
super(new Shape(numRows, numCols), 0, new CNumber[0], new int[numRows+1], new int[0]);

this.numRows = shape.dims[0];
this.numCols = shape.dims[1];
this.rowPointers = indices[0];
this.colIndices = indices[1];
nnz = entries.length;
}


/**
* Constructs a sparse matrix in CSR format with specified row-pointers, column indices and non-zero entries.
* @param shape Shape of the matrix.
Expand Down Expand Up @@ -1841,9 +1872,9 @@ public CooCVector getDiag() {
int start = rowPointers[i];
int stop = rowPointers[i+1];

int loc = Arrays.binarySearch(colIndices, i, start, stop); // Search for matching column index
int loc = Arrays.binarySearch(colIndices, start, stop, i); // Search for matching column index

if(loc > 0) {
if(loc >= 0) {
destEntries.add(entries[loc].copy());
destIndices.add(i);
}
Expand Down
37 changes: 34 additions & 3 deletions src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,37 @@ public class CsrMatrix
public final int nnz;


/**
* Constructs an empty sparse CSR matrix with the specified shape.
* @param shape Shape of the CSR matrix.
*/
public CsrMatrix(Shape shape) {
super(shape, 0, new double[0], new int[shape.dims[0]+1], new int[0]);

numRows = shape.dims[0];
numCols = shape.dims[1];
this.rowPointers = indices[0];
this.colIndices = indices[1];
nnz = entries.length;
}


/**
* Constructs an empty sparse CSR matrix with the specified shape.
* @param numRows Number of rows in the CSR matrix.
* @param numCols Number of columns in the CSR matrix.
*/
public CsrMatrix(int numRows, int numCols) {
super(new Shape(numRows, numCols), 0, new double[0], new int[numRows+1], new int[0]);

this.numRows = shape.dims[0];
this.numCols = shape.dims[1];
this.rowPointers = indices[0];
this.colIndices = indices[1];
nnz = entries.length;
}


/**
* Constructs a sparse matrix in CSR format with specified row-pointers, column indices and non-zero entries.
* @param shape Shape of the matrix.
Expand All @@ -107,7 +138,7 @@ public class CsrMatrix
* @param colIndices Column indices for CSR matrix.
*/
public CsrMatrix(Shape shape, double[] entries, int[] rowPointers, int[] colIndices) {
super(shape, entries.length, entries, new int[colIndices.length], colIndices);
super(shape, entries.length, entries, rowPointers, colIndices);

this.rowPointers = rowPointers;
this.colIndices = colIndices;
Expand Down Expand Up @@ -1373,9 +1404,9 @@ public CooVector getDiag() {
int start = rowPointers[i];
int stop = rowPointers[i+1];

int loc = Arrays.binarySearch(colIndices, i, start, stop); // Search for matching column index
int loc = Arrays.binarySearch(colIndices, start, stop, i); // Search for matching column index

if(loc > 0) {
if(loc >= 0) {
destEntries.add(entries[loc]);
destIndices.add(i);
}
Expand Down
13 changes: 0 additions & 13 deletions src/main/java/org/flag4j/core/sparse_base/SparseTensorBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,11 @@ protected SparseTensorBase(Shape shape, int nonZeroEntries, D entries, int[] ini
int totalIndices = restIndices.length + 1;
ParameterChecks.assertEquals(totalIndices, shape.getRank());

// TODO: This needs to be enforced in CooMatrix class
// ParameterChecks.assertArrayLengthsEq(nonZeroEntries, initIndices.length);

this.indices = new int[totalIndices][];
this.indices[0] = initIndices;

// TODO: This needs to be enforced in CooMatrix class.
// for(int i=1; i<totalIndices; i++) {
// if(restIndices[i-1].length != initIndices.length) {
// throw new IllegalArgumentException(
// String.format("All index array must have the same length but got %d and %d.",
// initIndices.length, restIndices[i-1].length));
// }
// }

System.arraycopy(restIndices, 0, this.indices, 1, totalIndices - 1);


this.nonZeroEntries = nonZeroEntries;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package org.flag4j.sparse_csr_complex_matrix;

import org.flag4j.arrays.dense.CVector;
import org.flag4j.arrays.sparse.CooCVector;
import org.flag4j.arrays.sparse.CsrCMatrix;
import org.flag4j.complex_numbers.CNumber;
import org.flag4j.core.Shape;
import org.flag4j.util.exceptions.LinearAlgebraException;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

class CsrCMatrixTriDiagTests {

static CsrCMatrix A;
static Shape aShape;
static CNumber[] aEntries;
static int[] aRowPointers;
static int[] aColIndices;

@Test
void traceTest() {
// ----------------------- sub-case 1 -----------------------
aShape = new Shape(6, 6);
aEntries = new CNumber[]{new CNumber(0.13392, 0.67581), new CNumber(0.08953, 0.43612),
new CNumber(0.68598, 0.57467), new CNumber(0.19996, 0.4443), new CNumber(0.1453, 0.59643)};
aRowPointers = new int[]{0, 1, 1, 2, 2, 5, 5};
aColIndices = new int[]{5, 0, 0, 1, 3};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
assertEquals(new CNumber(0), A.trace());

// ----------------------- sub-case 2 -----------------------
aShape = new Shape(6, 6);
aEntries = new CNumber[]{new CNumber(0.90327, 0.45253), new CNumber(0.21721, 0.28695),
new CNumber(0.65185, 0.93707), new CNumber(0.48592, 0.63105), new CNumber(0.96722, 0.76818)};
aRowPointers = new int[]{0, 1, 1, 2, 4, 5, 5};
aColIndices = new int[]{3, 5, 3, 4, 5};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
assertEquals(new CNumber(0.65185, 0.93707), A.trace());

// ----------------------- sub-case 3 -----------------------
aShape = new Shape(12, 12);
aEntries = new CNumber[]{new CNumber(0.52333, 0.67155), new CNumber(0.79849, 0.05489),
new CNumber(0.2229, 0.67036), new CNumber(0.65271, 0.51699), new CNumber(0.63722, 0.55373),
new CNumber(0.21806, 0.5938), new CNumber(0.06624, 0.41699), new CNumber(0.32211, 0.1279),
new CNumber(0.11324, 0.21277), new CNumber(0.45704, 0.75931), new CNumber(0.13948, 0.53299),
new CNumber(0.13934, 0.59231), new CNumber(0.30193, 0.98664), new CNumber(0.11591, 0.98686),
new CNumber(0.29993, 0.8055), new CNumber(0.43436, 0.64936), new CNumber(0.9495, 0.32514),
new CNumber(0.22636, 0.4559), new CNumber(0.58931, 0.3885), new CNumber(0.332, 0.82381),
new CNumber(0.11975, 0.11127), new CNumber(0.46906, 0.80406)};
aRowPointers = new int[]{0, 4, 5, 7, 8, 9, 10, 12, 13, 15, 19, 21, 22};
aColIndices = new int[]{2, 3, 5, 11, 10, 2, 11, 4, 8, 11, 8, 9, 0, 3, 10, 3, 6, 7, 11, 1, 4, 2};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
assertEquals(new CNumber(0.21806, 0.5938), A.trace());

// ----------------------- sub-case 4 -----------------------
assertThrows(LinearAlgebraException.class, ()->new CsrCMatrix(12, 4).tr());
}


@Test
void getDiagTests() {
CooCVector exp;

// ----------------------- sub-case 1 -----------------------
aShape = new Shape(6, 6);
aEntries = new CNumber[]{new CNumber(0.16488, 0.46447), new CNumber(0.3774, 0.35246),
new CNumber(0.48798, 0.43226), new CNumber(0.48544, 0.0083), new CNumber(0.77148, 0.92185)};
aRowPointers = new int[]{0, 0, 0, 1, 3, 5, 5};
aColIndices = new int[]{2, 0, 5, 2, 4};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new CVector(new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.16488, 0.46447),
new CNumber(0.0, 0.0), new CNumber(0.77148, 0.92185), new CNumber(0.0, 0.0)).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 2 -----------------------
aShape = new Shape(6, 6);
aEntries = new CNumber[]{new CNumber(0.3583, 0.11616), new CNumber(0.38055, 0.12452),
new CNumber(0.97075, 0.0994), new CNumber(0.43841, 0.0319), new CNumber(0.78954, 0.72668)};
aRowPointers = new int[]{0, 1, 3, 4, 4, 4, 5};
aColIndices = new int[]{4, 0, 3, 5, 2};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new CVector(0, 0, 0, 0, 0, 0).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 3 -----------------------
aShape = new Shape(12, 12);
aEntries = new CNumber[]{new CNumber(0.13994, 0.01355), new CNumber(0.5696, 0.69936),
new CNumber(0.1473, 0.92664), new CNumber(0.40113, 0.88561), new CNumber(0.45177, 0.86077),
new CNumber(0.91248, 0.28304), new CNumber(0.71058, 0.73788), new CNumber(0.04337, 0.79422),
new CNumber(0.26867, 0.54323), new CNumber(0.28884, 0.58676), new CNumber(0.62439, 0.17585),
new CNumber(0.76891, 0.03482), new CNumber(0.02111, 0.58864), new CNumber(0.36376, 0.44496)};
aRowPointers = new int[]{0, 1, 4, 5, 6, 8, 10, 11, 12, 13, 13, 14, 14};
aColIndices = new int[]{7, 3, 5, 8, 6, 3, 6, 7, 3, 5, 11, 10, 8, 11};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new CVector(new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0),
new CNumber(0.91248, 0.28304), new CNumber(0.0, 0.0), new CNumber(0.28884, 0.58676),
new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.02111, 0.58864),
new CNumber(0.0, 0.0), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0)).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 4 -----------------------
aShape = new Shape(3, 7);
aEntries = new CNumber[]{new CNumber(0.85218, 0.08775), new CNumber(0.16499, 0.26153), new CNumber(0.36642, 0.63926), new CNumber(0.63494, 0.79019)};
aRowPointers = new int[]{0, 1, 3, 4};
aColIndices = new int[]{0, 3, 5, 1};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new CVector(new CNumber(0.85218, 0.08775), new CNumber(0.0, 0.0), new CNumber(0.0, 0.0)).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 5 -----------------------
aShape = new Shape(7, 3);
aEntries = new CNumber[]{new CNumber(0.38881, 0.35982), new CNumber(0.82873, 0.85744),
new CNumber(0.22975, 0.7268), new CNumber(0.9345, 0.92117)};
aRowPointers = new int[]{0, 1, 1, 2, 2, 2, 3, 4};
aColIndices = new int[]{0, 2, 2, 0};
A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new CVector(new CNumber(0.38881, 0.35982), new CNumber(0.0, 0.0),
new CNumber(0.82873, 0.85744)).toCoo();
assertEquals(exp, A.getDiag());
}
}
105 changes: 105 additions & 0 deletions src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package org.flag4j.sparse_csr_matrix;

import org.flag4j.arrays.dense.Vector;
import org.flag4j.arrays.sparse.CooVector;
import org.flag4j.arrays.sparse.CsrMatrix;
import org.flag4j.core.Shape;
import org.flag4j.util.exceptions.LinearAlgebraException;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

class CsrMatrixTriDiagTests {

static CsrMatrix A;
static Shape aShape;
static double[] aEntries;
static int[] aRowPointers;
static int[] aColIndices;

@Test
void traceTest() {
// ----------------------- sub-case 1 -----------------------
aShape = new Shape(6, 6);
aEntries = new double[]{0.550651808914819, 0.25248902903617765, 0.7347711759989852, 0.6355561756397741, 0.4229489382733195};
aRowPointers = new int[]{0, 1, 3, 3, 4, 4, 5};
aColIndices = new int[]{1, 0, 3, 4, 2};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
assertEquals(0.0, A.trace());

// ----------------------- sub-case 2 -----------------------
aShape = new Shape(6, 6);
aEntries = new double[]{0.3947350230967598, 0.057362465663433504, 0.9347345786066202, 0.43387856849809536, 0.644454913626892};
aRowPointers = new int[]{0, 2, 2, 2, 3, 3, 5};
aColIndices = new int[]{2, 5, 3, 0, 3};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
assertEquals(0.9347345786066202, A.trace());

// ----------------------- sub-case 3 -----------------------
aShape = new Shape(15, 15);
aEntries = new double[]{0.8086884822224377, 0.8263266272760342, 0.32525649651881317, 0.9608789667749497, 0.4395049603468911,
0.07305897091345792, 0.7678491080726771, 0.2294487507606635, 0.1382399153004038, 0.920903220380491, 0.6895253484430522,
0.9256640817501549, 0.8520473820197568, 0.6642547057142691, 0.9992544305209986, 0.6144764840301362,
0.38299169954163803, 0.04890094352081187, 0.07595056410539092, 0.6778873127860436, 0.5725538366855386,
0.633943368653018};
aRowPointers = new int[]{0, 2, 5, 5, 6, 9, 10, 13, 14, 15, 15, 16, 17, 20, 21, 22};
aColIndices = new int[]{8, 9, 1, 2, 13, 14, 0, 11, 13, 7, 8, 12, 14, 8, 14, 5, 4, 1, 4, 12, 9, 9};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
assertEquals(1.0031438093048568, A.trace());

// ----------------------- sub-case 4 -----------------------
assertThrows(LinearAlgebraException.class, ()->new CsrMatrix(12, 4).tr());
}


@Test
void getDiagTests() {
CooVector exp;

// ----------------------- sub-case 1 -----------------------
aShape = new Shape(6, 6);
aEntries = new double[]{0.2742696420912182, 0.07751690909235676, 0.921830928702064, 0.1935807720906103, 0.7540620527220376};
aRowPointers = new int[]{0, 0, 3, 3, 4, 5, 5};
aColIndices = new int[]{0, 2, 4, 5, 4};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new Vector(0.0, 0.0, 0.0, 0.0, 0.7540620527220376, 0.0).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 2 -----------------------
aShape = new Shape(6, 6);
aEntries = new double[]{0.5117934899539107, 0.5232442605382139, 0.6241020346691114, 0.5149866370978797, 0.9899310737230431, 0.18946387228386474, 0.2824144399493015, 0.25137905573116825, 0.9609185162323945, 0.6746428794940532, 0.47997193451501086};
aRowPointers = new int[]{0, 2, 5, 7, 9, 10, 11};
aColIndices = new int[]{2, 5, 0, 1, 3, 2, 5, 0, 4, 1, 5};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new Vector(0.0, 0.5149866370978797, 0.18946387228386474, 0.0, 0.0, 0.47997193451501086).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 3 -----------------------
aShape = new Shape(12, 12);
aEntries = new double[]{0.989336440083382, 0.5784226184530201, 0.4128615264146236, 0.8720518156451399, 0.015332897118263578, 0.30602701869807314, 0.5445047402785209, 0.6799626535731805, 0.8160519253929756, 0.44554133866387846, 0.5507111109163054, 0.5127437616463539, 0.12630665534801888, 0.8348974142473102};
aRowPointers = new int[]{0, 1, 3, 5, 8, 9, 9, 10, 12, 12, 13, 14, 14};
aColIndices = new int[]{6, 6, 8, 7, 10, 3, 6, 8, 4, 5, 3, 5, 3, 4};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new Vector(0.0, 0.0, 0.0, 0.30602701869807314, 0.8160519253929756, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 4 -----------------------
aShape = new Shape(3, 7);
aEntries = new double[]{0.46451579858863223, 0.761818146777147, 0.4513656938478876, 0.8045385713145264};
aRowPointers = new int[]{0, 0, 2, 4};
aColIndices = new int[]{1, 4, 1, 2};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new Vector(0.0, 0.46451579858863223, 0.8045385713145264).toCoo();
assertEquals(exp, A.getDiag());

// ----------------------- sub-case 5 -----------------------
aShape = new Shape(7, 3);
aEntries = new double[]{0.11343649359086139, 0.6082353444109962, 0.8703305006626817, 0.9546036038577662};
aRowPointers = new int[]{0, 0, 2, 3, 4, 4, 4, 4};
aColIndices = new int[]{1, 2, 2, 0};
A = new CsrMatrix(aShape, aEntries, aRowPointers, aColIndices);
exp = new Vector(0.0, 0.11343649359086139, 0.8703305006626817).toCoo();
assertEquals(exp, A.getDiag());
}
}
Binary file modified target/flag4j-v0.1.0-beta.jar
Binary file not shown.

0 comments on commit 3d4d588

Please sign in to comment.