Skip to content

Commit e835a39

Browse files
committed
Update truncation
1 parent 720d5ec commit e835a39

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/factorizations/truncation.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated(
4545
return indexmask
4646
end
4747

48+
function similar_truncate(
49+
::typeof(svd_trunc!),
50+
(U, S, Vᴴ)::TBlockUSVᴴ,
51+
strategy::BlockPermutedDiagonalTruncationStrategy,
52+
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
53+
)
54+
ax = axes(S, 1)
55+
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
56+
s_lengths = filter!(>(0), map(counter, blocks(ax)))
57+
s_axis = blockedrange(s_lengths)
58+
= similar(U, axes(U, 1), s_axis)
59+
= similar(S, s_axis, s_axis)
60+
Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2))
61+
return Ũ, S̃, Ṽᴴ
62+
end
63+
4864
function MatrixAlgebraKit.truncate!(
4965
::typeof(svd_trunc!),
5066
(U, S, Vᴴ)::TBlockUSVᴴ,
@@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!(
5470

5571
# first determine the block structure of the output to avoid having assumptions on the
5672
# data structures
57-
ax = axes(S, 1)
58-
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
59-
Slengths = filter!(>(0), map(counter, blocks(ax)))
60-
Sax = blockedrange(Slengths)
61-
= similar(U, axes(U, 1), Sax)
62-
= similar(S, Sax, Sax)
63-
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))
73+
Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
6474

6575
# then loop over the blocks and assign the data
6676
# TODO: figure out if we can presort and loop over the blocks -
@@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!(
7080
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))
7181

7282
I′ = 0 # number of skipped blocks that got fully truncated
83+
ax = axes(S, 1)
7384
for I in 1:blocksize(ax, 1)
7485
b = ax[Block(I)]
7586
mask = indexmask[b]

0 commit comments

Comments
 (0)