Skip to content

Commit b5eeb94

Browse files
authored
[NDTensors] UniformDiagBlockSparse norm (#1622)
1 parent 4339067 commit b5eeb94

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

NDTensors/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/blocksparse/diagblocksparse.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using LinearAlgebra: LinearAlgebra
12
using TypeParameterAccessors: similartype
23

34
export DiagBlockSparse, DiagBlockSparseTensor
@@ -582,8 +583,13 @@ function _contract!!(
582583
return R
583584
end
584585

585-
# TODO: Improve this with FillArrays.jl
586-
norm(S::UniformDiagBlockSparseTensor) = sqrt(mindim(S) * abs2(data(S)))
586+
function LinearAlgebra.norm(D::UniformDiagBlockSparseTensor)
587+
normD² = zero(eltype(D))
588+
for b in nzblocks(D)
589+
normD² += norm(D[b])^2
590+
end
591+
return (abs(normD²))
592+
end
587593

588594
function contraction_output(
589595
T1::TensorT1, labelsT1, T2::TensorT2, labelsT2, labelsR

NDTensors/test/test_diagblocksparse.jl

+13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@eval module $(gensym())
22
using Dictionaries: Dictionary
33
using GPUArraysCore: @allowscalar
4+
using LinearAlgebra: norm
45
using NDTensors:
56
NDTensors,
67
Block,
@@ -87,6 +88,18 @@ using .NDTensorsTestUtils: devices_list
8788
contract(dense(A), (-1, -2), dense(t), (-1, -2))[]
8889
end
8990

91+
@testset "UniformDiagBlockSparse norm" begin
92+
elt = Float64
93+
storage = DiagBlockSparse(one(elt), Dictionary([Block(1, 1), Block(2, 2)], [0, 2]))
94+
tensor = Tensor(storage, ([2, 2], [2, 2]))
95+
@test norm(tensor) norm(dense(tensor))
96+
97+
elt = Float64
98+
storage = DiagBlockSparse(one(elt), Dictionary([Block(1, 1)], [0]))
99+
tensor = Tensor(storage, ([2], [1, 1]))
100+
@test norm(tensor) norm(dense(tensor))
101+
end
102+
90103
@testset "DiagBlockSparse denseblocks" begin
91104
elt = Float64
92105
blockoffsets_a = Dictionary([Block(1, 1), Block(2, 2)], [0, 2])

0 commit comments

Comments
 (0)