Skip to content

Commit 4cbee4a

Browse files
committed
fix isinplace inference and add inference tests
1 parent 2e1a62c commit 4cbee4a

File tree

4 files changed

+152
-145
lines changed

4 files changed

+152
-145
lines changed

src/problems/optimization_problems.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ function OptimizationProblem(
131131
OptimizationProblem{isinplace(f)}(f, args...; kwargs...)
132132
end
133133
function OptimizationProblem(f, args...; kwargs...)
134-
isinplace(f, 2, has_two_dispatches = false)
135-
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
134+
OptimizationProblem(OptimizationFunction(f), args...; kwargs...)
136135
end
137136

138137
function OptimizationFunction(

src/scimlfunctions.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4199,7 +4199,10 @@ IntervalNonlinearFunction(f::IntervalNonlinearFunction; kwargs...) = f
41994199
struct NoAD <: AbstractADType end
42004200

42014201
(f::OptimizationFunction)(args...) = f.f(args...)
4202-
OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; kwargs...)
4202+
function OptimizationFunction(f, args...; kwargs...)
4203+
isinplace(f, 2, outofplace_param_number=2)
4204+
OptimizationFunction{true}(f, args...; kwargs...)
4205+
end
42034206

42044207
function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
42054208
grad = nothing, fg = nothing, hess = nothing, hv = nothing, fgh = nothing,
@@ -4251,8 +4254,9 @@ end
42514254
(f::MultiObjectiveOptimizationFunction)(args...) = f.f(args...)
42524255

42534256
# Convenience constructor
4254-
function MultiObjectiveOptimizationFunction(args...; kwargs...)
4255-
MultiObjectiveOptimizationFunction{true}(args...; kwargs...)
4257+
function MultiObjectiveOptimizationFunction(f, args...; kwargs...)
4258+
isinplace(f, 2, outofplace_param_number=2)
4259+
MultiObjectiveOptimizationFunction{true}(f, args...; kwargs...)
42564260
end
42574261

42584262
# Constructor with keyword arguments
@@ -4339,15 +4343,17 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
43394343
if iip_f
43404344
jac = update_coefficients! #(J,u,p,t)
43414345
else
4342-
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
4346+
jac_prototype_copy = deepcopy(jac_prototype)
4347+
jac = (u, p, t) -> update_coefficients!(jac_prototype_copy, u, p, t)
43434348
end
43444349
end
43454350

43464351
if bcjac === nothing && isa(bcjac_prototype, AbstractSciMLOperator)
43474352
if iip_bc
43484353
bcjac = update_coefficients! #(J,u,p,t)
43494354
else
4350-
bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t)
4355+
bcjac_prototype_copy = deepcopy(bcjac_prototype)
4356+
bcjac = (u, p, t) -> update_coefficients!(bcjac_prototype_copy, u, p, t)
43514357
end
43524358
end
43534359

@@ -4512,15 +4518,17 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc;
45124518
if iip_f
45134519
jac = update_coefficients! #(J,u,p,t)
45144520
else
4515-
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
4521+
jac_prototype_copy = deepcopy(jac_prototype)
4522+
jac = (u, p, t) -> update_coefficients!(jac_prototype_copy, u, p, t)
45164523
end
45174524
end
45184525

45194526
if bcjac === nothing && isa(bcjac_prototype, AbstractSciMLOperator)
45204527
if iip_bc
45214528
bcjac = update_coefficients! #(J,u,p,t)
45224529
else
4523-
bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t)
4530+
bcjac_prototype_copy = deepcopy(jac_prototype)
4531+
bcjac = (u, p, t) -> update_coefficients!(bcjac_prototype_copy, u, p, t)
45244532
end
45254533
end
45264534

test/aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
# for method_ambiguity in ambs
3030
# @show method_ambiguity
3131
# end
32-
@warn "Number of method ambiguities: $(length(ambs))"
32+
!isempty(ambs) &&@warn "Number of method ambiguities: $(length(ambs))"
3333
@test length(ambs) 8
3434
end
3535

0 commit comments

Comments
 (0)