Skip to content

Commit

Permalink
Add tests for sparse tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdwatters committed Aug 13, 2024
1 parent 6288165 commit c74ca83
Show file tree
Hide file tree
Showing 7 changed files with 863 additions and 17 deletions.
30 changes: 20 additions & 10 deletions src/main/java/org/flag4j/arrays/sparse/CooCTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,21 @@ public CooCTensor reshape(Shape newShape) {
newShape.makeStridesIfNull();

int rank = indices[0].length;
int newRank = newShape.getRank();
int nnz = entries.length;

int[] oldStrides = shape.getStrides();
int[] newStrides = newShape.getStrides();

int[][] newIndices = new int[nnz][rank];
int[][] newIndices = new int[nnz][newRank];

for (int i = 0; i < nnz; i++) {
for(int i=0; i<nnz; i++) {
int flatIndex = 0;
for (int j = 0; j < rank; j++) {
for(int j=0; j < rank; j++) {
flatIndex += indices[i][j] * oldStrides[j];
}

for (int j = 0; j < rank; j++) {
for(int j=0; j<newRank; j++) {
newIndices[i][j] = flatIndex / newStrides[j];
flatIndex %= newStrides[j];
}
Expand Down Expand Up @@ -331,6 +332,7 @@ public CooCTensor flatten(int axis) {

// Compute new shape.
int[] destShape = new int[indices[0].length];
Arrays.fill(destShape, 1);
destShape[axis] = shape.totalEntries().intValueExact();

for(int i=0, size=entries.length; i<size; i++)
Expand Down Expand Up @@ -389,7 +391,7 @@ public CooCTensor T(int axis1, int axis2) {
}

// Create sparse coo tensor and sort values lexicographically.
CooCTensor transpose = new CooCTensor(shape, transposeEntries, transposeIndices);
CooCTensor transpose = new CooCTensor(shape.swapAxes(axis1, axis2), transposeEntries, transposeIndices);
transpose.sortIndices();

return transpose;
Expand Down Expand Up @@ -417,7 +419,7 @@ public CooCTensor T(int... axes) {
CNumber[] transposeEntries = new CNumber[nnz];

// Permute the indices according to the permutation array.
for(int i = 0; i < nnz; i++) {
for(int i=0; i < nnz; i++) {
transposeEntries[i] = entries[i];
transposeIndices[i] = indices[i].clone();

Expand All @@ -427,7 +429,7 @@ public CooCTensor T(int... axes) {
}

// Create sparse coo tensor and sort values lexicographically.
CooCTensor transpose = new CooCTensor(shape, transposeEntries, transposeIndices);
CooCTensor transpose = new CooCTensor(shape.swapAxes(axes), transposeEntries, transposeIndices);
transpose.sortIndices();

return transpose;
Expand Down Expand Up @@ -655,9 +657,17 @@ public CooCTensor transpose() {


/**
* Computes the transpose of a tensor. Same as {@link #transpose()}.
* <p>Computes the transpose of a tensor. Same as {@link #transpose()}.</p>
*
* <p>This method transposes the tensor by exchanges the first and last index
* of the tensor. Thus, for a rank 2 tensor, this method is equivalent to a matrix transpose.</p>
*
* <p>{@link #T(int, int)} and {@link #T(int...)} offer more general tensor transposes.</p>
*
* @return The transpose of this tensor.
* @see #transpose()
* @see #T(int, int)
* @see #T(int...)
*/
@Override
public CooCTensor T() {
Expand Down Expand Up @@ -870,7 +880,7 @@ public CooCTensor H(int axis1, int axis2) {
}

// Create sparse coo tensor and sort values lexicographically.
CooCTensor transpose = new CooCTensor(shape, transposeEntries, transposeIndices);
CooCTensor transpose = new CooCTensor(shape.swapAxes(axis1, axis2), transposeEntries, transposeIndices);
transpose.sortIndices();

return transpose;
Expand Down Expand Up @@ -908,7 +918,7 @@ public CooCTensor H(int... axes) {
}

// Create sparse coo tensor and sort values lexicographically.
CooCTensor transpose = new CooCTensor(shape, transposeEntries, transposeIndices);
CooCTensor transpose = new CooCTensor(shape.swapAxes(axes), transposeEntries, transposeIndices);
transpose.sortIndices();

return transpose;
Expand Down
24 changes: 17 additions & 7 deletions src/main/java/org/flag4j/arrays/sparse/CooTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,25 @@ public CooTensor reshape(Shape newShape) {
newShape.makeStridesIfNull(); // Ensure this shape object has strides computed.

int rank = indices[0].length;
int newRank = newShape.getRank();
int nnz = entries.length;

int[] oldStrides = shape.getStrides();
int[] newStrides = newShape.getStrides();

int[][] newIndices = new int[nnz][rank];
int[][] newIndices = new int[nnz][newRank];

for (int i = 0; i < nnz; i++) {
int flatIndex = 0;
for (int j = 0; j < rank; j++) {
flatIndex += indices[i][j] * oldStrides[j];
}

for (int j = 0; j < rank; j++) {
for (int j = 0; j < newRank; j++) {
int[] arr1 = newIndices[i];
int v1 = newIndices[i][j];
int v2 = newStrides[j];

newIndices[i][j] = flatIndex / newStrides[j];
flatIndex %= newStrides[j];
}
Expand Down Expand Up @@ -337,6 +342,7 @@ public CooTensor flatten(int axis) {

// Compute new shape.
int[] destShape = new int[indices[0].length];
Arrays.fill(destShape, 1);
destShape[axis] = shape.totalEntries().intValueExact();

for(int i=0, size=entries.length; i<size; i++)
Expand Down Expand Up @@ -395,7 +401,7 @@ public CooTensor T(int axis1, int axis2) {
}

// Create sparse coo tensor and sort values lexicographically.
CooTensor transpose = new CooTensor(shape, transposeEntries, transposeIndices);
CooTensor transpose = new CooTensor(shape.swapAxes(axis1, axis2), transposeEntries, transposeIndices);
transpose.sortIndices();

return transpose;
Expand Down Expand Up @@ -423,7 +429,7 @@ public CooTensor T(int... axes) {
double[] transposeEntries = new double[nnz];

// Permute the indices according to the permutation array.
for(int i = 0; i < nnz; i++) {
for(int i=0; i<nnz; i++) {
transposeEntries[i] = entries[i];
transposeIndices[i] = indices[i].clone();

Expand All @@ -433,7 +439,7 @@ public CooTensor T(int... axes) {
}

// Create sparse coo tensor and sort values lexicographically.
CooTensor transpose = new CooTensor(shape, transposeEntries, transposeIndices);
CooTensor transpose = new CooTensor(shape.swapAxes(axes), transposeEntries, transposeIndices);
transpose.sortIndices();

return transpose;
Expand Down Expand Up @@ -636,8 +642,12 @@ public CooTensor transpose() {


/**
* Computes the transpose of a tensor. Same as {@link #transpose()}.
* In the context of a tensor, this exchanges the first and last axis of the tensor.
* <p>Computes the transpose of a tensor. Same as {@link #transpose()}.</p>
*
* <p>This method transposes the tensor by exchanges the first and last index
* of the tensor. Thus, for a rank 2 tensor, this method is equivalent to a matrix transpose.</p>
*
* <p>{@link #T(int, int)} and {@link #T(int...)} offer more general tensor transposes.</p>
*
* @return The transpose of this tensor.
* @see #transpose()
Expand Down
Loading

0 comments on commit c74ca83

Please sign in to comment.