Skip to content

Commit e81ac30

Browse files
committed
Document build_explicit_observed_function and allow user-defined array construction
1 parent b52bce7 commit e81ac30

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

src/systems/diffeqs/odesystem.jl

+31-21
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,32 @@ ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...
411411
"""
412412
$(SIGNATURES)
413413
414-
Build the observed function assuming the observed equations are all explicit,
415-
i.e. there are no cycles.
414+
Generates a function that computes the observed value(s) `ts` in the system `sys` assuming that there are no cycles in the equations.
415+
416+
The return value will be either:
417+
* a single function if the input is a scalar or if the input is a Vector but `return_inplace` is false
418+
* the out of place and in-place functions `(ip, oop)` if `return_inplace` is true and the input is a `Vector`
419+
420+
The function(s) will be:
421+
* `RuntimeGeneratedFunction`s by default,
422+
* A Julia `Expr` if `expression` is true,
423+
* A directly evaluated Julia function in the module `eval_module` if `eval_expression` is true
424+
425+
The signatures will be of the form `g(...)` with arguments:
426+
* `output` for in-place functions
427+
* `unknowns` if `params_only` is `false`
428+
* `inputs` if `inputs` is an array of symbolic inputs that should be available in `ts`
429+
* `p...` unconditionally; note that in the case of `MTKParameters` more than one parameters argument may be present, so it must be splatted
430+
* `t` if the system is time-dependent; for example `NonlinearSystem` will not have `t`
431+
For example, a function `g(op, unknowns, p, inputs, t)` will be the in-place function generated if `return_inplace` is true, `ts` is a vector, an array of inputs `inputs` is given, and `params_only` is false for a time-dependent system.
432+
433+
Options not otherwise specified are:
434+
* `output_type = Array` the type of the array generated by the out-of-place vector-valued function
435+
* `checkbounds = true` checks bounds if true when destructuring parameters
436+
* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
437+
* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist
438+
* `drop_expr` is deprecated.
439+
* `mkarray`; only used if the output is an array (that is, `!isscalar(ts)`). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
416440
"""
417441
function build_explicit_observed_function(sys, ts;
418442
inputs = nothing,
@@ -426,7 +450,8 @@ function build_explicit_observed_function(sys, ts;
426450
return_inplace = false,
427451
param_only = false,
428452
op = Operator,
429-
throw = true)
453+
throw = true,
454+
mkarray = MakeArray)
430455
if (isscalar = symbolic_type(ts) !== NotSymbolic())
431456
ts = [ts]
432457
end
@@ -571,12 +596,11 @@ function build_explicit_observed_function(sys, ts;
571596
oop_mtkp_wrapper = mtkparams_wrapper
572597
end
573598

599+
output_expr = isscalar ? ts[1] : mkarray(ts, output_type)
574600
# Need to keep old method of building the function since it uses `output_type`,
575601
# which can't be provided to `build_function`
576-
oop_fn = Func(args, [],
577-
pre(Let(obsexprs,
578-
isscalar ? ts[1] : MakeArray(ts, output_type),
579-
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
602+
oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |>
603+
oop_mtkp_wrapper |> toexpr
580604
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
581605

582606
if !isscalar
@@ -597,20 +621,6 @@ function build_explicit_observed_function(sys, ts;
597621
end
598622
end
599623

600-
function populate_delays(delays::Set, obsexprs, histfn, sys, sym)
601-
_vars_util = vars(sym)
602-
for v in _vars_util
603-
v in delays && continue
604-
iscall(v) && issym(operation(v)) && (args = arguments(v); length(args) == 1) &&
605-
iscall(only(args)) || continue
606-
607-
idx = variable_index(sys, operation(v)(get_iv(sys)))
608-
idx === nothing && error("Delay term $v is not an unknown in the system")
609-
push!(delays, v)
610-
push!(obsexprs, v histfn(only(args))[idx])
611-
end
612-
end
613-
614624
function _eq_unordered(a, b)
615625
length(a) === length(b) || return false
616626
n = length(a)

0 commit comments

Comments
 (0)