@@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated(
45
45
return indexmask
46
46
end
47
47
48
+ function similar_truncate (
49
+ :: typeof (svd_trunc!),
50
+ (U, S, Vᴴ):: TBlockUSV ᴴ,
51
+ strategy:: BlockPermutedDiagonalTruncationStrategy ,
52
+ indexmask= MatrixAlgebraKit. findtruncated (diagview (S), strategy),
53
+ )
54
+ ax = axes (S, 1 )
55
+ counter = Base. Fix1 (count, Base. Fix1 (getindex, indexmask))
56
+ s_lengths = filter! (> (0 ), map (counter, blocks (ax)))
57
+ s_axis = blockedrange (s_lengths)
58
+ Ũ = similar (U, axes (U, 1 ), s_axis)
59
+ S̃ = similar (S, s_axis, s_axis)
60
+ Ṽᴴ = similar (Vᴴ, s_axis, axes (Vᴴ, 2 ))
61
+ return Ũ, S̃, Ṽᴴ
62
+ end
63
+
48
64
function MatrixAlgebraKit. truncate! (
49
65
:: typeof (svd_trunc!),
50
66
(U, S, Vᴴ):: TBlockUSV ᴴ,
@@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!(
54
70
55
71
# first determine the block structure of the output to avoid having assumptions on the
56
72
# data structures
57
- ax = axes (S, 1 )
58
- counter = Base. Fix1 (count, Base. Fix1 (getindex, indexmask))
59
- Slengths = filter! (> (0 ), map (counter, blocks (ax)))
60
- Sax = blockedrange (Slengths)
61
- Ũ = similar (U, axes (U, 1 ), Sax)
62
- S̃ = similar (S, Sax, Sax)
63
- Ṽᴴ = similar (Vᴴ, Sax, axes (Vᴴ, 2 ))
73
+ Ũ, S̃, Ṽᴴ = similar_truncate (svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
64
74
65
75
# then loop over the blocks and assign the data
66
76
# TODO : figure out if we can presort and loop over the blocks -
@@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!(
70
80
bI_Vᴴs = collect (eachblockstoredindex (Vᴴ))
71
81
72
82
I′ = 0 # number of skipped blocks that got fully truncated
83
+ ax = axes (S, 1 )
73
84
for I in 1 : blocksize (ax, 1 )
74
85
b = ax[Block (I)]
75
86
mask = indexmask[b]
0 commit comments