Skip to content

Commit

Permalink
Add range-based Matrix subscripts (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgromov authored May 24, 2020
1 parent 0cf733d commit c69bf39
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
54 changes: 53 additions & 1 deletion Example/Tests/MatrixTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,59 @@ class MatrixSpec: QuickSpec {
expect(m1[1, 0]) == 30.0
}
}


describe("Matrix range-based subscript") {
it("[i,j] -> Matrix") {
let m1 = Matrix([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
let m2 = Matrix([[10.0, 11.0],
[12.0, 13.0]])
expect(m1[0, 1].flat) == [2.0, 3.0, 5.0, 6.0, 8.0, 9.0]
m1[1, 0] = m2
expect(m1[1..<3, 0...1].flat) == [10.0, 11.0, 12.0, 13.0]
}
it("closed") {
let m1 = Matrix([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
let m2 = Matrix([[10.0, 11.0],
[12.0, 13.0]])
expect(m1[1...2, 0...1].flat) == [4.0, 5.0, 7.0, 8.0]
m1[0...1, 1...2] = m2
expect(m1.flat) == [1.0, 10.0, 11.0, 4.0, 12.0, 13.0, 7.0, 8.0, 9.0]
}
it("open") {
let m1 = Matrix([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
let m2 = Matrix([[10.0, 11.0],
[12.0, 13.0]])
expect(m1[1..<3, 0..<2].flat) == [4.0, 5.0, 7.0, 8.0]
m1[0..<2, 1..<3] = m2
expect(m1.flat) == [1.0, 10.0, 11.0, 4.0, 12.0, 13.0, 7.0, 8.0, 9.0]
}
it("partial") {
let m1 = Matrix([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
let m2 = Matrix([[10.0, 11.0],
[12.0, 13.0]])
expect(m1[1..., ..<2].flat) == [4.0, 5.0, 7.0, 8.0]
m1[1..., 1...] = m2
expect(m1.flat) == [1.0, 2.0, 3.0, 4.0, 10.0, 11.0, 7.0, 12.0, 13.0]
}
it("unbounded") {
let m1 = Matrix([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
let m2 = Matrix([[10.0, 11.0, 12.0]])
expect(m1[..., 2...2].flat) == [3.0, 6.0, 9.0]
m1[1...1, ...] = m2
expect(m1.flat) == [1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 7.0, 8.0, 9.0]
}
}

describe("Matrix map/reduce") {
it("map") {
let m1 = Matrix([[1.0, 2.0], [3.0, 4.0]])
Expand Down
123 changes: 123 additions & 0 deletions Sources/Matrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,129 @@ extension Matrix {
}
}
}

/// Get and set M(row, col) submatrix of Matrix.
///
/// The range-based subscript methods for getting and setting submatricies.
///
/// var M = Matrix([[1, 2, 3, 4],
/// [5, 6, 7, 8],
/// [9, 10, 11, 12])
///
/// var K = Matrix([[1, 0],
/// [0, 1])
///
/// - Using bounded ranges, including partial (e.g. `..<3`):
///
/// M[1..<3, 0..1] = K
/// // M is now:
/// // [[1, 2, 3, 4],
/// // [1, 0, 7, 8],
/// // [0, 1, 11, 12]]
///
/// K = M[0...1, 2...]
/// // K is now:
/// // [[3, 4]
/// // [7, 8]]
///
/// - Using unbounded ranges:
///
/// K = M[..., 1..2]
/// // K is now:
/// // [[ 2, 3],
/// // [ 6, 7],
/// // [10, 11]]
///
/// - Parameters:
/// - row: Range for rows (0-based)
/// - col: Range for cols (0-based)
///
/// - Returns: submatrix of size `row.count` by `col.count`
///
public subscript<A: RangeExpression, B: RangeExpression>(_ row: A, _ col: B) -> Matrix where A.Bound == Int, B.Bound == Int {
get {
return self[ClosedRange<Int>(row.relative(to: self[row: 0])), ClosedRange<Int>(col.relative(to: self[col: 0]))]
}

set {
self[ClosedRange<Int>(row.relative(to: self[row: 0])), ClosedRange<Int>(col.relative(to: self[col: 0]))] = newValue
}
}

public subscript<B: RangeExpression>(_ : UnboundedRange, _ col: B) -> Matrix where B.Bound == Int {
get { return self[0..<rows, ClosedRange<Int>(col.relative(to: self[col: 0]))] }
set { self[0..<rows, ClosedRange<Int>(col.relative(to: self[col: 0]))] = newValue }
}

public subscript<A: RangeExpression>(_ row: A, _ : UnboundedRange) -> Matrix where A.Bound == Int {
get { return self[ClosedRange<Int>(row.relative(to: self[row: 0])), 0..<cols] }
set { self[ClosedRange<Int>(row.relative(to: self[row: 0])), 0..<cols] = newValue }
}

public subscript(_ : UnboundedRange, _ : UnboundedRange) -> Matrix {
get { return self}
set { self[0..<rows, 0..<cols] = newValue}
}

/// Get and set submatrix with a top-left corner at (`row`, `col`).
///
/// The method allows to get and set a submatrix using just the coordinates of its top-left corner.
///
/// var M = Matrix([[1, 2, 3, 4],
/// [5, 6, 7, 8],
/// [9, 10, 11, 12])
///
/// var K = Matrix([[1, 0],
/// [0, 1])
///
/// M[1, 2] = K
/// // M is now:
/// // [[1, 2, 3, 4],
/// // [5, 6, 1, 0],
/// // [9, 10, 0, 1]]
///
/// - Warning: Getter and setter are using submatrices of different dimensions.
///
/// - Parameters:
/// - row: row index (0-based)
/// - col: column index (0-based)
/// - Returns: Submatrix equivalent to `M[row..., col...]`
///
public subscript(_ row: Int, _ col: Int) -> Matrix {
get { return self[row..., col...]}
set { self[row..<(row + newValue.rows), col..<(col + newValue.cols)] = newValue}
}

public subscript(_ row: ClosedRange<Int>, _ col: ClosedRange<Int>) -> Matrix {
get {
precondition(indexIsValidForRow(row.lowerBound, col.lowerBound), "Invalid range")
precondition(indexIsValidForRow(row.upperBound, col.upperBound), "Invalid range")

let dst = Matrix(row.count, col.count)

flat.withUnsafeBufferPointer { srcBuf in
let srcPtr = srcBuf.baseAddress! + row.lowerBound * rows + col.lowerBound
vDSP_mmovD(srcPtr, &dst.flat,
vDSP_Length(col.count), vDSP_Length(row.count),
vDSP_Length(cols), vDSP_Length(col.count))
}

return dst
}

set {
precondition(indexIsValidForRow(row.lowerBound, col.lowerBound), "Invalid range")
precondition(indexIsValidForRow(row.upperBound, col.upperBound), "Invalid range")
precondition(newValue.cols == col.count && newValue.rows == row.count, "Matrix dimensions must agree")

flat.withUnsafeMutableBufferPointer { dstBuf in
let dstPtr = dstBuf.baseAddress! + row.lowerBound * rows + col.lowerBound
vDSP_mmovD(newValue.flat, dstPtr,
vDSP_Length(col.count), vDSP_Length(row.count),
vDSP_Length(newValue.cols), vDSP_Length(cols))
}
}
}

/// Construct new matrix from source using specified extractor.
///
Expand Down

0 comments on commit c69bf39

Please sign in to comment.