Skip to content

feat: add cost and coalesce to ODESystem #3531

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 8 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,9 @@ for prop in [:eqs
:tstops
:index_cache
:is_scalar_noise
:isscheduled]
:isscheduled
:costs
:consolidate]
fname_get = Symbol(:get_, prop)
fname_has = Symbol(:has_, prop)
@eval begin
Expand Down
33 changes: 27 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ end
```julia
DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
allow_cost = false,
version = nothing, tgrad = false,
jac = false,
checkbounds = false, sparse = false,
Expand Down Expand Up @@ -730,6 +731,7 @@ end
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
allow_cost = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
Expand All @@ -745,6 +747,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
Consider a BVProblem instead.")
end

if !isempty(get_costs(sys)) && !allow_cost
error("ODEProblem will not optimize solutions of ODESystems that have associated cost functions.
Solvers for optimal control problems are forthcoming. In order to bypass this error (e.g.
to check the cost of a regular solution), pass `allow_cost` = true into the constructor.")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
Expand Down Expand Up @@ -796,21 +804,19 @@ If an ODESystem without `constraints` is specified, it will be treated as an ini

```julia
@parameters g t_c = 0.5
@variables x(..) y(t) [state_priority = 10] λ(t)
@variables x(..) y(t) λ(t)
eqs = [D(D(x(t))) ~ λ * x(t)
D(D(y)) ~ λ * y - g
x(t)^2 + y^2 ~ 1]
cstr = [x(0.5) ~ 1]
@named cstrs = ConstraintsSystem(cstr, t)
@mtkbuild pend = ODESystem(eqs, t)
@mtkbuild pend = ODESystem(eqs, t; constraints = cstrs)

tspan = (0.0, 1.5)
u0map = [x(t) => 0.6, y => 0.8]
parammap = [g => 1]
guesses = [λ => 1]
constraints = [x(0.5) ~ 1]

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
```

If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
Expand Down Expand Up @@ -839,6 +845,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
guesses = Dict(),
allow_cost = false,
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
Expand All @@ -852,6 +859,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

if !isempty(get_costs(sys)) && !allow_cost
error("BVProblem will not optimize solutions of ODESystems that have associated cost functions.
Solvers for optimal control problems are forthcoming. In order to bypass this error (e.g.
to check the cost of a regular solution), pass `allow_cost` = true into the constructor.")
end

has_alg_eqs(sys) &&
error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

Expand Down Expand Up @@ -924,7 +937,7 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
exprs = vcat(init_conds, cons)
_p = reorder_parameters(sys, ps)

build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
build_function_wrapper(sys, exprs, sol, _p..., iv; output_type = Array, kwargs...)
end

"""
Expand All @@ -948,11 +961,19 @@ end

function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
parammap = DiffEqBase.NullParameters();
allow_cost = false,
warn_initialize_determined = true,
check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
end

if !isempty(get_costs(sys)) && !allow_cost
error("DAEProblem will not optimize solutions of ODESystems that have associated cost functions.
Solvers for optimal control problems are forthcoming. In order to bypass this error (e.g.
to check the cost of a regular solution), pass `allow_cost` = true into the constructor.")
end

f, du0, u0, p = process_SciMLProblem(DAEFunction{iip}, sys, u0map, parammap;
implicit_dae = true, du0map = du0map, check_length,
t = tspan !== nothing ? tspan[1] : tspan,
Expand Down
112 changes: 97 additions & 15 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
observed::Vector{Equation}
"""System of constraints that must be satisfied by the solution to the system."""
constraintsystem::Union{Nothing, ConstraintsSystem}
"""A set of expressions defining the costs of the system for optimal control."""
costs::Vector
"""Takes the cost vector and returns a scalar for optimization."""
consolidate::Union{Nothing, Function}
"""
Time-derivative matrix. Note: this field will not be defined until
[`calculate_tgrad`](@ref) is called on the system.
Expand Down Expand Up @@ -205,7 +209,8 @@ struct ODESystem <: AbstractODESystem
parent::Any

function ODESystem(
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls,
observed, constraints, costs, consolidate, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
torn_matching, initializesystem, initialization_eqs, schedule,
connector_type, preface, cevents,
Expand All @@ -229,7 +234,7 @@ struct ODESystem <: AbstractODESystem
check_units(u, deqs)
end
new(tag, deqs, iv, dvs, ps, tspan, var_to_name,
ctrls, observed, constraints, tgrad, jac,
ctrls, observed, constraints, costs, consolidate, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, assertions, metadata,
Expand All @@ -243,6 +248,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Equation[],
constraintsystem = nothing,
costs = Num[],
consolidate = nothing,
systems = ODESystem[],
tspan = nothing,
name = nothing,
Expand Down Expand Up @@ -323,22 +330,27 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
cons = get_constraintsystem(sys)
cons !== nothing && push!(conssystems, cons)
end
@show conssystems
@set! constraintsystem.systems = conssystems
end
costs = wrap.(costs)

if length(costs) > 1 && isnothing(consolidate)
error("Must specify a consolidation function for the costs vector.")
end

assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)

ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed,
constraintsystem, costs, consolidate, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, nothing, initializesystem,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies, assertions,
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
Expand Down Expand Up @@ -384,8 +396,16 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
end
end

if !isempty(costs)
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
for p in costps
!in(p, new_ps) && push!(new_ps, p)
end
end
costs = wrap.(costs)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
collect(new_ps); constraintsystem, kwargs...)
collect(new_ps); constraintsystem, costs, kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand All @@ -400,7 +420,9 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
_eq_unordered(continuous_events(sys1), continuous_events(sys2)) &&
_eq_unordered(discrete_events(sys1), discrete_events(sys2)) &&
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) &&
isequal(get_constraintsystem(sys1), get_constraintsystem(sys2)) &&
_eq_unordered(get_costs(sys1), get_costs(sys2))
end

function flatten(sys::ODESystem, noeqs = false)
Expand Down Expand Up @@ -734,22 +756,53 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
return nothing
end

# Validate that all the variables in the BVP constraints are well-formed states or parameters.
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
"""
Build the constraint system for the ODESystem.
"""
function process_constraint_system(
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
isempty(constraints) && return nothing

constraintsts = OrderedSet()
constraintps = OrderedSet()

for cons in constraints
collect_vars!(constraintsts, constraintps, cons, iv)
end

# Validate the states.
for var in constraintsts
validate_vars_and_find_ps!(constraintsts, constraintps, sts, iv)

ConstraintsSystem(
constraints, collect(constraintsts), collect(constraintps); name = consname)
end

"""
Process the costs for the constraint system.
"""
function process_costs(costs::Vector, sts, ps, iv)
coststs = OrderedSet()
costps = OrderedSet()
for cost in costs
collect_vars!(coststs, costps, cost, iv)
end

validate_vars_and_find_ps!(coststs, costps, sts, iv)
coststs, costps
end

"""
Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
well-formed states or parameters.
- Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
- Callable/delay parameters should be parameters of the system

Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
parameter of the system.
"""
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
sts = sysvars

for var in auxvars
if !iscall(var)
occursin(iv, var) && (var ∈ sts ||
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
Expand All @@ -764,13 +817,42 @@ function process_constraint_system(
arg isa AbstractFloat ||
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

isparameter(arg) && push!(constraintps, arg)
isparameter(arg) && push!(auxps, arg)
else
var ∈ sts &&
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
end
end
end

ConstraintsSystem(
constraints, collect(constraintsts), collect(constraintps); name = consname)
"""
Generate a function that takes a solution object and computes the cost function obtained by coalescing the costs vector.
"""
function generate_cost_function(sys::ODESystem, kwargs...)
costs = get_costs(sys)
consolidate = get_consolidate(sys)
iv = get_iv(sys)

ps = parameters(sys; initial_parameters = false)
sts = unknowns(sys)
np = length(ps)
ns = length(sts)
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])

@variables sol(..)[1:ns]
for st in vars(costs)
x = operation(st)
t = only(arguments(st))
idx = stidxmap[x(iv)]

costs = map(c -> Symbolics.fast_substitute(c, Dict(x(t) => sol(t)[idx])), costs)
end

_p = reorder_parameters(sys, ps)
fs = build_function_wrapper(sys, costs, sol, _p..., t; output_type = Array, kwargs...)
vc_oop, vc_iip = eval_or_rgf.(fs)

cost(sol, p, t) = consolidate(vc_oop(sol, p, t))
return cost
end
Loading
Loading