Skip to content

Commit a041e6e

Browse files
timholytkelman
authored andcommitted
Only reuse memory in mapslices when it's safe (fixes #18524) (#18570)
1 parent d478f96 commit a041e6e

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

base/abstractarray.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,6 +1734,7 @@ function mapslices(f, A::AbstractArray, dims::AbstractVector)
17341734

17351735
Aslice = A[idx...]
17361736
r1 = f(Aslice)
1737+
safe_for_reuse = isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number)
17371738

17381739
# determine result size and allocate
17391740
Rsize = copy(dimsA)
@@ -1758,15 +1759,31 @@ function mapslices(f, A::AbstractArray, dims::AbstractVector)
17581759

17591760
isfirst = true
17601761
nidx = length(otherdims)
1761-
for I in CartesianRange(itershape)
1762-
if isfirst
1763-
isfirst = false
1764-
else
1765-
for i in 1:nidx
1766-
idx[otherdims[i]] = ridx[otherdims[i]] = I.I[i]
1762+
if safe_for_reuse
1763+
# when f returns an array, R[ridx...] = f(Aslice) line copies elements,
1764+
# so we can reuse Aslice
1765+
for I in CartesianRange(itershape)
1766+
if isfirst
1767+
isfirst = false # skip the first element, we already handled it
1768+
else
1769+
for i in 1:nidx
1770+
idx[otherdims[i]] = ridx[otherdims[i]] = I.I[i]
1771+
end
1772+
_unsafe_getindex!(Aslice, A, idx...)
1773+
R[ridx...] = f(Aslice)
1774+
end
1775+
end
1776+
else
1777+
# we can't guarantee safety (#18524), so allocate new storage for each slice
1778+
for I in CartesianRange(itershape)
1779+
if isfirst
1780+
isfirst = false
1781+
else
1782+
for i in 1:nidx
1783+
idx[otherdims[i]] = ridx[otherdims[i]] = I.I[i]
1784+
end
1785+
R[ridx...] = f(A[idx...])
17671786
end
1768-
_unsafe_getindex!(Aslice, A, idx...)
1769-
R[ridx...] = f(Aslice)
17701787
end
17711788
end
17721789

test/arrayops.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,17 @@ end
953953
@test lexless(asc[:,1],asc[:,2])
954954
@test lexless(asc[:,2],asc[:,3])
955955

956+
# mutating functions
957+
o = ones(3, 4)
958+
m = mapslices(x->fill!(x, 0), o, 2)
959+
@test m == zeros(3, 4)
960+
@test o == ones(3, 4)
961+
962+
# issue #18524
963+
m = mapslices(x->tuple(x), [1 2; 3 4], 1)
964+
@test m[1,1] == ([1,3],)
965+
@test m[1,2] == ([2,4],)
966+
956967
asr = sortrows(a, rev=true)
957968
@test lexless(asr[2,:],asr[1,:])
958969
@test lexless(asr[3,:],asr[2,:])

0 commit comments

Comments
 (0)