Skip to content

Commit 9a084eb

Browse files
authored
fix array2py and friends for adjoints etc. (#653)
* fix array2py and friends for adjoints etc. * fallback for missing numpy * whoops
1 parent e137bb0 commit 9a084eb

File tree

4 files changed

+38
-25
lines changed

4 files changed

+38
-25
lines changed

src/conversions.jl

+16-22
Original file line numberDiff line numberDiff line change
@@ -292,37 +292,31 @@ append!(a::PyVector{T}, items) where {T} = PyVector{T}(append!(a.o, items))
292292
#########################################################################
293293
# Lists and 1d arrays.
294294

295+
if VERSION < v"1.1.0-DEV.392" # #29440
296+
cirange(I,J) = CartesianIndices(map((i,j) -> i:j, Tuple(I), Tuple(J)))
297+
else
298+
cirange(I,J) = I:J
299+
end
300+
295301
# recursive conversion of A to a list of list of lists... starting
296-
# with dimension dim and index i in A.
297-
function array2py(A::AbstractArray{T, N}, dim::Integer, i::Integer) where {T, N}
298-
if dim > N
302+
# with dimension dim and Cartesian index i in A.
303+
function array2py(A::AbstractArray{<:Any, N}, dim::Integer, i::CartesianIndex{N}) where {N}
304+
if dim > N # base case
299305
return PyObject(A[i])
300-
elseif dim == N # special case last dim to coarsen recursion leaves
301-
len = size(A, dim)
302-
s = N == 1 ? 1 : stride(A, dim)
303-
o = PyObject(@pycheckn ccall((@pysym :PyList_New), PyPtr, (Int,), len))
304-
for j = 0:len-1
305-
oi = PyObject(A[i+j*s])
306-
@pycheckz ccall((@pysym :PyList_SetItem), Cint, (PyPtr,Int,PyPtr),
307-
o, j, oi)
308-
pyincref(oi) # PyList_SetItem steals the reference
309-
end
310-
return o
311-
else # dim < N: store multidimensional array as list of lists
312-
len = size(A, dim)
313-
s = stride(A, dim)
314-
o = PyObject(@pycheckn ccall((@pysym :PyList_New), PyPtr, (Int,), len))
315-
for j = 0:len-1
316-
oi = array2py(A, dim+1, i+j*s)
306+
else # recursively store multidimensional array as list of lists
307+
ilast = CartesianIndex(ntuple(j -> j == dim ? lastindex(A, dim) : i[j], Val{N}()))
308+
o = PyObject(@pycheckn ccall((@pysym :PyList_New), PyPtr, (Int,), size(A, dim)))
309+
for icur in cirange(i,ilast)
310+
oi = array2py(A, dim+1, icur)
317311
@pycheckz ccall((@pysym :PyList_SetItem), Cint, (PyPtr,Int,PyPtr),
318-
o, j, oi)
312+
o, icur[dim]-i[dim], oi)
319313
pyincref(oi) # PyList_SetItem steals the reference
320314
end
321315
return o
322316
end
323317
end
324318

325-
array2py(A::AbstractArray) = array2py(A, 1, 1)
319+
array2py(A::AbstractArray) = array2py(A, 1, first(CartesianIndices(A)))
326320

327321
PyObject(A::AbstractArray) =
328322
ndims(A) <= 1 || hasmethod(stride, Tuple{typeof(A),Int}) ? array2py(A) :

src/numpy.jl

+18-2
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,23 @@ function PyObject(a::StridedArray{T}) where T<:PYARR_TYPES
190190
try
191191
return NpyArray(a, false)
192192
catch
193-
array2py(a) # fallback to non-NumPy version
193+
return array2py(a) # fallback to non-NumPy version
194194
end
195195
end
196196

197-
PyReverseDims(a::StridedArray{T}) where {T<:PYARR_TYPES} = NpyArray(a, true)
197+
function PyReverseDims(a::StridedArray{T,N}) where {T<:PYARR_TYPES,N}
198+
try
199+
return NpyArray(a, true)
200+
catch
201+
return array2py(permutedims(a, N:-1:1)) # fallback to non-NumPy version
202+
end
203+
end
198204
PyReverseDims(a::BitArray) = PyReverseDims(Array(a))
199205

206+
# fallback to physically transposing the array
207+
PyReverseDims(a::AbstractArray{<:Any,N}) where {N} = PyObject(permutedims(a, N:-1:1))
208+
PyReverseDims(a::AbstractMatrix) = PyObject(permutedims(a))
209+
200210
"""
201211
PyReverseDims(array)
202212
@@ -209,3 +219,9 @@ libraries that expect row-major data.
209219
PyReverseDims(a::AbstractArray)
210220

211221
#########################################################################
222+
223+
# transposed arrays can be passed to NumPy without copying
224+
PyObject(a::Union{LinearAlgebra.Adjoint{<:Real},LinearAlgebra.Transpose}) =
225+
PyReverseDims(a.parent)
226+
227+
PyObject(a::LinearAlgebra.Adjoint) = PyObject(Matrix(a)) # non-real arrays require a copy

src/pybuffer.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function array_format(pybuf::PyBuffer)
251251
use_native_sizes = false
252252
elseif fmt_str[1] == '='
253253
use_native_sizes = false
254-
elseif fmt_str[1] == "Z"
254+
elseif fmt_str[1] == 'Z'
255255
type_start_idx = 1
256256
else
257257
error("Unsupported format string: \"$fmt_str\"")

test/runtests.jl

+3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ const PyInt = pyversion < v"3" ? Int : Clonglong
6060
@test roundtripeq(C_NULL) && roundtripeq(convert(Ptr{Cvoid}, 12345))
6161
@test roundtripeq([1,3,4,5]) && roundtripeq([1,3.2,"hello",true])
6262
@test roundtripeq([1 2 3;4 5 6]) && roundtripeq([1. 2 3;4 5 6])
63+
@test roundtripeq([1. 2 3;4 5 6]')
64+
@test roundtripeq([1.0+2im 2+3im 3;4 5 6]')
6365
@test roundtripeq((1,(3.2,"hello"),true)) && roundtripeq(())
6466
@test roundtripeq(Int32)
6567
@test roundtripeq(Dict(1 => "hello", 2 => "goodbye")) && roundtripeq(Dict())
@@ -119,6 +121,7 @@ const PyInt = pyversion < v"3" ? Int : Clonglong
119121
array2py2arrayeq(x) = PyCall.py2array(Float64,PyCall.array2py(x)) == x
120122
@test array2py2arrayeq(rand(3))
121123
@test array2py2arrayeq(rand(3,4))
124+
@test array2py2arrayeq(rand(3,4)')
122125
@test array2py2arrayeq(rand(3,4,5))
123126

124127
@test roundtripeq(2:10) && roundtripeq(10:-1:2)

0 commit comments

Comments
 (0)