Skip to content

Commit 34a826b

Browse files
authored
[NestedPermutedDimsArrays] Fix setindex! (#1593)
1 parent 549e7f1 commit 34a826b

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.71"
4+
version = "0.3.72"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,45 @@
11
# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
22
# Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose`
33
# (though those are fully recursive).
4+
#=
5+
TODO: Investigate replacing this with a `PermutedDimsArray` wrapped around a `MappedArrays.MappedArray`.
6+
There are a few issues with that:
7+
1. Just using a type alias leads to type piracy, for example the constructor is type piracy.
8+
2. `setindex!(::NestedPermutedDimsArray, I...)` fails because no conversion is defined between `Array`
9+
and `PermutedDimsArray`.
10+
3. The type alias is tricky to define, ideally it would have similar type parameters to the current
11+
`NestedPermutedDimsArrays.NestedPermutedDimsArray` definition which matches the type parameters
12+
of `PermutedDimsArrays.PermutedDimsArray` but that seems to be difficult to achieve.
13+
```julia
14+
module NestedPermutedDimsArrays
15+
16+
using MappedArrays: MultiMappedArray, mappedarray
17+
export NestedPermutedDimsArray
18+
19+
const NestedPermutedDimsArray{TT,T<:AbstractArray{TT},N,perm,iperm,AA<:AbstractArray{T}} = PermutedDimsArray{
20+
PermutedDimsArray{TT,N,perm,iperm,T},
21+
N,
22+
perm,
23+
iperm,
24+
MultiMappedArray{
25+
PermutedDimsArray{TT,N,perm,iperm,T},
26+
N,
27+
Tuple{AA},
28+
Type{PermutedDimsArray{TT,N,perm,iperm,T}},
29+
Type{PermutedDimsArray{TT,N,iperm,perm,T}},
30+
},
31+
}
32+
33+
function NestedPermutedDimsArray(a::AbstractArray, perm)
34+
iperm = invperm(perm)
35+
f = PermutedDimsArray{eltype(eltype(a)),ndims(a),perm,iperm,eltype(a)}
36+
finv = PermutedDimsArray{eltype(eltype(a)),ndims(a),iperm,perm,eltype(a)}
37+
return PermutedDimsArray(mappedarray(f, finv, a), perm)
38+
end
39+
40+
end
41+
```
42+
=#
443
module NestedPermutedDimsArrays
544

645
import Base: permutedims, permutedims!
@@ -107,7 +146,7 @@ end
107146
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
108147
) where {T,N,perm,iperm}
109148
@boundscheck checkbounds(A, I...)
110-
@inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...)
149+
@inbounds setindex!(A.parent, PermutedDimsArray(val, iperm), genperm(I, iperm)...)
111150
return val
112151
end
113152

NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@ using Test: @test, @testset
55
Float32, Float64, Complex{Float32}, Complex{Float64}
66
)
77
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
8-
perm = (3, 2, 1)
8+
perm = (3, 1, 2)
99
p = NestedPermutedDimsArray(a, perm)
1010
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
1111
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
12-
@test size(p) == (4, 3, 2)
12+
@test size(p) == (4, 2, 3)
1313
@test eltype(p) === T
1414
for I in eachindex(p)
15-
@test size(p[I]) == (4, 3, 2)
16-
@test p[I] == permutedims(a[CartesianIndex(reverse(Tuple(I)))], perm)
15+
@test size(p[I]) == (4, 2, 3)
16+
@test p[I] == permutedims(a[CartesianIndex(map(i -> Tuple(I)[i], invperm(perm)))], perm)
1717
end
18-
x = randn(elt, 4, 3, 2)
19-
p[3, 2, 1] = x
20-
@test p[3, 2, 1] == x
21-
@test a[1, 2, 3] == permutedims(x, perm)
18+
x = randn(elt, 4, 2, 3)
19+
p[3, 1, 2] = x
20+
@test p[3, 1, 2] == x
21+
@test a[1, 2, 3] == permutedims(x, invperm(perm))
2222
end
2323
end

0 commit comments

Comments
 (0)