diff --git a/src/Blocks/Blocks.jl b/src/Blocks/Blocks.jl index cf0679a37..9bf4f8580 100644 --- a/src/Blocks/Blocks.jl +++ b/src/Blocks/Blocks.jl @@ -11,6 +11,7 @@ using ModelingToolkit: getdefault D = Differential(t) export RealInput, RealOutput, SISO +using Symbolics: Struct, StructElement, getelements, symstruct include("utils.jl") export Gain, Sum, MatrixGain, Feedback, Add, Add3, Product, Division diff --git a/src/Blocks/utils.jl b/src/Blocks/utils.jl index d6200d85b..5b224dd06 100644 --- a/src/Blocks/utils.jl +++ b/src/Blocks/utils.jl @@ -107,3 +107,40 @@ Base class for a multiple input multiple output (MIMO) continuous system block. ] return ODESystem(eqs, t, vcat(u..., y...), []; name = name, systems = [input, output]) end + + + +using Symbolics: Struct, symbolic_getproperty +@connector function StructInput(; structdef, name) + @variables u(t)::Struct [input = true] # Dummy default value due to bug in Symbolics + ODESystem(Equation[], t, [u], []; name) +end + +@connector function StructOutput(; structdef, name) + @variables u(t)::Struct [output = true] # Dummy default value due to bug in Symbolics + ODESystem(Equation[], t, [u], []; name) +end + +function _structelem2connector(elem::StructElement) + T = Symbolics.decodetyp(elem.typ) + if T <: Bool + return BoolOutput(; name = elem.name) + elseif T <: Real + return RealOutput(; name = elem.name) + end +end + +@component function BusSelect(;name, structdef, selected_fields) + @parameters t + nout = length(selected_fields) + inputbus = Blocks.StructInput(; structdef, name = Symbol("inputbus")) + @variables input(t) + + output_elements = filter(e->e.name in selected_fields, getelements(structdef)) + output_connectors = map(_structelem2connector, output_elements) + + eqs = [ + symbolic_getproperty(inputbus.u, field) ~ con.u for (field, con) in zip(selected_fields, output_connectors) + ] + return ODESystem(eqs, t; name = name, systems = [inputbus; output_connectors]) +end diff --git a/test/Blocks/sources.jl b/test/Blocks/sources.jl index b1df8f21b..27e09e680 100644 --- a/test/Blocks/sources.jl +++ b/test/Blocks/sources.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, ModelingToolkitStandardLibrary, OrdinaryDiffEq +using ModelingToolkit, ModelingToolkitStandardLibrary, OrdinaryDiffEq, Test using ModelingToolkitStandardLibrary.Blocks using ModelingToolkitStandardLibrary.Blocks: smooth_sin, smooth_cos, smooth_damped_sin, smooth_square, smooth_step, smooth_ramp, @@ -474,3 +474,58 @@ end @test sol[ddy][end]≈2 atol=1e-3 end end + +using Symbolics +using Symbolics: Struct, StructElement, getelements, symstruct +using Test +using ModelingToolkitStandardLibrary.Blocks +using ModelingToolkitStandardLibrary.Blocks: BusSelect +#using ModelingToolkitStandardLibrary.Blocks: structelem2connector + +# Test struct +struct BarStruct + speed::Float64 + isSpeedValid::Int +end + +bar = BarStruct(2.0, 1) +structdef = symstruct(BarStruct) +selected_fields = [:speed] + +@parameters bar_param::Struct +systems = @named begin + inputbus = Blocks.StructOutput(; structdef) + output = BusSelect(; structdef, selected_fields) +end +eqs = [inputbus.u ~ bar_param + connect(inputbus, output.inputbus)] +@named sys = ODESystem(eqs, t; systems) +sys = complete(sys) +ssys = structural_simplify(sys) +prob = ODEProblem(ssys, [], (0.0, 1.0), [sys.bar_param => bar], tofloat=false) +sol = solve(prob, Rodas4()) +@test sol(1.0, idxs = sys.output.speed.u) == 2.0 + +@mtkmodel BusSelectTest begin + @parameters bar_param::Struct + @components begin + inputbus = Blocks.StructOutput(; structdef) + output = BusSelect(; structdef, selected_fields) + end + @equations begin + inputbus.u ~ bar_param + connect(inputbus, output.inputbus) + end +end + +@named sys = BusSelectTest() +sys = complete(sys) +ssys = structural_simplify(sys) +@test_broken begin + prob = ODEProblem(ssys, [ + sys.bar_param => bar + ], (0.0, 1.0)) + sol = solve(prob, Rodas4()) + @test sol.retcode == ReturnCode.Success + @test sol(1.0, idxs = sys.output.speed.u) == 2.0 +end