Skip to content

fix: handle derivatives of time-dependent parameters #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 24, 2025
1 change: 1 addition & 0 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
D(eq.lhs)
end
rhs = ModelingToolkit.expand_derivatives(D(eq.rhs))
rhs = fast_substitute(rhs, state.param_derivative_map)
substitution_dict = Dict(x.lhs => x.rhs
for x in out_eqs if x !== nothing && x.lhs isa Symbolic)
sub_rhs = substitute(rhs, substitution_dict)
Expand Down
18 changes: 17 additions & 1 deletion src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,23 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)

sys = ts.sys
eq = equations(ts)[ieq]
eq = 0 ~ Symbolics.derivative(eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true)
eq = 0 ~ fast_substitute(
ModelingToolkit.derivative(
eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map)

vs = ModelingToolkit.vars(eq.rhs)
for v in vs
# parameters with unknown derivatives have a value of `nothing` in the map,
# so use `missing` as the default.
get(ts.param_derivative_map, v, missing) === nothing || continue
_original_eq = equations(ts)[ieq]
error("""
Encountered derivative of discrete variable `$(only(arguments(v)))` when \
differentiating equation `$(_original_eq)`. This may indicate a model error or a \
missing equation of the form `$v ~ ...` that defines this derivative.
""")
end

push!(equations(ts), eq)
# Analyze the new equation and update the graph/solvable_graph
# First, copy the previous incidence and add the derivative terms.
Expand Down
31 changes: 30 additions & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
fullvars::Vector
structure::SystemStructure
extra_eqs::Vector
param_derivative_map::Dict{BasicSymbolic, Any}
end

TransformationState(sys::AbstractSystem) = TearingState(sys)
Expand Down Expand Up @@ -253,6 +254,12 @@ function Base.push!(ev::EquationsView, eq)
push!(ev.ts.extra_eqs, eq)
end

function is_time_dependent_parameter(p, iv)
return iv !== nothing && isparameter(p) && iscall(p) &&
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], iv) ||
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
end

function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
Expand All @@ -264,6 +271,7 @@ function TearingState(sys; quick_cancel = false, check = true)
var2idx = Dict{Any, Int}()
symbolic_incidence = []
fullvars = []
param_derivative_map = Dict{BasicSymbolic, Any}()
var_counter = Ref(0)
var_types = VariableType[]
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
Expand All @@ -276,11 +284,23 @@ function TearingState(sys; quick_cancel = false, check = true)

vars = OrderedSet()
varsvec = []
eqs_to_retain = trues(length(eqs))
for (i, eq′) in enumerate(eqs)
if eq′.lhs isa Connection
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
return nothing
end
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
# we want to store `nothing` in the map because that means `fast_substitute`
# will ignore the rule. We will this identify the presence of `eq′.lhs` in
# the differentiated expression and error.
param_derivative_map[eq′.lhs] = coalesce(eq′.rhs, nothing)
eqs_to_retain[i] = false
# change the equation if the RHS is `missing` so the rest of this loop works
eq′ = eq′.lhs ~ coalesce(eq′.rhs, 0.0)
end
if _iszero(eq′.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = eq′
Expand All @@ -295,6 +315,12 @@ function TearingState(sys; quick_cancel = false, check = true)
any(isequal(_var), ivs) && continue
if isparameter(_var) ||
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
if is_time_dependent_parameter(_var, iv) &&
!haskey(param_derivative_map, Differential(iv)(_var))
# Parameter derivatives default to zero - they stay constant
# between callbacks
param_derivative_map[Differential(iv)(_var)] = 0.0
end
continue
end
v = scalarize(v)
Expand Down Expand Up @@ -351,6 +377,9 @@ function TearingState(sys; quick_cancel = false, check = true)
eqs[i] = eqs[i].lhs ~ rhs
end
end
eqs = eqs[eqs_to_retain]
neqs = length(eqs)
symbolic_incidence = symbolic_incidence[eqs_to_retain]

### Handle discrete variables
lowest_shift = Dict()
Expand Down Expand Up @@ -438,7 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true)
ts = TearingState(sys, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
Any[])
Any[], param_derivative_map)
if sys isa DiscreteSystem
ts = shift_discrete_system(ts)
end
Expand Down
2 changes: 1 addition & 1 deletion test/state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ModelingToolkit, OrdinaryDiffEq, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

sts = @variables x1(t) x2(t) x3(t) x4(t)
params = @parameters u1(t) u2(t) u3(t) u4(t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe

params = @parameters u1 u2 u3 u4
eqs = [x1 + x2 + u1 ~ 0
x1 + x2 + x3 + u2 ~ 0
x1 + D(x3) + x4 + u3 ~ 0
Expand Down
109 changes: 109 additions & 0 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SparseArrays
using UnPack
using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm
using Symbolics: unwrap
using DataInterpolations
const ST = StructuralTransformations

# Define some variables
Expand Down Expand Up @@ -282,3 +283,111 @@ end
@test length(mapping) == 3
end
end

@testset "Issue#3480: Derivatives of time-dependent parameters" begin
@component function FilteredInput(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t) = x0
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@component function FilteredInputExplicit(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t)[1:1] = [x0]
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
D(k[1]) ~ 1.0
dx ~ (k[1] - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@component function FilteredInputErr(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t) = x0
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k - x) / T
D(k) ~ missing]
return ODESystem(eqs, t, vars, params; systems, name)
end

@named sys = FilteredInputErr()
@test_throws ["derivative of discrete variable", "k(t)"] structural_simplify(sys)

@mtkbuild sys = FilteredInput()
vs = Set()
for eq in equations(sys)
ModelingToolkit.vars!(vs, eq)
end
for eq in observed(sys)
ModelingToolkit.vars!(vs, eq)
end

@test !(D(sys.k) in vs)

@mtkbuild sys = FilteredInputExplicit()
obsfn1 = ModelingToolkit.build_explicit_observed_function(sys, sys.ddx)
obsfn2 = ModelingToolkit.build_explicit_observed_function(sys, sys.dx)
u = [1.0]
p = MTKParameters(sys, [sys.k => [2.0], sys.T => 3.0])
@test obsfn1(u, p, 0.0) ≈ (1 - obsfn2(u, p, 0.0)) / 3.0

@testset "Called parameter still has derivative" begin
@component function FilteredInput2(; name, x0 = 0, T = 0.1)
ts = collect(0.0:0.1:10.0)
spline = LinearInterpolation(ts .^ 2, ts)
params = @parameters begin
(k::LinearInterpolation)(..) = spline
T = T
end
vars = @variables begin
x(t) = k(t)
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k(t) - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@mtkbuild sys = FilteredInput2()
vs = Set()
for eq in equations(sys)
ModelingToolkit.vars!(vs, eq)
end
for eq in observed(sys)
ModelingToolkit.vars!(vs, eq)
end

@test D(sys.k(t)) in vs
end
end
Loading