Skip to content

Commit bb6c8cd

Browse files
authored
Merge pull request #193 from JordiManyer/inplace-assembly
Issue #192: In-place assembly and non-split assembly
2 parents f9bda08 + 0a59925 commit bb6c8cd

File tree

6 files changed

+433
-6
lines changed

6 files changed

+433
-6
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [Unreleased]
9+
10+
### Added
11+
12+
- Added support for in-place assembly of `PSparseMatrix`, in the case where all entries assembled are already in the matrix.
13+
- Added support for regular assembly for non-split `PSparseMatrix`.
814

915
## [0.5.9] - 2025-02-07
1016

1117
### Added
1218

1319
- Zenodo integration.
1420

15-
1621
## [0.5.8] - 2025-02-05
1722

1823
### Added

docs/src/reference/psparsematrix.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ psystem!
3030
```@docs
3131
assemble(::PSparseMatrix,rows)
3232
assemble!(::PSparseMatrix,::PSparseMatrix,cache)
33+
assemble!(::PSparseMatrix)
34+
assemble!(::PSparseMatrix,cache)
3335
consistent(::PSparseMatrix,rows)
3436
consistent!(::PSparseMatrix,::PSparseMatrix,cache)
3537
```

src/p_sparse_matrix.jl

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,30 @@ function assemble!(B::PSparseMatrix,A::PSparseMatrix,cache)
15831583
psparse_assemble_impl!(B,A,T,cache)
15841584
end
15851585

1586+
"""
1587+
assemble!([f,]A::PSparseMatrix;kwargs...)
1588+
"""
1589+
function assemble!(A::PSparseMatrix;kwargs...)
1590+
assemble!(+,A;kwargs...)
1591+
end
1592+
1593+
function assemble!(f,A::PSparseMatrix;kwargs...)
1594+
T = eltype(partition(A))
1595+
psparse_assemble_impl!(f,A,T;kwargs...)
1596+
end
1597+
1598+
"""
1599+
assemble!([f,]A::PSparseMatrix,cache)
1600+
"""
1601+
function assemble!(A::PSparseMatrix,cache)
1602+
assemble!(+,A,cache)
1603+
end
1604+
1605+
function assemble!(f,A::PSparseMatrix,cache)
1606+
T = eltype(partition(A))
1607+
psparse_assemble_impl!(f,A,T,cache)
1608+
end
1609+
15861610
function psparse_assemble_impl(A,::Type,rows;kwargs...)
15871611
error("Case not implemented yet")
15881612
end
@@ -1755,6 +1779,132 @@ function psparse_assemble_impl(
17551779
end
17561780
end
17571781

1782+
function psparse_assemble_impl(
1783+
A,
1784+
::Type{<:AbstractSparseMatrix},
1785+
rows;
1786+
reuse=Val(false),
1787+
assembly_neighbors_options_cols=(;))
1788+
1789+
function setup_cache_snd(A,parts_snd,rows_sa,cols_sa)
1790+
local_to_owner_row = local_to_owner(rows_sa)
1791+
local_to_global_row = local_to_global(rows_sa)
1792+
local_to_global_col = local_to_global(cols_sa)
1793+
me = part_id(rows_sa)
1794+
owner_to_p = Dict(( owner=>i for (i,owner) in enumerate(parts_snd) ))
1795+
ptrs = zeros(Int32,length(parts_snd)+1)
1796+
for (i,_,_) in nziterator(A)
1797+
owner = local_to_owner_row[i]
1798+
if owner != me
1799+
ptrs[owner_to_p[owner]+1] += 1
1800+
end
1801+
end
1802+
length_to_ptrs!(ptrs)
1803+
Tv = eltype(A)
1804+
ndata = ptrs[end]-1
1805+
I_snd_data = zeros(Int,ndata)
1806+
J_snd_data = zeros(Int,ndata)
1807+
V_snd_data = zeros(Tv,ndata)
1808+
k_snd_data = zeros(Int32,ndata)
1809+
for (k,(i,j,v)) in enumerate(nziterator(A))
1810+
owner = local_to_owner_row[i]
1811+
if owner != me
1812+
p = ptrs[owner_to_p[owner]]
1813+
I_snd_data[p] = local_to_global_row[i]
1814+
J_snd_data[p] = local_to_global_col[j]
1815+
V_snd_data[p] = v
1816+
k_snd_data[p] = k
1817+
ptrs[owner_to_p[owner]] += 1
1818+
end
1819+
end
1820+
rewind_ptrs!(ptrs)
1821+
I_snd = JaggedArray(I_snd_data,ptrs)
1822+
J_snd = JaggedArray(J_snd_data,ptrs)
1823+
V_snd = JaggedArray(V_snd_data,ptrs)
1824+
k_snd = JaggedArray(k_snd_data,ptrs)
1825+
(;I_snd,J_snd,V_snd,k_snd,parts_snd)
1826+
end
1827+
function setup_cache_rcv(I_rcv,J_rcv,V_rcv,parts_rcv)
1828+
k_rcv_data = zeros(Int32,length(I_rcv.data))
1829+
k_rcv = JaggedArray(k_rcv_data,I_rcv.ptrs)
1830+
(;I_rcv,J_rcv,V_rcv,k_rcv,parts_rcv)
1831+
end
1832+
function setup_own_triplets(A,cache_rcv,rows_sa,cols_sa)
1833+
local_to_own_rows = local_to_own(rows_sa)
1834+
I_sa, J_sa, V_sa = findnz(A)
1835+
k_own_to_sa = findall(i -> !iszero(local_to_own_rows[i]), I_sa)
1836+
1837+
I_own = view(I_sa,k_own_to_sa)
1838+
I_rcv = cache_rcv.I_rcv.data
1839+
map_global_to_local!(I_rcv,rows_sa)
1840+
I = vcat(I_own,I_rcv)
1841+
1842+
J_own = view(J_sa,k_own_to_sa)
1843+
J_rcv = cache_rcv.J_rcv.data
1844+
map_local_to_global!(J_own,cols_sa)
1845+
J = vcat(J_own,J_rcv)
1846+
1847+
V_own = view(V_sa,k_own_to_sa)
1848+
V_rcv = cache_rcv.V_rcv.data
1849+
V = vcat(V_own,V_rcv)
1850+
1851+
(I,J,V), J, k_own_to_sa
1852+
end
1853+
function finalize_values(A,rows_fa,cols_fa,cache_snd,cache_rcv,triplets,aux)
1854+
I, J, V = triplets
1855+
k_own_to_sa = aux
1856+
I_rcv, J_rcv, k_rcv = cache_rcv.I_rcv.data, cache_rcv.J_rcv.data, cache_rcv.k_rcv.data
1857+
map_global_to_local!(J,cols_fa)
1858+
values = compresscoo(typeof(A),I,J,V,length(rows_fa),length(cols_fa))
1859+
1860+
k_sa = zeros(Int32,nnz(A))
1861+
n_own = length(k_own_to_sa)
1862+
I_own = view(I,1:n_own)
1863+
J_own = view(J,1:n_own)
1864+
k_own = view(k_sa,k_own_to_sa)
1865+
precompute_nzindex!(k_own,values,I_own,J_own)
1866+
1867+
n_tot = length(I)
1868+
I_rcv = view(I,n_own+1:n_tot)
1869+
J_rcv = view(J,n_own+1:n_tot)
1870+
precompute_nzindex!(k_rcv,values,I_rcv,J_rcv)
1871+
1872+
cache = (;k_sa,cache_snd...,cache_rcv...)
1873+
values, cache
1874+
end
1875+
rows_sa = partition(axes(A,1))
1876+
cols_sa = partition(axes(A,2))
1877+
cols = map(remove_ghost,cols_sa)
1878+
parts_snd, parts_rcv = assembly_neighbors(rows_sa)
1879+
cache_snd = map(setup_cache_snd,partition(A),parts_snd,rows_sa,cols_sa)
1880+
I_snd = map(i->i.I_snd,cache_snd)
1881+
J_snd = map(i->i.J_snd,cache_snd)
1882+
V_snd = map(i->i.V_snd,cache_snd)
1883+
graph = ExchangeGraph(parts_snd,parts_rcv)
1884+
t_I = exchange(I_snd,graph)
1885+
t_J = exchange(J_snd,graph)
1886+
t_V = exchange(V_snd,graph)
1887+
@fake_async begin
1888+
I_rcv = fetch(t_I)
1889+
J_rcv = fetch(t_J)
1890+
V_rcv = fetch(t_V)
1891+
cache_rcv = map(setup_cache_rcv,I_rcv,J_rcv,V_rcv,parts_rcv)
1892+
triplets,J,aux = map(setup_own_triplets,partition(A),cache_rcv,rows_sa,cols_sa) |> tuple_of_arrays
1893+
J_owner = find_owner(cols_sa,J)
1894+
rows_fa = rows
1895+
cols_fa = map(union_ghost,cols,J,J_owner)
1896+
assembly_neighbors(cols_fa;assembly_neighbors_options_cols...)
1897+
vals_fa, cache = map(finalize_values,partition(A),rows_fa,cols_fa,cache_snd,cache_rcv,triplets,aux) |> tuple_of_arrays
1898+
assembled = true
1899+
B = PSparseMatrix(vals_fa,rows_fa,cols_fa,assembled)
1900+
if !val_parameter(reuse)
1901+
B
1902+
else
1903+
B, cache
1904+
end
1905+
end
1906+
end
1907+
17581908
function psparse_assemble_impl!(B,A,::Type,cache)
17591909
error("case not implemented")
17601910
end
@@ -1815,6 +1965,177 @@ function psparse_assemble_impl!(B,A,::Type{<:AbstractSplitMatrix},cache)
18151965
end
18161966
end
18171967

1968+
function psparse_assemble_impl!(B,A,::Type{<:AbstractSparseMatrix},cache)
1969+
function setup_snd(A,cache)
1970+
V_snd = cache.V_snd.data
1971+
k_snd = cache.k_snd.data
1972+
nz = nonzeros(A)
1973+
for p in eachindex(k_snd)
1974+
k = k_snd[p]
1975+
V_snd[p] = nz[k]
1976+
end
1977+
end
1978+
function setup_sa(B,A,cache)
1979+
setcoofast!(B,nonzeros(A),cache.k_sa)
1980+
end
1981+
function setup_rcv(B,cache)
1982+
V_rcv = cache.V_rcv.data
1983+
k_rcv = cache.k_rcv.data
1984+
nz = nonzeros(B)
1985+
for p in eachindex(k_rcv)
1986+
k = k_rcv[p]
1987+
nz[k] += V_rcv[p]
1988+
end
1989+
end
1990+
map(setup_snd,partition(A),cache)
1991+
parts_snd = map(i->i.parts_snd,cache)
1992+
parts_rcv = map(i->i.parts_rcv,cache)
1993+
V_snd = map(i->i.V_snd,cache)
1994+
V_rcv = map(i->i.V_rcv,cache)
1995+
graph = ExchangeGraph(parts_snd,parts_rcv)
1996+
t = exchange!(V_rcv,V_snd,graph)
1997+
map(setup_sa,partition(B),partition(A),cache)
1998+
@fake_async begin
1999+
wait(t)
2000+
map(setup_rcv,partition(B),cache)
2001+
B
2002+
end
2003+
end
2004+
2005+
function psparse_assemble_impl!(f,A,::Type;kwargs...)
2006+
error("case not implemented")
2007+
end
2008+
2009+
function psparse_assemble_impl!(
2010+
f,
2011+
A,
2012+
::Type{<:AbstractSparseMatrix};
2013+
reuse=Val(false))
2014+
2015+
function setup_cache_snd(A,parts_snd,rows,cols)
2016+
local_to_owner_row = local_to_owner(rows)
2017+
local_to_global_row = local_to_global(rows)
2018+
local_to_global_col = local_to_global(cols)
2019+
me = part_id(rows)
2020+
owner_to_p = Dict(( owner=>i for (i,owner) in enumerate(parts_snd) ))
2021+
ptrs = zeros(Int32,length(parts_snd)+1)
2022+
for (i,_,_) in nziterator(A)
2023+
owner = local_to_owner_row[i]
2024+
if owner != me
2025+
ptrs[owner_to_p[owner]+1] += 1
2026+
end
2027+
end
2028+
length_to_ptrs!(ptrs)
2029+
Tv = eltype(A)
2030+
ndata = ptrs[end]-1
2031+
I_snd_data = zeros(Int,ndata)
2032+
J_snd_data = zeros(Int,ndata)
2033+
V_snd_data = zeros(Tv,ndata)
2034+
k_snd_data = zeros(Int32,ndata)
2035+
for (k,(i,j,v)) in enumerate(nziterator(A))
2036+
owner = local_to_owner_row[i]
2037+
if owner != me
2038+
p = ptrs[owner_to_p[owner]]
2039+
I_snd_data[p] = local_to_global_row[i]
2040+
J_snd_data[p] = local_to_global_col[j]
2041+
V_snd_data[p] = v
2042+
k_snd_data[p] = k
2043+
ptrs[owner_to_p[owner]] += 1
2044+
end
2045+
end
2046+
rewind_ptrs!(ptrs)
2047+
I_snd = JaggedArray(I_snd_data,ptrs)
2048+
J_snd = JaggedArray(J_snd_data,ptrs)
2049+
V_snd = JaggedArray(V_snd_data,ptrs)
2050+
k_snd = JaggedArray(k_snd_data,ptrs)
2051+
(;I_snd,J_snd,V_snd,k_snd,parts_snd)
2052+
end
2053+
function setup_cache_rcv(I_rcv,J_rcv,V_rcv,parts_rcv)
2054+
k_rcv_data = zeros(Int32,length(I_rcv.data))
2055+
k_rcv = JaggedArray(k_rcv_data,I_rcv.ptrs)
2056+
(;I_rcv,J_rcv,V_rcv,k_rcv,parts_rcv)
2057+
end
2058+
function finalize_values!(A,rows,cols,cache_snd,cache_rcv)
2059+
I_rcv_data = cache_rcv.I_rcv.data
2060+
J_rcv_data = cache_rcv.J_rcv.data
2061+
V_rcv_data = cache_rcv.V_rcv.data
2062+
k_rcv_data = cache_rcv.k_rcv.data
2063+
A_nonzeros = nonzeros(A)
2064+
map_global_to_local!(I_rcv_data,rows)
2065+
map_global_to_local!(J_rcv_data,cols)
2066+
for p in eachindex(k_rcv_data)
2067+
i = I_rcv_data[p]
2068+
j = J_rcv_data[p]
2069+
k = nzindex(A,i,j)
2070+
@boundscheck @assert k > 0 "The sparsity pattern of the ghost layer is inconsistent"
2071+
k_rcv_data[p] = k
2072+
A_nonzeros[k] = f(A_nonzeros[k],V_rcv_data[p])
2073+
end
2074+
cache = (;cache_snd...,cache_rcv...)
2075+
cache
2076+
end
2077+
rows = partition(axes(A,1))
2078+
cols = partition(axes(A,2))
2079+
parts_snd, parts_rcv = assembly_neighbors(rows)
2080+
cache_snd = map(setup_cache_snd,partition(A),parts_snd,rows,cols)
2081+
I_snd = map(i->i.I_snd,cache_snd)
2082+
J_snd = map(i->i.J_snd,cache_snd)
2083+
V_snd = map(i->i.V_snd,cache_snd)
2084+
graph = ExchangeGraph(parts_snd,parts_rcv)
2085+
t_I = exchange(I_snd,graph)
2086+
t_J = exchange(J_snd,graph)
2087+
t_V = exchange(V_snd,graph)
2088+
@fake_async begin
2089+
I_rcv = fetch(t_I)
2090+
J_rcv = fetch(t_J)
2091+
V_rcv = fetch(t_V)
2092+
cache_rcv = map(setup_cache_rcv,I_rcv,J_rcv,V_rcv,parts_rcv)
2093+
cache = map(finalize_values!,partition(A),rows,cols,cache_snd,cache_rcv)
2094+
if !val_parameter(reuse)
2095+
A
2096+
else
2097+
A, cache
2098+
end
2099+
end
2100+
end
2101+
2102+
function psparse_assemble_impl!(f::Function,A,::Type,cache)
2103+
error("case not implemented")
2104+
end
2105+
2106+
function psparse_assemble_impl!(f::Function,A,::Type{<:AbstractSparseMatrix},cache)
2107+
function setup_snd(A,cache)
2108+
V_snd_data = cache.V_snd.data
2109+
k_snd_data = cache.k_snd.data
2110+
A_nonzeros = nonzeros(A)
2111+
for p in eachindex(k_snd_data)
2112+
k = k_snd_data[p]
2113+
V_snd_data[p] = A_nonzeros[k]
2114+
end
2115+
end
2116+
function setup_rcv(A,cache)
2117+
V_rcv_data = cache.V_rcv.data
2118+
k_rcv_data = cache.k_rcv.data
2119+
A_nonzeros = nonzeros(A)
2120+
for p in eachindex(k_rcv_data)
2121+
k = k_rcv_data[p]
2122+
A_nonzeros[k] = f(A_nonzeros[k],V_rcv_data[p])
2123+
end
2124+
end
2125+
map(setup_snd,partition(A),cache)
2126+
parts_snd = map(i->i.parts_snd,cache)
2127+
parts_rcv = map(i->i.parts_rcv,cache)
2128+
V_snd = map(i->i.V_snd,cache)
2129+
V_rcv = map(i->i.V_rcv,cache)
2130+
graph = ExchangeGraph(parts_snd,parts_rcv)
2131+
t = exchange!(V_rcv,V_snd,graph)
2132+
@fake_async begin
2133+
wait(t)
2134+
map(setup_rcv,partition(A),cache)
2135+
A
2136+
end
2137+
end
2138+
18182139
"""
18192140
consistent(A::PSparseMatrix,rows;kwargs...)
18202141
"""

0 commit comments

Comments
 (0)