diff --git a/src/ParMatrix.jl b/src/ParMatrix.jl index 193475e..3a19faf 100644 --- a/src/ParMatrix.jl +++ b/src/ParMatrix.jl @@ -64,3 +64,26 @@ function from_Dict(::Type{ParMatrix}, d) end ParMatrix(dtype, d["m"], d["n"], mid) end + +function Base.getindex(A::ParMatrix{T}, rows, cols) where T + row_range = isa(rows, Colon) ? (1:Range(A)) : (isa(rows, Integer) ? (rows:rows) : rows) + col_range = isa(cols, Colon) ? (1:Domain(A)) : (isa(rows, Integer) ? (cols:cols) : cols) + + new_m = length(row_range) + new_n = length(col_range) + + return ParMatrix(T, new_m, new_n) +end + +function Base.getindex(A::ParParameterized{T,T,Linear,ParMatrix{T},V}, rows, cols) where {T,V} + row_range = isa(rows, Colon) ? (1:Range(A)) : (isa(rows, Integer) ? (rows:rows) : rows) + col_range = isa(cols, Colon) ? (1:Domain(A)) : (isa(rows, Integer) ? (cols:cols) : cols) + + new_m = length(row_range) + new_n = length(col_range) + + new_params = reshape(A.params[rows, cols], new_m, new_n) + new_matrix = ParMatrix(T, new_m, new_n) + + return ParParameterized(new_matrix, new_params) +end