Skip to content

Commit 7a82fb2

Browse files
Merge pull request #3573 from vyudu/ss_simplify
[v10] refactor: change inputs/outputs handling in `structural_simplify`
2 parents 0b67a45 + 7a234f1 commit 7a82fb2

15 files changed

+153
-136
lines changed

src/inputoutput.jl

+18-28
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
163163
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166-
disturbance_inputs = nothing;
166+
disturbance_inputs = disturbances(sys);
167167
implicit_dae = false,
168168
simplify = false,
169169
)
@@ -179,9 +179,6 @@ The return values also include the chosen state-realization (the remaining unkno
179179
180180
If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will (by default) not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. To add an input argument corresponding to the disturbance inputs, either include the disturbance inputs among the control inputs, or set `disturbance_argument=true`, in which case an additional input argument `w` is added to the generated function `(x,u,p,t,w)->rhs`.
181181
182-
!!! note "Un-simplified system"
183-
This function expects `sys` to be un-simplified, i.e., `structural_simplify` or `@mtkbuild` should not be called on the system before passing it into this function. `generate_control_function` calls a special version of `structural_simplify` internally.
184-
185182
# Example
186183
187184
```
@@ -200,16 +197,18 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
200197
simplify = false,
201198
eval_expression = false,
202199
eval_module = @__MODULE__,
200+
check_simplified = true,
203201
kwargs...)
202+
# Remove this when the ControlFunction gets merged.
203+
if check_simplified && !iscomplete(sys)
204+
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating the control function.")
205+
end
204206
isempty(inputs) && @warn("No unbound inputs were found in system.")
205-
206207
if disturbance_inputs !== nothing
207208
# add to inputs for the purposes of io processing
208209
inputs = [inputs; disturbance_inputs]
209210
end
210211

211-
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
212-
213212
dvs = unknowns(sys)
214213
ps = parameters(sys; initial_parameters = true)
215214
ps = setdiff(ps, inputs)
@@ -257,8 +256,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
257256
(; f, dvs, ps, io_sys = sys)
258257
end
259258

260-
function inputs_to_parameters!(state::TransformationState, io)
261-
check_bound = io === nothing
259+
"""
260+
Turn input variables into parameters of the system.
261+
"""
262+
function inputs_to_parameters!(state::TransformationState, inputsyms)
263+
check_bound = inputsyms === nothing
262264
@unpack structure, fullvars, sys = state
263265
@unpack var_to_diff, graph, solvable_graph = structure
264266
@assert solvable_graph === nothing
@@ -287,7 +289,7 @@ function inputs_to_parameters!(state::TransformationState, io)
287289
push!(new_fullvars, v)
288290
end
289291
end
290-
ninputs == 0 && return (state, 1:0)
292+
ninputs == 0 && return state
291293

292294
nvars = ndsts(graph) - ninputs
293295
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -316,24 +318,11 @@ function inputs_to_parameters!(state::TransformationState, io)
316318
@set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
317319
ps = parameters(sys)
318320

319-
if io !== nothing
320-
inputs, = io
321-
# Change order of new parameters to correspond to user-provided order in argument `inputs`
322-
d = Dict{Any, Int}()
323-
for (i, inp) in enumerate(new_parameters)
324-
d[inp] = i
325-
end
326-
permutation = [d[i] for i in inputs]
327-
new_parameters = new_parameters[permutation]
328-
end
329-
330321
@set! sys.ps = [ps; new_parameters]
331-
332322
@set! state.sys = sys
333323
@set! state.fullvars = new_fullvars
334324
@set! state.structure = structure
335-
base_params = length(ps)
336-
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
325+
return state
337326
end
338327

339328
"""
@@ -359,7 +348,7 @@ function get_disturbance_system(dist::DisturbanceModel{<:ODESystem})
359348
end
360349

361350
"""
362-
(f_oop, f_ip), augmented_sys, dvs, p = add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
351+
(f_oop, f_ip), augmented_sys, dvs, p = add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[])
363352
364353
Add a model of an unmeasured disturbance to `sys`. The disturbance model is an instance of [`DisturbanceModel`](@ref).
365354
@@ -408,13 +397,13 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i
408397
409398
`f_oop` will have an extra state corresponding to the integrator in the disturbance model. This state will not be affected by any input, but will affect the dynamics from where it enters, in this case it will affect additively from `model.torque.tau.u`.
410399
"""
411-
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kwargs...)
400+
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwargs...)
412401
t = get_iv(sys)
413402
@variables d(t)=0 [disturbance = true]
414403
@variables u(t)=0 [input = true] # New system input
415404
dsys = get_disturbance_system(dist)
416405

417-
if inputs === nothing
406+
if isempty(inputs)
418407
all_inputs = [u]
419408
else
420409
i = findfirst(isequal(dist.input), inputs)
@@ -429,8 +418,9 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kw
429418
dist.input ~ u + dsys.output.u[1]]
430419
augmented_sys = ODESystem(eqs, t, systems = [dsys], name = gensym(:outer))
431420
augmented_sys = extend(augmented_sys, sys)
421+
ssys = structural_simplify(augmented_sys, inputs = all_inputs, disturbance_inputs = [d])
432422

433-
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
423+
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(ssys, all_inputs,
434424
[d]; kwargs...)
435425
(f_oop, f_ip), augmented_sys, dvs, p, io_sys
436426
end

src/linearization.jl

+46-20
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ function linearization_function(sys::AbstractSystem, inputs,
5858
outputs = mapreduce(vcat, outputs; init = []) do var
5959
symbolic_type(var) == ArraySymbolic() ? collect(var) : [var]
6060
end
61-
ssys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs;
62-
simplify,
63-
kwargs...)
61+
ssys = structural_simplify(sys; inputs, outputs, simplify, kwargs...)
62+
diff_idxs, alge_idxs = eq_idxs(ssys)
6463
if zero_dummy_der
6564
dummyder = setdiff(unknowns(ssys), unknowns(sys))
6665
defs = Dict(x => 0.0 for x in dummyder)
@@ -87,9 +86,9 @@ function linearization_function(sys::AbstractSystem, inputs,
8786

8887
p = parameter_values(prob)
8988
t0 = current_time(prob)
90-
inputvals = [p[idx] for idx in input_idxs]
89+
inputvals = [prob.ps[i] for i in inputs]
9190

92-
hp_fun = let fun = h, setter = setp_oop(sys, input_idxs)
91+
hp_fun = let fun = h, setter = setp_oop(sys, inputs)
9392
function hpf(du, input, u, p, t)
9493
p = setter(p, input)
9594
fun(du, u, p, t)
@@ -113,7 +112,7 @@ function linearization_function(sys::AbstractSystem, inputs,
113112
# observed function is a `GeneratedFunctionWrapper` with iip component
114113
h_jac = PreparedJacobian{true}(h, similar(prob.u0, size(outputs)), autodiff,
115114
prob.u0, DI.Constant(p), DI.Constant(t0))
116-
pf_fun = let fun = prob.f, setter = setp_oop(sys, input_idxs)
115+
pf_fun = let fun = prob.f, setter = setp_oop(sys, inputs)
117116
function pff(du, input, u, p, t)
118117
p = setter(p, input)
119118
SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
@@ -127,12 +126,23 @@ function linearization_function(sys::AbstractSystem, inputs,
127126
end
128127

129128
lin_fun = LinearizationFunction(
130-
diff_idxs, alge_idxs, input_idxs, length(unknowns(sys)),
129+
diff_idxs, alge_idxs, inputs, length(unknowns(sys)),
131130
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
132131
hp_jac, initializealg, initialization_kwargs)
133132
return lin_fun, sys
134133
end
135134

135+
"""
136+
Return the set of indexes of differential equations and algebraic equations in the simplified system.
137+
"""
138+
function eq_idxs(sys::AbstractSystem)
139+
eqs = equations(sys)
140+
alge_idxs = findall(!isdiffeq, eqs)
141+
diff_idxs = setdiff(1:length(eqs), alge_idxs)
142+
143+
diff_idxs, alge_idxs
144+
end
145+
136146
"""
137147
$(TYPEDEF)
138148
@@ -192,7 +202,7 @@ A callable struct which linearizes a system.
192202
$(TYPEDFIELDS)
193203
"""
194204
struct LinearizationFunction{
195-
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, II, P <: ODEProblem,
205+
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, I, P <: ODEProblem,
196206
H, C, J1, J2, J3, J4, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
197207
"""
198208
The indexes of differential equations in the linearized system.
@@ -206,7 +216,7 @@ struct LinearizationFunction{
206216
The indexes of parameters in the linearized system which represent
207217
input variables.
208218
"""
209-
input_idxs::II
219+
inputs::I
210220
"""
211221
The number of unknowns in the linearized system.
212222
"""
@@ -281,6 +291,7 @@ function (linfun::LinearizationFunction)(u, p, t)
281291
end
282292

283293
fun = linfun.prob.f
294+
input_vals = [linfun.prob.ps[i] for i in linfun.inputs]
284295
if u !== nothing # Handle systems without unknowns
285296
linfun.num_states == length(u) ||
286297
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
@@ -294,15 +305,15 @@ function (linfun::LinearizationFunction)(u, p, t)
294305
end
295306
fg_xz = linfun.uf_jac(u, DI.Constant(p), DI.Constant(t))
296307
h_xz = linfun.h_jac(u, DI.Constant(p), DI.Constant(t))
297-
fg_u = linfun.pf_jac([p[idx] for idx in linfun.input_idxs],
308+
fg_u = linfun.pf_jac(input_vals,
298309
DI.Constant(u), DI.Constant(p), DI.Constant(t))
299310
else
300311
linfun.num_states == 0 ||
301312
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
302313
fg_xz = zeros(0, 0)
303-
h_xz = fg_u = zeros(0, length(linfun.input_idxs))
314+
h_xz = fg_u = zeros(0, length(linfun.inputs))
304315
end
305-
h_u = linfun.hp_jac([p[idx] for idx in linfun.input_idxs],
316+
h_u = linfun.hp_jac(input_vals,
306317
DI.Constant(u), DI.Constant(p), DI.Constant(t))
307318
(f_x = fg_xz[linfun.diff_idxs, linfun.diff_idxs],
308319
f_z = fg_xz[linfun.diff_idxs, linfun.alge_idxs],
@@ -482,9 +493,8 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
482493
outputs; simplify = false, allow_input_derivatives = false,
483494
eval_expression = false, eval_module = @__MODULE__,
484495
kwargs...)
485-
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(
486-
sys, inputs, outputs; simplify,
487-
kwargs...)
496+
sys = structural_simplify(sys; inputs, outputs, simplify, kwargs...)
497+
diff_idxs, alge_idxs = eq_idxs(sys)
488498
sts = unknowns(sys)
489499
t = get_iv(sys)
490500
ps = parameters(sys; initial_parameters = true)
@@ -545,10 +555,14 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
545555
(; A, B, C, D, f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u), sys
546556
end
547557

548-
function markio!(state, orig_inputs, inputs, outputs; check = true)
558+
"""
559+
Modify the variable metadata of system variables to indicate which ones are inputs, outputs, and disturbances. Needed for `inputs`, `outputs`, `disturbances`, `unbound_inputs`, `unbound_outputs` to return the proper subsets.
560+
"""
561+
function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true)
549562
fullvars = get_fullvars(state)
550563
inputset = Dict{Any, Bool}(i => false for i in inputs)
551564
outputset = Dict{Any, Bool}(o => false for o in outputs)
565+
disturbanceset = Dict{Any, Bool}(d => false for d in disturbances)
552566
for (i, v) in enumerate(fullvars)
553567
if v in keys(inputset)
554568
if v in keys(outputset)
@@ -570,6 +584,13 @@ function markio!(state, orig_inputs, inputs, outputs; check = true)
570584
v = setio(v, false, false)
571585
fullvars[i] = v
572586
end
587+
588+
if v in keys(disturbanceset)
589+
v = setio(v, true, false)
590+
v = setdisturbance(v, true)
591+
disturbanceset[v] = true
592+
fullvars[i] = v
593+
end
573594
end
574595
if check
575596
ikeys = keys(filter(!last, inputset))
@@ -578,11 +599,16 @@ function markio!(state, orig_inputs, inputs, outputs; check = true)
578599
"Some specified inputs were not found in system. The following variables were not found ",
579600
ikeys)
580601
end
602+
dkeys = keys(filter(!last, disturbanceset))
603+
if !isempty(dkeys)
604+
error(
605+
"Specified disturbance inputs were not found in system. The following variables were not found ",
606+
ikeys)
607+
end
608+
(all(values(outputset)) || error(
609+
"Some specified outputs were not found in system. The following Dict indicates the found variables ",
610+
outputset))
581611
end
582-
check && (all(values(outputset)) ||
583-
error(
584-
"Some specified outputs were not found in system. The following Dict indicates the found variables ",
585-
outputset))
586612
state, orig_inputs
587613
end
588614

src/systems/abstractsystem.jl

-15
Original file line numberDiff line numberDiff line change
@@ -2484,21 +2484,6 @@ function eliminate_constants(sys::AbstractSystem)
24842484
return sys
24852485
end
24862486

2487-
function io_preprocessing(sys::AbstractSystem, inputs,
2488-
outputs; simplify = false, kwargs...)
2489-
sys, input_idxs = structural_simplify(sys, (inputs, outputs); simplify, kwargs...)
2490-
2491-
eqs = equations(sys)
2492-
alg_start_idx = findfirst(!isdiffeq, eqs)
2493-
if alg_start_idx === nothing
2494-
alg_start_idx = length(eqs) + 1
2495-
end
2496-
diff_idxs = 1:(alg_start_idx - 1)
2497-
alge_idxs = alg_start_idx:length(eqs)
2498-
2499-
sys, diff_idxs, alge_idxs, input_idxs
2500-
end
2501-
25022487
@latexrecipe function f(sys::AbstractSystem)
25032488
return latexify(equations(sys))
25042489
end

src/systems/analysis_points.jl

-1
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,6 @@ function linearization_function(sys::AbstractSystem,
958958
end
959959

960960
sys = handle_loop_openings(sys, map(AnalysisPoint, collect(loop_openings)))
961-
962961
return linearization_function(system_modifier(sys), input_vars, output_vars; kwargs...)
963962
end
964963

src/systems/clock_inference.jl

+13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
struct ClockInference{S}
2+
"""Tearing state."""
23
ts::S
4+
"""The time domain (discrete clock, continuous) of each equation."""
35
eq_domain::Vector{TimeDomain}
6+
"""The output time domain (discrete clock, continuous) of each variable."""
47
var_domain::Vector{TimeDomain}
8+
"""The set of variables with concrete domains."""
59
inferred::BitSet
610
end
711

@@ -67,6 +71,9 @@ function substitute_sample_time(ex, dt)
6771
end
6872
end
6973

74+
"""
75+
Update the equation-to-time domain mapping by inferring the time domain from the variables.
76+
"""
7077
function infer_clocks!(ci::ClockInference)
7178
@unpack ts, eq_domain, var_domain, inferred = ci
7279
@unpack var_to_diff, graph = ts.structure
@@ -132,6 +139,9 @@ function is_time_domain_conversion(v)
132139
input_timedomain(o) != output_timedomain(o)
133140
end
134141

142+
"""
143+
For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain.
144+
"""
135145
function split_system(ci::ClockInference{S}) where {S}
136146
@unpack ts, eq_domain, var_domain, inferred = ci
137147
fullvars = get_fullvars(ts)
@@ -143,11 +153,14 @@ function split_system(ci::ClockInference{S}) where {S}
143153
cid_to_eq = Vector{Int}[]
144154
var_to_cid = Vector{Int}(undef, ndsts(graph))
145155
cid_to_var = Vector{Int}[]
156+
# cid_counter = number of clocks
146157
cid_counter = Ref(0)
147158
for (i, d) in enumerate(eq_domain)
148159
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
149160
continuous_id = continuous_id
150161

162+
# Fill the clock_to_id dict as you go,
163+
# ContinuousClock() => 1, ...
151164
get!(clock_to_id, d) do
152165
cid = (cid_counter[] += 1)
153166
push!(id_to_clock, d)

0 commit comments

Comments
 (0)