Skip to content

Commit 04a7c58

Browse files
Merge pull request #2446 from oscardssmith/os/fix-oop-matrix-u
`_reshape` and `_vec` appropriately in more places
2 parents fc1a214 + 2bc1ff3 commit 04a7c58

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -758,16 +758,17 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order},
758758
α₀ = 1
759759
β₀ = inv((1 - κ) * γₖ[k])
760760
if u isa Number
761-
u₀ = sum(D[1:k]) + uprev
761+
u₀ = sum(view(D, 1:k)) + uprev
762762
ϕ = zero(u)
763763
for i in 1:k
764764
ϕ += γₖ[i] * D[i]
765765
end
766766
else
767-
u₀ = reshape(sum(D[:, 1:k], dims = 2) .+ uprev, size(u))
767+
u₀ = _reshape(sum(view(D, :, 1:k), dims = 2), axes(u)) .+ uprev
768768
ϕ = zero(u)
769769
for i in 1:k
770-
ϕ = @.. ϕ + γₖ[i] * D[:, i]
770+
D_row = _reshape(view(D, :, i), axes(u))
771+
ϕ = @.. ϕ + γₖ[i] * D_row
771772
end
772773
end
773774
markfirststage!(nlsolver)
@@ -802,14 +803,14 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order},
802803
end
803804
integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t)
804805
if k > 1
805-
@views atmpm1 = calculate_residuals(D[:, k], uprev, u, integrator.opts.abstol,
806-
integrator.opts.reltol,
807-
integrator.opts.internalnorm, t)
806+
@views atmpm1 = calculate_residuals(_reshape(view(D, :, k), axes(u)),
807+
uprev, u, integrator.opts.abstol,
808+
integrator.opts.reltol, integrator.opts.internalnorm, t)
808809
cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t)
809810
end
810811
if k < max_order
811-
@views atmpp1 = calculate_residuals(D[:, k + 2], uprev, u, abstol, reltol,
812-
internalnorm, t)
812+
@views atmpp1 = calculate_residuals(_reshape(view(D, :, k + 2), axes(u)),
813+
uprev, u, abstol, reltol, internalnorm, t)
813814
cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t)
814815
end
815816
end
@@ -925,13 +926,13 @@ function perform_step!(integrator, cache::QNDFCache{max_order},
925926
integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t)
926927
if k > 1
927928
@views calculate_residuals!(
928-
atmpm1, reshape(D[:, k], size(u)), uprev, u, abstol,
929+
atmpm1, _reshape(D[:, k], axes(u)), uprev, u, abstol,
929930
reltol, internalnorm, t)
930931
cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t)
931932
end
932933
if k < max_order
933934
@views calculate_residuals!(
934-
atmpp1, reshape(D[:, k + 2], size(u)), uprev, u, abstol,
935+
atmpp1, _reshape(D[:, k + 2], axes(u)), uprev, u, abstol,
935936
reltol, internalnorm, t)
936937
cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t)
937938
end
@@ -1112,7 +1113,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order},
11121113
end
11131114
tmp = -uprev * bdf_coeffs[k, 2]
11141115
for i in 1:(k - 1)
1115-
@views tmp = @.. tmp - u_corrector[:, i] * bdf_coeffs[k, i + 2]
1116+
tmp = @.. tmp - $(_reshape(view(u_corrector, :, i), axes(u))) * bdf_coeffs[k, i + 2]
11161117
end
11171118
end
11181119

@@ -1169,7 +1170,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order},
11691170
terk *= abs(dt^(k))
11701171
else
11711172
for i in 2:(k + 1)
1172-
@views terk = @.. terk + fd_weights[i, k + 1] * u_history[:, i - 1]
1173+
terk = @.. terk + fd_weights[i, k + 1] * $(_reshape(view(u_history, :, i - 1), axes(u)))
11731174
end
11741175
terk *= abs(dt^(k))
11751176
end

lib/OrdinaryDiffEqBDF/src/controllers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ function choose_order!(alg::FBDF, integrator,
217217
terk_tmp = similar(u)
218218
@.. terk_tmp = fd_weights[k - 2, 1] * _vec(u)
219219
for i in 2:(k - 2)
220-
@.. @views terk_tmp += fd_weights[i, k - 2] * u_history[:, i - 1]
220+
@.. terk_tmp += fd_weights[i, k - 2] * $(_reshape(view(u_history, :, i - 1), axes(u)))
221221
end
222222
@.. terk_tmp *= abs(dt^(k - 2))
223223
end

lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ end
169169

170170
rhs1 = @. fw1 - αdt * Mw1 + βdt * Mw2
171171
rhs2 = @. fw2 - βdt * Mw1 - αdt * Mw2
172-
dw12 = LU1 \ (@. rhs1 + rhs2 * im)
172+
dw12 = _reshape(LU1 \ _vec(@. rhs1 + rhs2 * im), axes(u))
173173
integrator.stats.nsolve += 1
174174
dw1 = real(dw12)
175175
dw2 = imag(dw12)
@@ -450,8 +450,8 @@ end
450450
rhs1 = @.. broadcast=false fw1-γdt * Mw1
451451
rhs2 = @.. broadcast=false fw2 - αdt * Mw2+βdt * Mw3
452452
rhs3 = @.. broadcast=false fw3 - βdt * Mw2-αdt * Mw3
453-
dw1 = LU1 \ rhs1
454-
dw23 = LU2 \ (@.. broadcast=false rhs2+rhs3 * im)
453+
dw1 = _reshape(LU1 \ _vec(rhs1), axes(u))
454+
dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u))
455455
integrator.stats.nsolve += 2
456456
dw2 = real(dw23)
457457
dw3 = imag(dw23)
@@ -507,7 +507,10 @@ end
507507
tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3
508508
mass_matrix != I && (tmp = mass_matrix * tmp)
509509
utilde = @.. broadcast=false integrator.fsalfirst+tmp
510-
alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1)
510+
if alg.smooth_est
511+
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
512+
integrator.stats.nsolve += 1
513+
end
511514
# RadauIIA5 needs a transformed rtol and atol see
512515
# https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421
513516
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
@@ -899,9 +902,9 @@ end
899902
rhs3 = @.. broadcast=false fw3 - β1dt * Mw2-α1dt * Mw3
900903
rhs4 = @.. broadcast=false fw4 - α2dt * Mw4+β2dt * Mw5
901904
rhs5 = @.. broadcast=false fw5 - β2dt * Mw4-α2dt * Mw5
902-
dw1 = LU1 \ rhs1
903-
dw23 = LU2 \ (@.. broadcast=false rhs2+rhs3 * im)
904-
dw45 = LU3 \ (@.. broadcast=false rhs4+rhs5 * im)
905+
dw1 = _reshape(LU1 \ _vec(rhs1), axes(u))
906+
dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u))
907+
dw45 = _reshape(LU3 \ _vec(@.. broadcast=false rhs4+rhs5 * im), axes(u))
905908
integrator.stats.nsolve += 3
906909
dw2 = real(dw23)
907910
dw3 = imag(dw23)
@@ -969,7 +972,10 @@ end
969972
tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3+e4dt*z4+e5dt*z5
970973
mass_matrix != I && (tmp = mass_matrix * tmp)
971974
utilde = @.. broadcast=false integrator.fsalfirst+tmp
972-
alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1)
975+
if alg.smooth_est
976+
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
977+
integrator.stats.nsolve += 1
978+
end
973979
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
974980
integrator.EEst = internalnorm(atmp, t)
975981

0 commit comments

Comments
 (0)