Skip to content

Commit

Permalink
added dispatch rules for indexing into ParMatrix (#8)
Browse files Browse the repository at this point in the history
Co-authored-by: turquoisedragon2926 <[email protected]>
  • Loading branch information
turquoisedragon2926 and Richard2926 authored Feb 1, 2024
1 parent d879085 commit 4e90ece
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/ParMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,26 @@ function from_Dict(::Type{ParMatrix}, d)
end
ParMatrix(dtype, d["m"], d["n"], mid)
end

function Base.getindex(A::ParMatrix{T}, rows, cols) where T
row_range = isa(rows, Colon) ? (1:Range(A)) : (isa(rows, Integer) ? (rows:rows) : rows)
col_range = isa(cols, Colon) ? (1:Domain(A)) : (isa(rows, Integer) ? (cols:cols) : cols)

new_m = length(row_range)
new_n = length(col_range)

return ParMatrix(T, new_m, new_n)
end

function Base.getindex(A::ParParameterized{T,T,Linear,ParMatrix{T},V}, rows, cols) where {T,V}
row_range = isa(rows, Colon) ? (1:Range(A)) : (isa(rows, Integer) ? (rows:rows) : rows)
col_range = isa(cols, Colon) ? (1:Domain(A)) : (isa(rows, Integer) ? (cols:cols) : cols)

new_m = length(row_range)
new_n = length(col_range)

new_params = reshape(A.params[rows, cols], new_m, new_n)
new_matrix = ParMatrix(T, new_m, new_n)

return ParParameterized(new_matrix, new_params)
end

0 comments on commit 4e90ece

Please sign in to comment.