Skip to content

Commit ecdc172

Browse files
Merge pull request #1022 from oscardssmith/os/fix-function-inference
fix function inference
2 parents 2e1a62c + 4b195f5 commit ecdc172

File tree

4 files changed

+28
-21
lines changed

4 files changed

+28
-21
lines changed

src/problems/optimization_problems.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ function OptimizationProblem(
131131
OptimizationProblem{isinplace(f)}(f, args...; kwargs...)
132132
end
133133
function OptimizationProblem(f, args...; kwargs...)
134-
isinplace(f, 2, has_two_dispatches = false)
135-
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
134+
OptimizationProblem(OptimizationFunction(f), args...; kwargs...)
136135
end
137136

138137
function OptimizationFunction(

src/scimlfunctions.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ dt: the time step
948948
949949
```julia
950950
ImplicitDiscreteFunction{iip,specialize}(f;
951-
analytic = __has_analytic(f) ? f.analytic : nothing,
951+
analytic = __has_analytic(f) ? f.analytic : nothing,
952952
resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing)
953953
```
954954
@@ -2107,7 +2107,7 @@ A representation of a ODE function `f` with inputs, defined by:
21072107
```math
21082108
\frac{dx}{dt} = f(x, u, p, t)
21092109
```
2110-
where `x` are the states of the system and `u` are the inputs (which may represent
2110+
where `x` are the states of the system and `u` are the inputs (which may represent
21112111
different things in different contexts, such as control variables in optimal control).
21122112
21132113
Includes all of its related functions, such as the Jacobian of `f`, its gradient
@@ -2134,7 +2134,7 @@ ODEInputFunction{iip, specialize}(f;
21342134
sys = __has_sys(f) ? f.sys : nothing)
21352135
```
21362136
2137-
`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
2137+
`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
21382138
See the section on `iip` for more details on in-place vs out-of-place handling.
21392139
21402140
- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
@@ -4199,7 +4199,10 @@ IntervalNonlinearFunction(f::IntervalNonlinearFunction; kwargs...) = f
41994199
struct NoAD <: AbstractADType end
42004200

42014201
(f::OptimizationFunction)(args...) = f.f(args...)
4202-
OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; kwargs...)
4202+
function OptimizationFunction(f, args...; kwargs...)
4203+
isinplace(f, 2, outofplace_param_number=2)
4204+
OptimizationFunction{true}(f, args...; kwargs...)
4205+
end
42034206

42044207
function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
42054208
grad = nothing, fg = nothing, hess = nothing, hv = nothing, fgh = nothing,
@@ -4251,8 +4254,9 @@ end
42514254
(f::MultiObjectiveOptimizationFunction)(args...) = f.f(args...)
42524255

42534256
# Convenience constructor
4254-
function MultiObjectiveOptimizationFunction(args...; kwargs...)
4255-
MultiObjectiveOptimizationFunction{true}(args...; kwargs...)
4257+
function MultiObjectiveOptimizationFunction(f, args...; kwargs...)
4258+
isinplace(f, 3)
4259+
MultiObjectiveOptimizationFunction{true}(f, args...; kwargs...)
42564260
end
42574261

42584262
# Constructor with keyword arguments
@@ -4339,15 +4343,17 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
43394343
if iip_f
43404344
jac = update_coefficients! #(J,u,p,t)
43414345
else
4342-
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
4346+
jac_prototype_copy = deepcopy(jac_prototype)
4347+
jac = (u, p, t) -> update_coefficients!(jac_prototype_copy, u, p, t)
43434348
end
43444349
end
43454350

43464351
if bcjac === nothing && isa(bcjac_prototype, AbstractSciMLOperator)
43474352
if iip_bc
43484353
bcjac = update_coefficients! #(J,u,p,t)
43494354
else
4350-
bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t)
4355+
bcjac_prototype_copy = deepcopy(bcjac_prototype)
4356+
bcjac = (u, p, t) -> update_coefficients!(bcjac_prototype_copy, u, p, t)
43514357
end
43524358
end
43534359

@@ -4512,15 +4518,17 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc;
45124518
if iip_f
45134519
jac = update_coefficients! #(J,u,p,t)
45144520
else
4515-
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
4521+
jac_prototype_copy = deepcopy(jac_prototype)
4522+
jac = (u, p, t) -> update_coefficients!(jac_prototype_copy, u, p, t)
45164523
end
45174524
end
45184525

45194526
if bcjac === nothing && isa(bcjac_prototype, AbstractSciMLOperator)
45204527
if iip_bc
45214528
bcjac = update_coefficients! #(J,u,p,t)
45224529
else
4523-
bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t)
4530+
bcjac_prototype_copy = deepcopy(jac_prototype)
4531+
bcjac = (u, p, t) -> update_coefficients!(bcjac_prototype_copy, u, p, t)
45244532
end
45254533
end
45264534

test/aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
# for method_ambiguity in ambs
3030
# @show method_ambiguity
3131
# end
32-
@warn "Number of method ambiguities: $(length(ambs))"
32+
!isempty(ambs) &&@warn "Number of method ambiguities: $(length(ambs))"
3333
@test length(ambs) 8
3434
end
3535

test/function_building_error_messages.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ ofboth(u, p, t) = u
5454
ofboth(du, u, p, t) = du .= u
5555

5656
ODEFunction(ofboth)
57-
ODEFunction{true}(ofboth)
58-
ODEFunction{false}(ofboth)
57+
@inferred ODEFunction{true}(ofboth)
58+
@inferred ODEFunction{false}(ofboth)
5959

6060
jac(u, t) = [1.0]
6161
@test_throws SciMLBase.TooFewArgumentsError ODEFunction(fiip, jac = jac)
@@ -428,8 +428,8 @@ nfboth(u, p) = u
428428
nfboth(du, u, p) = du .= u
429429

430430
NonlinearFunction(nfboth)
431-
NonlinearFunction{true}(nfboth)
432-
NonlinearFunction{false}(nfboth)
431+
@inferred NonlinearFunction{true}(nfboth)
432+
@inferred NonlinearFunction{false}(nfboth)
433433

434434
njac(u) = [1.0]
435435
@test_throws SciMLBase.TooFewArgumentsError NonlinearFunction(nfiip, jac = njac)
@@ -520,8 +520,8 @@ bcfboth(u, p, t) = u
520520
bcfboth(du, u, p, t) = du .= u
521521

522522
BVPFunction(bfboth, bcfboth)
523-
BVPFunction{true}(bfboth, bcfboth)
524-
BVPFunction{false}(bfboth, bcfboth)
523+
@inferred BVPFunction{true}(bfboth, bcfboth)
524+
@inferred BVPFunction{false}(bfboth, bcfboth)
525525

526526
bjac(u, t) = [1.0]
527527
bcjac(u, t) = [1.0]
@@ -663,8 +663,8 @@ dbcfboth(du, u, p, t) = u
663663
dbcfboth(res, du, u, p, t) = res .= du .- u
664664

665665
DynamicalBVPFunction(dbfboth, dbcfboth)
666-
DynamicalBVPFunction{true}(dbfboth, dbcfboth)
667-
DynamicalBVPFunction{false}(dbfboth, dbcfboth)
666+
@inferred DynamicalBVPFunction{true}(dbfboth, dbcfboth)
667+
@inferred DynamicalBVPFunction{false}(dbfboth, dbcfboth)
668668

669669
dbjac(du, u, t) = [1.0]
670670
dbcjac(du, u, t) = [1.0]

0 commit comments

Comments
 (0)