Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PermMatrixCSC #78

Merged
merged 13 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
name: CI
on:
- push
- pull_request
push:
branches:
- master
pull_request:
branches:
- master
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
Expand All @@ -10,7 +14,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
StaticArrays = "1"
julia = "1"
julia = "1.10"
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 2 additions & 0 deletions src/LuxurySparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ using SparseArrays: SparseMatrixCSC
using SparseArrays.HigherOrderFns
using Base: @propagate_inbounds
using LinearAlgebra
import SparseArrays: findnz, nnz
using LinearAlgebra: StructuredMatrixStyle
using Base.Broadcast:
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!

# static types
export SDPermMatrix, SPermMatrix, PermMatrix, pmrand,
SDPermMatrixCSC, SPermMatrixCSC, PermMatrixCSC, pmcscrand,
SDSparseMatrixCSC, SSparseMatrixCSC, SparseMatrixCSC, sprand,
SparseMatrixCOO,
SDMatrix, SDVector,
Expand Down
123 changes: 83 additions & 40 deletions src/PermMatrix.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
abstract type AbstractPermMatrix{Tv, Ti} <: AbstractMatrix{Tv} end
"""
PermMatrix{Tv, Ti}(perm::AbstractVector{Ti}, vals::AbstractVector{Tv}) where {Tv, Ti<:Integer}
PermMatrix(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti}
Expand All @@ -24,7 +25,7 @@
```
"""
struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <:
AbstractMatrix{Tv}
AbstractPermMatrix{Tv,Ti}
perm::Vi # new orders
vals::Vv # multiplied values.

Expand All @@ -42,26 +43,74 @@
new{Tv,Ti,Vv,Vi}(perm, vals)
end
end

function PermMatrix{Tv,Ti}(perm, vals) where {Tv,Ti<:Integer}
PermMatrix{Tv,Ti,Vector{Tv},Vector{Ti}}(Vector{Ti}(perm), Vector{Tv}(vals))
basetype(pm::PermMatrix) = PermMatrix
Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} =
M.perm[i] == j ? M.vals[i] : zero(Tv)
@propagate_inbounds function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer)
@assert M.perm[i] == j "Can not set index due to the absense of entry: ($i, $j)"
@inbounds M.vals[i] = val
end

function PermMatrix(
perm::Vi,
vals::Vv,
) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
PermMatrix{Tv,Ti,Vv,Vi}(perm, vals)
# the column major version of `PermMatrix`
struct PermMatrixCSC{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <:
AbstractPermMatrix{Tv,Ti}
perm::Vi # new orders
vals::Vv # multiplied values.

function PermMatrixCSC{Tv,Ti,Vv,Vi}(
perm::Vi,
vals::Vv,
) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
if length(perm) != length(vals)
throw(
DimensionMismatch(
"permutation ($(length(perm))) and multiply ($(length(vals))) length mismatch.",
),
)
end
new{Tv,Ti,Vv,Vi}(perm, vals)
end
end
basetype(pm::PermMatrixCSC) = PermMatrixCSC
@propagate_inbounds function Base.getindex(M::PermMatrixCSC{Tv}, i::Integer, j::Integer) where {Tv}
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
@boundscheck 0 < j <= size(M, 2)
@inbounds M.perm[j] == i ? M.vals[j] : zero(Tv)
end
@propagate_inbounds function Base.setindex!(M::PermMatrixCSC, val, i::Integer, j::Integer)
@assert M.perm[j] == i "Can not set index due to the absense of entry: ($i, $j)"
@inbounds M.vals[j] = val
end

Base.:(==)(d1::PermMatrix, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.isapprox(d1::PermMatrix, d2::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...)
Base.zero(pm::PermMatrix) = PermMatrix(pm.perm, zero(pm.vals))
for MT in [:PermMatrix, :PermMatrixCSC]
@eval begin
function $MT{Tv,Ti}(perm, vals) where {Tv,Ti<:Integer}
$MT{Tv,Ti,Vector{Tv},Vector{Ti}}(Vector{Ti}(perm), Vector{Tv}(vals))
end

function $MT(
perm::Vi,
vals::Vv,
) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
$MT{Tv,Ti,Vv,Vi}(perm, vals)
end
end
end
Base.zero(pm::AbstractPermMatrix) = basetype(pm)(pm.perm, zero(pm.vals))
Base.similar(x::AbstractPermMatrix{Tv,Ti}) where {Tv,Ti} =
typeof(x)(copy(x.perm), similar(x.vals))
Base.similar(x::AbstractPermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =
basetype(x){T,Ti}(copy(x.perm), similar(x.vals, T))

################# Comparison ##################
Base.:(==)(d1::AbstractPermMatrix, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.isapprox(d1::AbstractPermMatrix, d2::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...)
Base.copyto!(A::AbstractPermMatrix, B::AbstractPermMatrix) =
(copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A)

################# Array Functions ##################

Base.size(M::PermMatrix) = (length(M.perm), length(M.perm))
function Base.size(A::PermMatrix, d::Integer)
Base.size(M::AbstractPermMatrix) = (length(M.perm), length(M.perm))
function Base.size(A::AbstractPermMatrix, d::Integer)
if d < 1
throw(ArgumentError("dimension must be ≥ 1, got $d"))
elseif d <= 2
Expand All @@ -70,18 +119,6 @@
return 1
end
end
Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} =
M.perm[i] == j ? M.vals[i] : zero(Tv)
function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer)
if M.perm[i] == j
@inbounds M.vals[i] = val
else
throw(BoundsError(M, (i, j)))
end
end

Base.copyto!(A::PermMatrix, B::PermMatrix) =
(copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A)

"""
pmrand(T::Type, n::Int) -> PermMatrix
Expand All @@ -105,20 +142,26 @@
pmrand(::Type{T}, n::Int) where {T} = PermMatrix(randperm(n), randn(T, n))
pmrand(n::Int) = pmrand(Float64, n)

Base.similar(x::PermMatrix{Tv,Ti}) where {Tv,Ti} =
PermMatrix{Tv,Ti}(copy(x.perm), similar(x.vals))
Base.similar(x::PermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =
PermMatrix{T,Ti}(copy(x.perm), similar(x.vals, T))

# TODO: rewrite this
# function show(io::IO, M::PermMatrix)
# println("PermMatrix")
# for item in zip(M.perm, M.vals)
# i, p = item
# println("- ($i) * $p")
# end
# end
pmcscrand(::Type{T}, n::Int) where {T} = PermMatrixCSC(randperm(n), randn(T, n))
pmcscrand(n::Int) = pmcscrand(Float64, n)

Base.show(io::IO, ::MIME"text/plain", M::AbstractPermMatrix) = show(io, M)

Check warning on line 148 in src/PermMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/PermMatrix.jl#L148

Added line #L148 was not covered by tests
function Base.show(io::IO, M::AbstractPermMatrix)
n = size(M, 1)
println(io, typeof(M))
nmax = 20
for (k, (i, j, p)) in enumerate(IterNz(M))
if k <= nmax || k > n-nmax
print(io, "($i, $j) = $p")
k < n && println(io)
elseif k == nmax+1
println(io, "...")

Check warning on line 158 in src/PermMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/PermMatrix.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
end
end
end
Base.hash(pm::AbstractPermMatrix) = hash((pm.perm, pm.vals))

######### sparse array interfaces #########
nnz(M::PermMatrix) = length(M.vals)
nnz(M::AbstractPermMatrix) = length(M.vals)
findnz(M::PermMatrix) = (collect(1:size(M, 1)), M.perm, M.vals)
findnz(M::PermMatrixCSC) = (M.perm, collect(1:size(M, 1)), M.vals)
80 changes: 24 additions & 56 deletions src/SSparseMatrixCSC.jl
Original file line number Diff line number Diff line change
@@ -1,62 +1,30 @@
@static if VERSION < v"1.4.0"
"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}

"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}
static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <:
SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <: AbstractSparseMatrix{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end

else
# NOTE: from 1.4.0, by subtyping AbstractSparseMatrixCSC, things like sparse broadcast
# should just work.

"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}

static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <:
SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr
SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval
end # @static
end
SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr
SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval

function SSparseMatrixCSC(
m::Integer,
Expand Down
47 changes: 24 additions & 23 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,33 @@

# PermMatrix
for func in (:conj, :real, :imag)
@eval (Base.$func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals))
@eval (Base.$func)(M::AbstractPermMatrix) = basetype(M)(M.perm, ($func)(M.vals))
end
Base.copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals))
Base.copy(M::AbstractPermMatrix) = basetype(M)(copy(M.perm), copy(M.vals))
Base.conj!(M::AbstractPermMatrix) = (conj!(M.vals); M)

function Base.transpose(M::PermMatrix)
function Base.transpose(M::AbstractPermMatrix)
new_perm = fast_invperm(M.perm)
return PermMatrix(new_perm, M.vals[new_perm])
return basetype(M)(new_perm, M.vals[new_perm])
end

Base.adjoint(S::PermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S))
Base.adjoint(S::AbstractPermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::AbstractPermMatrix{<:Complex}) = conj!(transpose(S))

# scalar
Base.:*(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:*(B::Number, A::IMatrix{T}) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:/(A::IMatrix{T}, B::Number) where {T} =
Diagonal(fill(promote_type(T, eltype(B))(1 / B), A.n))

Base.:*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B)
Base.:*(B::Number, A::PermMatrix) = A * B
Base.:/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B)
Base.:*(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals * B)
Base.:*(B::Number, A::AbstractPermMatrix) = A * B

Check warning on line 32 in src/arraymath.jl

View check run for this annotation

Codecov / codecov/patch

src/arraymath.jl#L32

Added line #L32 was not covered by tests
Base.:/(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals / B)
#+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev)
#-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev)

for op in [:+, :-]
for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval begin
# IMatrix, PermMatrix - SparseMatrixCSC
Base.$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B)
Expand All @@ -45,12 +46,12 @@
# IMatrix, PermMatrix - Diagonal
Base.$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag))
Base.$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2)))
Base.$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2))
Base.$op(d1::AbstractPermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::AbstractPermMatrix) = $op(d1, SparseMatrixCSC(d2))
# PermMatrix - IMatrix
Base.$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::AbstractPermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::AbstractPermMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
end
end
# NOTE: promote to integer
Expand All @@ -59,22 +60,22 @@
Base.:-(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
d1 == d2 ? spzeros(promote_type(Ta, Tb), d1.n, d1.n) : throw(DimensionMismatch())

for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval Base.:(==)(A::$MT, B::SparseMatrixCSC) = SparseMatrixCSC(A) == B
@eval Base.:(==)(A::SparseMatrixCSC, B::$MT) = A == SparseMatrixCSC(B)
end
Base.:(==)(d1::IMatrix, d2::Diagonal) = all(isone, d2.diag)
Base.:(==)(d1::Diagonal, d2::IMatrix) = all(isone, d1.diag)
Base.:(==)(d1::PermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(d1::Diagonal, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(A::IMatrix, B::PermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(A::PermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(d1::AbstractPermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(d1::Diagonal, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(A::IMatrix, B::AbstractPermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(A::AbstractPermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)

for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval Base.isapprox(A::$MT, B::SparseMatrixCSC; kwargs...) = isapprox(SparseMatrixCSC(A), B)
@eval Base.isapprox(A::SparseMatrixCSC, B::$MT; kwargs...) = isapprox(A, SparseMatrixCSC(B))
@eval Base.isapprox(d1::$MT, d2::Diagonal; kwargs...) = isapprox(diag(d1), d2.diag)
@eval Base.isapprox(d1::Diagonal, d2::$MT; kwargs...) = isapprox(d1.diag, diag(d2))
end
Base.isapprox(A::IMatrix, B::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::PermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::IMatrix, B::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::AbstractPermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)

Check warning on line 81 in src/arraymath.jl

View check run for this annotation

Codecov / codecov/patch

src/arraymath.jl#L80-L81

Added lines #L80 - L81 were not covered by tests
Loading
Loading