Skip to content

fix: copy initials to u0 if u0 not provided to remake #3572

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 10 commits into from
Apr 29, 2025
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DelayDiffEq = "5.50"
DiffEqBase = "6.165.1"
DiffEqBase = "6.170.1"
DiffEqCallbacks = "2.16, 3, 4"
DiffEqNoiseProcess = "5"
DiffRules = "0.1, 1.0"
Expand Down
73 changes: 54 additions & 19 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,26 +589,41 @@ function SciMLBase.remake_initialization_data(
return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
end

function promote_u0_p(u0, p::MTKParameters, t0)
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)

tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)

return u0, p
end

function promote_u0_p(u0, p::AbstractArray, t0)
return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0)
end

function SciMLBase.late_binding_update_u0_p(
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
supports_initialization(sys) || return newu0, newp
u0 === missing && return newu0, (p === missing ? copy(newp) : newp)

initdata = prob.f.initialization_data
meta = initdata === nothing ? nothing : initdata.metadata

newu0, newp = promote_u0_p(newu0, newp, t0)

# non-symbolic u0 updates initials...
if !(eltype(u0) <: Pair)
# if `p` is not provided or is symbolic
p === missing || eltype(p) <: Pair || return newu0, newp
(newu0 === nothing || isempty(newu0)) && return newu0, newp
initdata = prob.f.initialization_data
initdata === nothing && return newu0, newp
meta = initdata.metadata
meta isa InitializationMetadata || return newu0, newp
newp = p === missing ? copy(newp) : newp
initials, repack, alias = SciMLStructures.canonicalize(
SciMLStructures.Initials(), newp)
if eltype(initials) != eltype(newu0)
initials = DiffEqBase.promote_u0(initials, newu0, t0)
newp = repack(initials)
end

if length(newu0) != length(prob.u0)
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
end
Expand All @@ -617,17 +632,6 @@ function SciMLBase.late_binding_update_u0_p(
end

newp = p === missing ? copy(newp) : newp
newu0 = DiffEqBase.promote_u0(newu0, newp, t0)
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
if eltype(tunables) != eltype(newu0)
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
newp = repack(tunables)
end
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
if eltype(initials) != eltype(newu0)
initials = DiffEqBase.promote_u0(initials, newu0, t0)
newp = repack(initials)
end

allsyms = all_symbols(sys)
for (k, v) in u0
Expand All @@ -646,6 +650,37 @@ function SciMLBase.late_binding_update_u0_p(
return newu0, newp
end

function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...)
supports_initialization(sys) || return prob
initdata = prob.f.initialization_data
initdata isa SciMLBase.OverrideInitData || return prob
meta = initdata.metadata
meta isa InitializationMetadata || return prob
meta.get_updated_u0 === nothing && return prob

u0 = state_values(prob)
u0 === nothing && return prob

p = parameter_values(prob)
t0 = is_time_dependent(prob) ? current_time(prob) : nothing

if p isa MTKParameters
buffer = p.initials
else
buffer = p
end

u0 = DiffEqBase.promote_u0(u0, buffer, t0)

if ArrayInterface.ismutable(u0)
T = typeof(u0)
else
T = StaticArrays.similar_type(u0)
end

return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)))
end

"""
$(TYPEDSIGNATURES)

Expand Down
59 changes: 57 additions & 2 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ properly.

$(TYPEDFIELDS)
"""
struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU}
"""
The `u0map` used to construct the initialization.
"""
Expand All @@ -796,12 +796,62 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
"""
oop_reconstruct_u0_p::R
"""
A function which takes `(prob, initializeprob)` and return the `u0` to use for the problem.
"""
get_updated_u0::GUU
"""
A function which takes the `u0` of the problem and sets
`Initial.(unknowns(sys))`.
"""
set_initial_unknowns!::SIU
end

"""
$(TYPEDEF)

A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`.
Returns the value to use for the `u0` of the problem.

# Fields

$(TYPEDFIELDS)
"""
struct GetUpdatedU0{GG, GIU}
"""
Mask with length `length(unknowns(sys))` denoting indices of variables which should
take the guess value from `initializeprob`.
"""
guessvars::BitVector
"""
Function which returns the values of variables in `initializeprob` for which
`guessvars` is `true`, in the order they occur in `unknowns(sys)`.
"""
get_guessvars::GG
"""
Function which returns `Initial.(unknowns(sys))` as a `Vector`.
"""
get_initial_unknowns::GIU
end

function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
dvs = unknowns(sys)
eqs = equations(sys)
guessvars = trues(length(dvs))
for (i, var) in enumerate(dvs)
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
end
get_guessvars = getu(initsys, dvs[guessvars])
get_initial_unknowns = getu(sys, Initial.(dvs))
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
end

function (guu::GetUpdatedU0)(prob, initprob)
buffer = guu.get_initial_unknowns(prob)
algebuf = view(buffer, guu.guessvars)
copyto!(algebuf, guu.get_guessvars(initprob))
return buffer
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -840,10 +890,15 @@ function maybe_build_initialization_problem(
end
initializeprob = remake(initializeprob; p = initp)

get_initial_unknowns = if is_time_dependent(sys)
GetUpdatedU0(sys, initializeprob.f.sys, op)
else
nothing
end
meta = InitializationMetadata(
u0map, pmap, guesses, Vector{Equation}(initialization_eqs),
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
setp(sys, Initial.(unknowns(sys))))
get_initial_unknowns, setp(sys, Initial.(unknowns(sys))))

if is_time_dependent(sys)
all_init_syms = Set(all_symbols(initializeprob))
Expand Down
51 changes: 51 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1512,3 +1512,54 @@ end
@inferred remake(prob; u0 = 2 .* prob.u0, p = prob.p)
@inferred solve(prob)
end

@testset "Issue#3570, #3552: `Initial`s/guesses are copied to `u0` during `solve`/`init`" begin
@parameters g
@variables x(t) [state_priority = 10] y(t) λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)

prob = ODEProblem(
pend, [x => (√2 / 2), D(x) => 0.0], (0.0, 1.5),
[g => 1], guesses = [λ => 1, y => √2 / 2])
sol = solve(prob)

@testset "Guesses of initialization problem copied to algebraic variables" begin
prob.f.initialization_data.initializeprob[λ] = 1.0
prob2 = DiffEqBase.get_updated_symbolic_problem(
pend, prob; u0 = prob.u0, p = prob.p)
@test prob2[λ] ≈ 1.0
end

@testset "Initial values for algebraic variables are retained" begin
prob2 = ODEProblem(
pend, [x => (√2 / 2), D(y) => 0.0], (0.0, 1.5),
[g => 1], guesses = [λ => 1, y => √2 / 2])
sol = solve(prob)
@test SciMLBase.successful_retcode(sol)
prob3 = DiffEqBase.get_updated_symbolic_problem(
pend, prob2; u0 = prob2.u0, p = prob2.p)
@test prob3[D(y)] ≈ 0.0
end

@testset "`setsym_oop`" begin
setter = setsym_oop(prob, [Initial(x)])
(u0, p) = setter(prob, [0.8])
new_prob = remake(prob; u0, p, initializealg = BrownFullBasicInit())
new_sol = solve(new_prob)
@test new_sol[x, 1] ≈ 0.8
integ = init(new_prob)
@test integ[x] ≈ 0.8
end

@testset "`setsym`" begin
@test prob.ps[Initial(x)] ≈ √2 / 2
prob.ps[Initial(x)] = 0.8
sol = solve(prob; initializealg = BrownFullBasicInit())
@test sol[x, 1] ≈ 0.8
integ = init(prob; initializealg = BrownFullBasicInit())
@test integ[x] ≈ 0.8
end
end
Loading