Skip to content

Commit 3de0cde

Browse files
Merge pull request #32 from thomvet/allocations
Dispatches get_tmp on wrapper type of cache,
2 parents addfd17 + 86f9832 commit 3de0cde

File tree

4 files changed

+98
-67
lines changed

4 files changed

+98
-67
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PreallocationTools"
22
uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.5.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/PreallocationTools.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,28 @@ function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual}
4646
if nelem > length(dc.dual_du)
4747
enlargedualcache!(dc, nelem)
4848
end
49-
ArrayInterfaceCore.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
49+
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
5050
end
5151

5252
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
5353
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
5454
if nelem > length(dc.dual_du)
5555
enlargedualcache!(dc, nelem)
5656
end
57-
ArrayInterfaceCore.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
57+
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
5858
end
5959

6060
get_tmp(dc::DiffCache, u::Number) = dc.du
6161
get_tmp(dc::DiffCache, u::AbstractArray) = dc.du
6262

63+
function _restructure(normal_cache::Array, duals)
64+
reshape(duals, size(normal_cache)...)
65+
end
66+
67+
function _restructure(normal_cache::AbstractArray, duals)
68+
ArrayInterfaceCore.restructure(normal_cache, duals)
69+
end
70+
6371
function enlargedualcache!(dc, nelem) #warning comes only once per dualcache.
6472
chunksize = div(nelem, length(dc.du)) - 1
6573
@warn "The supplied dualcache was too small and was enlarged. This incurrs allocations

test/core_dispatch.jl

Lines changed: 85 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,96 @@
1-
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, LabelledArrays,
1+
using LinearAlgebra, Test, PreallocationTools, ForwardDiff, LabelledArrays,
22
RecursiveArrayTools
33

4-
#Base Array tests
4+
function test(u0, dual, chunk_size)
5+
cache = PreallocationTools.dualcache(u0, chunk_size)
6+
allocs_normal1 = @allocated get_tmp(cache, u0)
7+
allocs_normal2 = @allocated get_tmp(cache, first(u0))
8+
allocs_dual1 = @allocated get_tmp(cache, dual)
9+
allocs_dual2 = @allocated get_tmp(cache, first(dual))
10+
result_normal1 = get_tmp(cache, u0)
11+
result_normal2 = get_tmp(cache, first(u0))
12+
result_dual1 = get_tmp(cache, dual)
13+
result_dual2 = get_tmp(cache, first(dual))
14+
return allocs_normal1, allocs_normal2, allocs_dual1, allocs_dual2, result_normal1,
15+
result_normal2, result_dual1,
16+
result_dual2
17+
end
18+
19+
#Setup Base Array tests
520
chunk_size = 5
6-
u0_B = ones(5, 5)
7-
dual_B = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
8-
chunk_size}, 2, 2)
9-
cache_B = dualcache(u0_B, chunk_size)
10-
tmp_du_BA = get_tmp(cache_B, u0_B)
11-
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
12-
tmp_du_BN = get_tmp(cache_B, u0_B[1])
13-
tmp_dual_du_BN = get_tmp(cache_B, dual_B[1])
14-
@test size(tmp_du_BA) == size(u0_B)
15-
@test typeof(tmp_du_BA) == typeof(u0_B)
16-
@test eltype(tmp_du_BA) == eltype(u0_B)
17-
@test size(tmp_dual_du_BA) == size(u0_B)
18-
@test typeof(tmp_dual_du_BA) == typeof(dual_B)
19-
@test eltype(tmp_dual_du_BA) == eltype(dual_B)
20-
@test size(tmp_du_BN) == size(u0_B)
21-
@test typeof(tmp_du_BN) == typeof(u0_B)
22-
@test eltype(tmp_du_BN) == eltype(u0_B)
23-
@test size(tmp_dual_du_BN) == size(u0_B)
24-
@test typeof(tmp_dual_du_BN) == typeof(dual_B)
25-
@test eltype(tmp_dual_du_BN) == eltype(dual_B)
21+
u0 = ones(5, 5)
22+
dual = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
23+
chunk_size}, 5, 5)
24+
results = test(u0, dual, chunk_size)
25+
#allocation tests
26+
@test results[1] == 0
27+
@test results[2] == 0
28+
@test results[3] == 0
29+
@test results[4] == 0
30+
#size tests
31+
@test size(results[5]) == size(u0)
32+
@test size(results[6]) == size(u0)
33+
@test size(results[7]) == size(u0)
34+
@test size(results[8]) == size(u0)
35+
#type tests
36+
@test typeof(results[5]) == typeof(u0)
37+
@test typeof(results[6]) == typeof(u0)
38+
@test_broken typeof(results[7]) == typeof(dual)
39+
@test_broken typeof(results[8]) == typeof(dual)
40+
#eltype tests
41+
@test eltype(results[5]) == eltype(u0)
42+
@test eltype(results[7]) == eltype(dual)
2643

2744
#LArray tests
2845
chunk_size = 4
29-
u0_L = LArray((2, 2); a = 1.0, b = 1.0, c = 1.0, d = 1.0)
30-
zerodual = zero(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
46+
u0 = LArray((2, 2); a = 1.0, b = 1.0, c = 1.0, d = 1.0)
47+
zerodual = zero(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
3148
chunk_size})
32-
dual_L = LArray((2, 2); a = zerodual, b = zerodual, c = zerodual, d = zerodual)
33-
cache_L = dualcache(u0_L, chunk_size)
34-
tmp_du_LA = get_tmp(cache_L, u0_L)
35-
tmp_dual_du_LA = get_tmp(cache_L, dual_L)
36-
tmp_du_LN = get_tmp(cache_L, u0_L[1])
37-
tmp_dual_du_LN = get_tmp(cache_L, dual_L[1])
38-
@test size(tmp_du_LA) == size(u0_L)
39-
@test typeof(tmp_du_LA) == typeof(u0_L)
40-
@test eltype(tmp_du_LA) == eltype(u0_L)
41-
@test size(tmp_dual_du_LA) == size(u0_L)
42-
@test typeof(tmp_dual_du_LA) == typeof(dual_L)
43-
@test eltype(tmp_dual_du_LA) == eltype(dual_L)
44-
@test size(tmp_du_LN) == size(u0_L)
45-
@test typeof(tmp_du_LN) == typeof(u0_L)
46-
@test eltype(tmp_du_LN) == eltype(u0_L)
47-
@test size(tmp_dual_du_LN) == size(u0_L)
48-
@test typeof(tmp_dual_du_LN) == typeof(dual_L)
49-
@test eltype(tmp_dual_du_LN) == eltype(dual_L)
49+
dual = LArray((2, 2); a = zerodual, b = zerodual, c = zerodual, d = zerodual)
50+
results = test(u0, dual, chunk_size)
51+
#allocation tests
52+
@test results[1] == 0
53+
@test results[2] == 0
54+
@test_broken results[3] == 0
55+
@test_broken results[4] == 0
56+
#size tests
57+
@test size(results[5]) == size(u0)
58+
@test size(results[6]) == size(u0)
59+
@test size(results[7]) == size(u0)
60+
@test size(results[8]) == size(u0)
61+
#type tests
62+
@test typeof(results[5]) == typeof(u0)
63+
@test typeof(results[6]) == typeof(u0)
64+
@test typeof(results[7]) == typeof(dual)
65+
@test typeof(results[8]) == typeof(dual)
66+
#eltype tests
67+
@test eltype(results[5]) == eltype(u0)
68+
@test eltype(results[7]) == eltype(dual)
5069

5170
#ArrayPartition tests
52-
u0_AP = ArrayPartition(ones(2, 2), ones(3, 3))
53-
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
71+
chunk_size = 2
72+
u0 = ArrayPartition(ones(2, 2), ones(3, 3))
73+
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
5474
chunk_size}, 2, 2)
55-
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
75+
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
5676
chunk_size}, 3, 3)
57-
dual_AP = ArrayPartition(dual_a, dual_b)
58-
cache_AP = dualcache(u0_AP, chunk_size)
59-
tmp_du_APA = get_tmp(cache_AP, u0_AP)
60-
tmp_dual_du_APA = get_tmp(cache_AP, dual_AP)
61-
tmp_du_APN = get_tmp(cache_AP, u0_AP[1])
62-
tmp_dual_du_APN = get_tmp(cache_AP, dual_AP[1])
63-
@test size(tmp_du_APA) == size(u0_AP)
64-
@test typeof(tmp_du_APA) == typeof(u0_AP)
65-
@test eltype(tmp_du_APA) == eltype(u0_AP)
66-
@test size(tmp_dual_du_APA) == size(u0_AP)
67-
@test typeof(tmp_dual_du_APA) == typeof(dual_AP)
68-
@test eltype(tmp_dual_du_APA) == eltype(dual_AP)
69-
@test size(tmp_du_APN) == size(u0_AP)
70-
@test typeof(tmp_du_APN) == typeof(u0_AP)
71-
@test eltype(tmp_du_APN) == eltype(u0_AP)
72-
@test size(tmp_dual_du_APN) == size(u0_AP)
73-
@test typeof(tmp_dual_du_APN) == typeof(dual_AP)
74-
@test eltype(tmp_dual_du_APN) == eltype(dual_AP)
77+
dual = ArrayPartition(dual_a, dual_b)
78+
results = test(u0, dual, chunk_size)
79+
#allocation tests
80+
@test results[1] == 0
81+
@test results[2] == 0
82+
@test_broken results[3] == 0
83+
@test_broken results[4] == 0
84+
#size tests
85+
@test size(results[5]) == size(u0)
86+
@test size(results[6]) == size(u0)
87+
@test size(results[7]) == size(u0)
88+
@test size(results[8]) == size(u0)
89+
#type tests
90+
@test typeof(results[5]) == typeof(u0)
91+
@test typeof(results[6]) == typeof(u0)
92+
@test typeof(results[7]) == typeof(dual)
93+
@test typeof(results[8]) == typeof(dual)
94+
#eltype tests
95+
@test eltype(results[5]) == eltype(u0)
96+
@test eltype(results[7]) == eltype(dual)

test/core_odes.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ sol = solve(prob, TRBDF2(chunk_size = chunk_size))
1818
@test sol.retcode == :Success
1919

2020
#with auto-detected chunk_size
21-
prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0), (ones(5, 5), dualcache(zeros(5, 5))))
21+
cache = dualcache(zeros(5, 5))
22+
prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0), (A, cache))
2223
sol = solve(prob, TRBDF2())
2324
@test sol.retcode == :Success
2425

0 commit comments

Comments
 (0)