@@ -1583,6 +1583,30 @@ function assemble!(B::PSparseMatrix,A::PSparseMatrix,cache)
1583
1583
psparse_assemble_impl! (B,A,T,cache)
1584
1584
end
1585
1585
1586
+ """
1587
+ assemble!([f,]A::PSparseMatrix;kwargs...)
1588
+ """
1589
+ function assemble! (A:: PSparseMatrix ;kwargs... )
1590
+ assemble! (+ ,A;kwargs... )
1591
+ end
1592
+
1593
+ function assemble! (f,A:: PSparseMatrix ;kwargs... )
1594
+ T = eltype (partition (A))
1595
+ psparse_assemble_impl! (f,A,T;kwargs... )
1596
+ end
1597
+
1598
+ """
1599
+ assemble!([f,]A::PSparseMatrix,cache)
1600
+ """
1601
+ function assemble! (A:: PSparseMatrix ,cache)
1602
+ assemble! (+ ,A,cache)
1603
+ end
1604
+
1605
+ function assemble! (f,A:: PSparseMatrix ,cache)
1606
+ T = eltype (partition (A))
1607
+ psparse_assemble_impl! (f,A,T,cache)
1608
+ end
1609
+
1586
1610
function psparse_assemble_impl (A,:: Type ,rows;kwargs... )
1587
1611
error (" Case not implemented yet" )
1588
1612
end
@@ -1755,6 +1779,132 @@ function psparse_assemble_impl(
1755
1779
end
1756
1780
end
1757
1781
1782
+ function psparse_assemble_impl (
1783
+ A,
1784
+ :: Type{<:AbstractSparseMatrix} ,
1785
+ rows;
1786
+ reuse= Val (false ),
1787
+ assembly_neighbors_options_cols= (;))
1788
+
1789
+ function setup_cache_snd (A,parts_snd,rows_sa,cols_sa)
1790
+ local_to_owner_row = local_to_owner (rows_sa)
1791
+ local_to_global_row = local_to_global (rows_sa)
1792
+ local_to_global_col = local_to_global (cols_sa)
1793
+ me = part_id (rows_sa)
1794
+ owner_to_p = Dict (( owner=> i for (i,owner) in enumerate (parts_snd) ))
1795
+ ptrs = zeros (Int32,length (parts_snd)+ 1 )
1796
+ for (i,_,_) in nziterator (A)
1797
+ owner = local_to_owner_row[i]
1798
+ if owner != me
1799
+ ptrs[owner_to_p[owner]+ 1 ] += 1
1800
+ end
1801
+ end
1802
+ length_to_ptrs! (ptrs)
1803
+ Tv = eltype (A)
1804
+ ndata = ptrs[end ]- 1
1805
+ I_snd_data = zeros (Int,ndata)
1806
+ J_snd_data = zeros (Int,ndata)
1807
+ V_snd_data = zeros (Tv,ndata)
1808
+ k_snd_data = zeros (Int32,ndata)
1809
+ for (k,(i,j,v)) in enumerate (nziterator (A))
1810
+ owner = local_to_owner_row[i]
1811
+ if owner != me
1812
+ p = ptrs[owner_to_p[owner]]
1813
+ I_snd_data[p] = local_to_global_row[i]
1814
+ J_snd_data[p] = local_to_global_col[j]
1815
+ V_snd_data[p] = v
1816
+ k_snd_data[p] = k
1817
+ ptrs[owner_to_p[owner]] += 1
1818
+ end
1819
+ end
1820
+ rewind_ptrs! (ptrs)
1821
+ I_snd = JaggedArray (I_snd_data,ptrs)
1822
+ J_snd = JaggedArray (J_snd_data,ptrs)
1823
+ V_snd = JaggedArray (V_snd_data,ptrs)
1824
+ k_snd = JaggedArray (k_snd_data,ptrs)
1825
+ (;I_snd,J_snd,V_snd,k_snd,parts_snd)
1826
+ end
1827
+ function setup_cache_rcv (I_rcv,J_rcv,V_rcv,parts_rcv)
1828
+ k_rcv_data = zeros (Int32,length (I_rcv. data))
1829
+ k_rcv = JaggedArray (k_rcv_data,I_rcv. ptrs)
1830
+ (;I_rcv,J_rcv,V_rcv,k_rcv,parts_rcv)
1831
+ end
1832
+ function setup_own_triplets (A,cache_rcv,rows_sa,cols_sa)
1833
+ local_to_own_rows = local_to_own (rows_sa)
1834
+ I_sa, J_sa, V_sa = findnz (A)
1835
+ k_own_to_sa = findall (i -> ! iszero (local_to_own_rows[i]), I_sa)
1836
+
1837
+ I_own = view (I_sa,k_own_to_sa)
1838
+ I_rcv = cache_rcv. I_rcv. data
1839
+ map_global_to_local! (I_rcv,rows_sa)
1840
+ I = vcat (I_own,I_rcv)
1841
+
1842
+ J_own = view (J_sa,k_own_to_sa)
1843
+ J_rcv = cache_rcv. J_rcv. data
1844
+ map_local_to_global! (J_own,cols_sa)
1845
+ J = vcat (J_own,J_rcv)
1846
+
1847
+ V_own = view (V_sa,k_own_to_sa)
1848
+ V_rcv = cache_rcv. V_rcv. data
1849
+ V = vcat (V_own,V_rcv)
1850
+
1851
+ (I,J,V), J, k_own_to_sa
1852
+ end
1853
+ function finalize_values (A,rows_fa,cols_fa,cache_snd,cache_rcv,triplets,aux)
1854
+ I, J, V = triplets
1855
+ k_own_to_sa = aux
1856
+ I_rcv, J_rcv, k_rcv = cache_rcv. I_rcv. data, cache_rcv. J_rcv. data, cache_rcv. k_rcv. data
1857
+ map_global_to_local! (J,cols_fa)
1858
+ values = compresscoo (typeof (A),I,J,V,length (rows_fa),length (cols_fa))
1859
+
1860
+ k_sa = zeros (Int32,nnz (A))
1861
+ n_own = length (k_own_to_sa)
1862
+ I_own = view (I,1 : n_own)
1863
+ J_own = view (J,1 : n_own)
1864
+ k_own = view (k_sa,k_own_to_sa)
1865
+ precompute_nzindex! (k_own,values,I_own,J_own)
1866
+
1867
+ n_tot = length (I)
1868
+ I_rcv = view (I,n_own+ 1 : n_tot)
1869
+ J_rcv = view (J,n_own+ 1 : n_tot)
1870
+ precompute_nzindex! (k_rcv,values,I_rcv,J_rcv)
1871
+
1872
+ cache = (;k_sa,cache_snd... ,cache_rcv... )
1873
+ values, cache
1874
+ end
1875
+ rows_sa = partition (axes (A,1 ))
1876
+ cols_sa = partition (axes (A,2 ))
1877
+ cols = map (remove_ghost,cols_sa)
1878
+ parts_snd, parts_rcv = assembly_neighbors (rows_sa)
1879
+ cache_snd = map (setup_cache_snd,partition (A),parts_snd,rows_sa,cols_sa)
1880
+ I_snd = map (i-> i. I_snd,cache_snd)
1881
+ J_snd = map (i-> i. J_snd,cache_snd)
1882
+ V_snd = map (i-> i. V_snd,cache_snd)
1883
+ graph = ExchangeGraph (parts_snd,parts_rcv)
1884
+ t_I = exchange (I_snd,graph)
1885
+ t_J = exchange (J_snd,graph)
1886
+ t_V = exchange (V_snd,graph)
1887
+ @fake_async begin
1888
+ I_rcv = fetch (t_I)
1889
+ J_rcv = fetch (t_J)
1890
+ V_rcv = fetch (t_V)
1891
+ cache_rcv = map (setup_cache_rcv,I_rcv,J_rcv,V_rcv,parts_rcv)
1892
+ triplets,J,aux = map (setup_own_triplets,partition (A),cache_rcv,rows_sa,cols_sa) |> tuple_of_arrays
1893
+ J_owner = find_owner (cols_sa,J)
1894
+ rows_fa = rows
1895
+ cols_fa = map (union_ghost,cols,J,J_owner)
1896
+ assembly_neighbors (cols_fa;assembly_neighbors_options_cols... )
1897
+ vals_fa, cache = map (finalize_values,partition (A),rows_fa,cols_fa,cache_snd,cache_rcv,triplets,aux) |> tuple_of_arrays
1898
+ assembled = true
1899
+ B = PSparseMatrix (vals_fa,rows_fa,cols_fa,assembled)
1900
+ if ! val_parameter (reuse)
1901
+ B
1902
+ else
1903
+ B, cache
1904
+ end
1905
+ end
1906
+ end
1907
+
1758
1908
function psparse_assemble_impl! (B,A,:: Type ,cache)
1759
1909
error (" case not implemented" )
1760
1910
end
@@ -1815,6 +1965,177 @@ function psparse_assemble_impl!(B,A,::Type{<:AbstractSplitMatrix},cache)
1815
1965
end
1816
1966
end
1817
1967
1968
+ function psparse_assemble_impl! (B,A,:: Type{<:AbstractSparseMatrix} ,cache)
1969
+ function setup_snd (A,cache)
1970
+ V_snd = cache. V_snd. data
1971
+ k_snd = cache. k_snd. data
1972
+ nz = nonzeros (A)
1973
+ for p in eachindex (k_snd)
1974
+ k = k_snd[p]
1975
+ V_snd[p] = nz[k]
1976
+ end
1977
+ end
1978
+ function setup_sa (B,A,cache)
1979
+ setcoofast! (B,nonzeros (A),cache. k_sa)
1980
+ end
1981
+ function setup_rcv (B,cache)
1982
+ V_rcv = cache. V_rcv. data
1983
+ k_rcv = cache. k_rcv. data
1984
+ nz = nonzeros (B)
1985
+ for p in eachindex (k_rcv)
1986
+ k = k_rcv[p]
1987
+ nz[k] += V_rcv[p]
1988
+ end
1989
+ end
1990
+ map (setup_snd,partition (A),cache)
1991
+ parts_snd = map (i-> i. parts_snd,cache)
1992
+ parts_rcv = map (i-> i. parts_rcv,cache)
1993
+ V_snd = map (i-> i. V_snd,cache)
1994
+ V_rcv = map (i-> i. V_rcv,cache)
1995
+ graph = ExchangeGraph (parts_snd,parts_rcv)
1996
+ t = exchange! (V_rcv,V_snd,graph)
1997
+ map (setup_sa,partition (B),partition (A),cache)
1998
+ @fake_async begin
1999
+ wait (t)
2000
+ map (setup_rcv,partition (B),cache)
2001
+ B
2002
+ end
2003
+ end
2004
+
2005
+ function psparse_assemble_impl! (f,A,:: Type ;kwargs... )
2006
+ error (" case not implemented" )
2007
+ end
2008
+
2009
+ function psparse_assemble_impl! (
2010
+ f,
2011
+ A,
2012
+ :: Type{<:AbstractSparseMatrix} ;
2013
+ reuse= Val (false ))
2014
+
2015
+ function setup_cache_snd (A,parts_snd,rows,cols)
2016
+ local_to_owner_row = local_to_owner (rows)
2017
+ local_to_global_row = local_to_global (rows)
2018
+ local_to_global_col = local_to_global (cols)
2019
+ me = part_id (rows)
2020
+ owner_to_p = Dict (( owner=> i for (i,owner) in enumerate (parts_snd) ))
2021
+ ptrs = zeros (Int32,length (parts_snd)+ 1 )
2022
+ for (i,_,_) in nziterator (A)
2023
+ owner = local_to_owner_row[i]
2024
+ if owner != me
2025
+ ptrs[owner_to_p[owner]+ 1 ] += 1
2026
+ end
2027
+ end
2028
+ length_to_ptrs! (ptrs)
2029
+ Tv = eltype (A)
2030
+ ndata = ptrs[end ]- 1
2031
+ I_snd_data = zeros (Int,ndata)
2032
+ J_snd_data = zeros (Int,ndata)
2033
+ V_snd_data = zeros (Tv,ndata)
2034
+ k_snd_data = zeros (Int32,ndata)
2035
+ for (k,(i,j,v)) in enumerate (nziterator (A))
2036
+ owner = local_to_owner_row[i]
2037
+ if owner != me
2038
+ p = ptrs[owner_to_p[owner]]
2039
+ I_snd_data[p] = local_to_global_row[i]
2040
+ J_snd_data[p] = local_to_global_col[j]
2041
+ V_snd_data[p] = v
2042
+ k_snd_data[p] = k
2043
+ ptrs[owner_to_p[owner]] += 1
2044
+ end
2045
+ end
2046
+ rewind_ptrs! (ptrs)
2047
+ I_snd = JaggedArray (I_snd_data,ptrs)
2048
+ J_snd = JaggedArray (J_snd_data,ptrs)
2049
+ V_snd = JaggedArray (V_snd_data,ptrs)
2050
+ k_snd = JaggedArray (k_snd_data,ptrs)
2051
+ (;I_snd,J_snd,V_snd,k_snd,parts_snd)
2052
+ end
2053
+ function setup_cache_rcv (I_rcv,J_rcv,V_rcv,parts_rcv)
2054
+ k_rcv_data = zeros (Int32,length (I_rcv. data))
2055
+ k_rcv = JaggedArray (k_rcv_data,I_rcv. ptrs)
2056
+ (;I_rcv,J_rcv,V_rcv,k_rcv,parts_rcv)
2057
+ end
2058
+ function finalize_values! (A,rows,cols,cache_snd,cache_rcv)
2059
+ I_rcv_data = cache_rcv. I_rcv. data
2060
+ J_rcv_data = cache_rcv. J_rcv. data
2061
+ V_rcv_data = cache_rcv. V_rcv. data
2062
+ k_rcv_data = cache_rcv. k_rcv. data
2063
+ A_nonzeros = nonzeros (A)
2064
+ map_global_to_local! (I_rcv_data,rows)
2065
+ map_global_to_local! (J_rcv_data,cols)
2066
+ for p in eachindex (k_rcv_data)
2067
+ i = I_rcv_data[p]
2068
+ j = J_rcv_data[p]
2069
+ k = nzindex (A,i,j)
2070
+ @boundscheck @assert k > 0 " The sparsity pattern of the ghost layer is inconsistent"
2071
+ k_rcv_data[p] = k
2072
+ A_nonzeros[k] = f (A_nonzeros[k],V_rcv_data[p])
2073
+ end
2074
+ cache = (;cache_snd... ,cache_rcv... )
2075
+ cache
2076
+ end
2077
+ rows = partition (axes (A,1 ))
2078
+ cols = partition (axes (A,2 ))
2079
+ parts_snd, parts_rcv = assembly_neighbors (rows)
2080
+ cache_snd = map (setup_cache_snd,partition (A),parts_snd,rows,cols)
2081
+ I_snd = map (i-> i. I_snd,cache_snd)
2082
+ J_snd = map (i-> i. J_snd,cache_snd)
2083
+ V_snd = map (i-> i. V_snd,cache_snd)
2084
+ graph = ExchangeGraph (parts_snd,parts_rcv)
2085
+ t_I = exchange (I_snd,graph)
2086
+ t_J = exchange (J_snd,graph)
2087
+ t_V = exchange (V_snd,graph)
2088
+ @fake_async begin
2089
+ I_rcv = fetch (t_I)
2090
+ J_rcv = fetch (t_J)
2091
+ V_rcv = fetch (t_V)
2092
+ cache_rcv = map (setup_cache_rcv,I_rcv,J_rcv,V_rcv,parts_rcv)
2093
+ cache = map (finalize_values!,partition (A),rows,cols,cache_snd,cache_rcv)
2094
+ if ! val_parameter (reuse)
2095
+ A
2096
+ else
2097
+ A, cache
2098
+ end
2099
+ end
2100
+ end
2101
+
2102
+ function psparse_assemble_impl! (f:: Function ,A,:: Type ,cache)
2103
+ error (" case not implemented" )
2104
+ end
2105
+
2106
+ function psparse_assemble_impl! (f:: Function ,A,:: Type{<:AbstractSparseMatrix} ,cache)
2107
+ function setup_snd (A,cache)
2108
+ V_snd_data = cache. V_snd. data
2109
+ k_snd_data = cache. k_snd. data
2110
+ A_nonzeros = nonzeros (A)
2111
+ for p in eachindex (k_snd_data)
2112
+ k = k_snd_data[p]
2113
+ V_snd_data[p] = A_nonzeros[k]
2114
+ end
2115
+ end
2116
+ function setup_rcv (A,cache)
2117
+ V_rcv_data = cache. V_rcv. data
2118
+ k_rcv_data = cache. k_rcv. data
2119
+ A_nonzeros = nonzeros (A)
2120
+ for p in eachindex (k_rcv_data)
2121
+ k = k_rcv_data[p]
2122
+ A_nonzeros[k] = f (A_nonzeros[k],V_rcv_data[p])
2123
+ end
2124
+ end
2125
+ map (setup_snd,partition (A),cache)
2126
+ parts_snd = map (i-> i. parts_snd,cache)
2127
+ parts_rcv = map (i-> i. parts_rcv,cache)
2128
+ V_snd = map (i-> i. V_snd,cache)
2129
+ V_rcv = map (i-> i. V_rcv,cache)
2130
+ graph = ExchangeGraph (parts_snd,parts_rcv)
2131
+ t = exchange! (V_rcv,V_snd,graph)
2132
+ @fake_async begin
2133
+ wait (t)
2134
+ map (setup_rcv,partition (A),cache)
2135
+ A
2136
+ end
2137
+ end
2138
+
1818
2139
"""
1819
2140
consistent(A::PSparseMatrix,rows;kwargs...)
1820
2141
"""
0 commit comments