Skip to content

Commit

Permalink
Add various matrix norms (#2)
Browse files Browse the repository at this point in the history
* feat: add functions for dim-wise sum

* feat: add diverse matrix norm

* docs: typo

* zig fmt

---------

Co-authored-by: Adrià Arrufat <[email protected]>
  • Loading branch information
cih9088 and arrufat authored Apr 22, 2024
1 parent 897ed2a commit 8a3c350
Showing 1 changed file with 98 additions and 7 deletions.
105 changes: 98 additions & 7 deletions src/matrix.zig
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ pub fn Matrix(comptime T: type, comptime rows: usize, comptime cols: usize) type
return self.items[0][0];
}

/// Sums all the elements in a matrix
/// Sums all the elements in a matrix.
pub fn sum(self: Self) T {
var accum: T = 0;
for (self.items) |row| {
Expand All @@ -264,15 +264,83 @@ pub fn Matrix(comptime T: type, comptime rows: usize, comptime cols: usize) type
return accum;
}

/// Computes the norm of the matrix as the square root of the sum of its squared values.
pub fn norm(self: Self) T {
var sum_sq: T = 0;
/// Sums all the elements in columns.
pub fn sumCols(self: Self) Matrix(T, rows, 1) {
var result = Matrix(T, rows, 1).initAll(0);
for (0..self.rows) |r| {
for (0..self.cols) |c| {
result.items[r][0] += self.items[r][c];
}
}
return result;
}

/// Sums all the elements in rows.
pub fn sumRows(self: Self) Matrix(T, 1, cols) {
var result = Matrix(T, 1, cols).initAll(0);
for (0..self.rows) |r| {
for (0..self.cols) |c| {
result.items[0][c] += self.items[r][c];
}
}
return result;
}

/// Computes the "element-wise" matrix norm of the matrix.
pub fn norm(self: Self, p: T) T {
assert(p >= 1);
if (p == std.math.inf(T)) {
return self.maxNorm();
} else if (p == -std.math.inf(T)) {
return self.minNorm();
} else {
var result: T = 0;
for (self.items) |row| {
for (row) |col| {
result += std.math.pow(T, @abs(col), p);
}
}
result = std.math.pow(T, result, (1 / p));
return result;
}
}

/// Computes the Frobenius norm of the matrix as the square root of the sum of its squared values.
pub fn frobeniusNorm(self: Self) T {
return self.norm(2);
}

/// Computes the Nuclear norm of the matrix as the sum of its absolute values.
pub fn nuclearNorm(self: Self) T {
return self.norm(1);
}

/// Computes the Max norm of the matrix as the maximum absolute value.
pub fn maxNorm(self: Self) T {
var result: T = -std.math.inf(T);
for (self.items) |row| {
for (row) |col| {
sum_sq += col * col;
const val = @abs(col);
if (val > result) {
result = val;
}
}
}
return @sqrt(sum_sq);
return result;
}

/// Computes the Min norm of the matrix as the minimum absolute value.
pub fn minNorm(self: Self) T {
var result: T = std.math.inf(T);
for (self.items) |row| {
for (row) |col| {
const val = @abs(col);
if (val < result) {
result = val;
}
}
}
return result;
}
};
}
Expand Down Expand Up @@ -338,5 +406,28 @@ test "apply" {

test "norm" {
var matrix = Matrix(f32, 3, 4).random(null);
try expectEqual(matrix.norm(), @sqrt(matrix.times(matrix).sum()));
try expectEqual(matrix.frobeniusNorm(), @sqrt(matrix.times(matrix).sum()));

const f = struct {
fn f(x: f32) f32 {
return @abs(x);
}
}.f;
try expectEqual(matrix.nuclearNorm(), matrix.apply(f).sum());

matrix.set(2, 3, 1000000);
try expectEqual(matrix.maxNorm(), 1000000);

matrix = matrix.offset(10);
matrix.set(2, 3, -5);
try expectEqual(matrix.minNorm(), 5);
}

test "sum" {
var matrix = Matrix(f32, 3, 4).initAll(1);
const matrixSumCols = Matrix(f32, 3, 1).initAll(4);
const matrixSumRows = Matrix(f32, 1, 4).initAll(3);
try expectEqual(matrix.sumRows(), matrixSumRows);
try expectEqual(matrix.sumCols(), matrixSumCols);
try expectEqual(matrix.sumCols().sumRows().item(), matrix.sum());
}

0 comments on commit 8a3c350

Please sign in to comment.