Skip to content

Commit b52bce7

Browse files
Merge pull request #3192 from AayushSabharwal/as/hc-cache-startsys
feat: cache start system and solver in HomotopyContinuation interface
2 parents d25a060 + d284978 commit b52bce7

File tree

6 files changed

+165
-34
lines changed

6 files changed

+165
-34
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2121
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2222
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2323
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
24+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2425
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
2526
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
2627
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
@@ -94,6 +95,7 @@ Distributions = "0.23, 0.24, 0.25"
9495
DocStringExtensions = "0.7, 0.8, 0.9"
9596
DomainSets = "0.6, 0.7"
9697
DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
98+
EnumX = "1.0.4"
9799
ExprTools = "0.1.10"
98100
Expronicon = "0.8"
99101
FindFirstFunctions = "1"

ext/MTKHomotopyContinuationExt.jl

Lines changed: 126 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,104 @@ function contains_variable(x, wrt)
1515
any(y -> occursin(y, x), wrt)
1616
end
1717

18+
"""
19+
Possible reasons why a term is not polynomial
20+
"""
21+
MTK.EnumX.@enumx NonPolynomialReason begin
22+
NonIntegerExponent
23+
ExponentContainsUnknowns
24+
BaseNotPolynomial
25+
UnrecognizedOperation
26+
end
27+
28+
function display_reason(reason::NonPolynomialReason.T, sym)
29+
if reason == NonPolynomialReason.NonIntegerExponent
30+
pow = arguments(sym)[2]
31+
"In $sym: Exponent $pow is not an integer"
32+
elseif reason == NonPolynomialReason.ExponentContainsUnknowns
33+
pow = arguments(sym)[2]
34+
"In $sym: Exponent $pow contains unknowns of the system"
35+
elseif reason == NonPolynomialReason.BaseNotPolynomial
36+
base = arguments(sym)[1]
37+
"In $sym: Base $base is not a polynomial in the unknowns"
38+
elseif reason == NonPolynomialReason.UnrecognizedOperation
39+
op = operation(sym)
40+
"""
41+
In $sym: Operation $op is not recognized. Allowed polynomial operations are \
42+
`*, /, +, -, ^`.
43+
"""
44+
else
45+
error("This should never happen. Please open an issue in ModelingToolkit.jl.")
46+
end
47+
end
48+
49+
mutable struct PolynomialData
50+
non_polynomial_terms::Vector{BasicSymbolic}
51+
reasons::Vector{NonPolynomialReason.T}
52+
has_parametric_exponent::Bool
53+
end
54+
55+
PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)
56+
57+
struct NotPolynomialError <: Exception
58+
eq::Equation
59+
data::PolynomialData
60+
end
61+
62+
function Base.showerror(io::IO, err::NotPolynomialError)
63+
println(io,
64+
"Equation $(err.eq) is not a polynomial in the unknowns for the following reasons:")
65+
for (term, reason) in zip(err.data.non_polynomial_terms, err.data.reasons)
66+
println(io, display_reason(reason, term))
67+
end
68+
end
69+
70+
function is_polynomial!(data, y, wrt)
71+
process_polynomial!(data, y, wrt)
72+
isempty(data.reasons)
73+
end
74+
1875
"""
1976
$(TYPEDSIGNATURES)
2077
21-
Check if `x` is polynomial with respect to the variables in `wrt`.
78+
Return information about the polynmial `x` with respect to variables in `wrt`,
79+
writing said information to `data`.
2280
"""
23-
function is_polynomial(x, wrt)
81+
function process_polynomial!(data::PolynomialData, x, wrt)
2482
x = unwrap(x)
2583
symbolic_type(x) == NotSymbolic() && return true
2684
iscall(x) || return true
2785
contains_variable(x, wrt) || return true
2886
any(isequal(x), wrt) && return true
2987

3088
if operation(x) in (*, +, -, /)
31-
return all(y -> is_polynomial(y, wrt), arguments(x))
89+
return all(y -> is_polynomial!(data, y, wrt), arguments(x))
3290
end
3391
if operation(x) == (^)
3492
b, p = arguments(x)
3593
is_pow_integer = symtype(p) <: Integer
3694
if !is_pow_integer
37-
if symbolic_type(p) == NotSymbolic()
38-
@warn "In $x: Exponent $p is not an integer"
39-
else
40-
@warn "In $x: Exponent $p is not an integer. Use `@parameters p::Integer` to declare integer parameters."
41-
end
95+
push!(data.non_polynomial_terms, x)
96+
push!(data.reasons, NonPolynomialReason.NonIntegerExponent)
97+
end
98+
if symbolic_type(p) != NotSymbolic()
99+
data.has_parametric_exponent = true
42100
end
101+
43102
exponent_has_unknowns = contains_variable(p, wrt)
44103
if exponent_has_unknowns
45-
@warn "In $x: Exponent $p cannot contain unknowns of the system."
104+
push!(data.non_polynomial_terms, x)
105+
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
46106
end
47-
base_polynomial = is_polynomial(b, wrt)
107+
base_polynomial = is_polynomial!(data, b, wrt)
48108
if !base_polynomial
49-
@warn "In $x: Base is not a polynomial"
109+
push!(data.non_polynomial_terms, x)
110+
push!(data.reasons, NonPolynomialReason.BaseNotPolynomial)
50111
end
51112
return base_polynomial && !exponent_has_unknowns && is_pow_integer
52113
end
53-
@warn "In $x: Unrecognized operation $(operation(x)). Allowed polynomial operations are `*, +, -, ^`"
114+
push!(data.non_polynomial_terms, x)
115+
push!(data.reasons, NonPolynomialReason.UnrecognizedOperation)
54116
return false
55117
end
56118

@@ -179,21 +241,39 @@ Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial
179241
The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution`
180242
will contain the polynomial root closest to the point specified by `u0map` (if real roots
181243
exist for the system).
244+
245+
Keyword arguments:
246+
- `eval_expression`: Whether to `eval` the generated functions or use a `RuntimeGeneratedFunction`.
247+
- `eval_module`: The module to use for `eval`/`@RuntimeGeneratedFunction`
248+
- `warn_parametric_exponent`: Whether to warn if the system contains a parametric
249+
exponent preventing the homotopy from being cached.
250+
251+
All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
182252
"""
183253
function MTK.HomotopyContinuationProblem(
184254
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
185-
eval_module = ModelingToolkit, kwargs...)
255+
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
186256
if !iscomplete(sys)
187257
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
188258
end
189259

190260
dvs = unknowns(sys)
191-
eqs = equations(sys)
261+
# we need to consider `full_equations` because observed also should be
262+
# polynomials (if used in equations) and we don't know if observed is used
263+
# in denominator.
264+
# This is not the most efficient, and would be improved significantly with
265+
# CSE/hashconsing.
266+
eqs = full_equations(sys)
192267

193268
denoms = []
269+
has_parametric_exponents = false
194270
eqs2 = map(eqs) do eq
195-
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
196-
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
271+
data = PolynomialData()
272+
process_polynomial!(data, eq.lhs, dvs)
273+
process_polynomial!(data, eq.rhs, dvs)
274+
has_parametric_exponents |= data.has_parametric_exponent
275+
if !isempty(data.non_polynomial_terms)
276+
throw(NotPolynomialError(eq, data))
197277
end
198278
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)
199279

@@ -212,6 +292,9 @@ function MTK.HomotopyContinuationProblem(
212292
end
213293

214294
sys2 = MTK.@set sys.eqs = eqs2
295+
# remove observed equations to avoid adding them in codegen
296+
MTK.@set! sys2.observed = Equation[]
297+
MTK.@set! sys2.substitutions = nothing
215298

216299
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
217300
jac = true, eval_expression, eval_module)
@@ -223,29 +306,49 @@ function MTK.HomotopyContinuationProblem(
223306

224307
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
225308

226-
return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
309+
if has_parametric_exponents
310+
if warn_parametric_exponent
311+
@warn """
312+
The system has parametric exponents, preventing caching of the homotopy. \
313+
This will cause `solve` to be slower. Pass `warn_parametric_exponent \
314+
= false` to turn off this warning
315+
"""
316+
end
317+
solver_and_starts = nothing
318+
else
319+
solver_and_starts = HomotopyContinuation.solver_startsolutions(mtkhsys; kwargs...)
320+
end
321+
return MTK.HomotopyContinuationProblem(
322+
u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
227323
end
228324

229325
"""
230326
$(TYPEDSIGNATURES)
231327
232328
Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
233-
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
234-
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
329+
uses `HomotopyContinuation.jl`. The original solution as returned by
235330
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
236331
`NonlinearSolution`.
237332
238-
All keyword arguments have their default values in HomotopyContinuation.jl, except
239-
`show_progress` which defaults to `false`.
333+
All keyword arguments except the ones listed below are forwarded to
334+
`HomotopyContinuation.solve`. Note that the solver and start solutions are precomputed,
335+
and only keyword arguments related to the solve process are valid. All keyword
336+
arguments have their default values in HomotopyContinuation.jl, except `show_progress`
337+
which defaults to `false`.
240338
241339
Extra keyword arguments:
242340
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
243341
the denominator to be below `denominator_abstol` will be discarded.
244342
"""
245343
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
246344
alg = nothing; show_progress = false, denominator_abstol = 1e-7, kwargs...)
247-
sol = HomotopyContinuation.solve(
248-
prob.homotopy_continuation_system; show_progress, kwargs...)
345+
if prob.solver_and_starts === nothing
346+
sol = HomotopyContinuation.solve(
347+
prob.homotopy_continuation_system; show_progress, kwargs...)
348+
else
349+
solver, starts = prob.solver_and_starts
350+
sol = HomotopyContinuation.solve(solver, starts; show_progress, kwargs...)
351+
end
249352
realsols = HomotopyContinuation.results(sol; only_real = true)
250353
if isempty(realsols)
251354
u = state_values(prob)

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ using RecursiveArrayTools
5555
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5656
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
5757
import CommonSolve
58+
import EnumX
5859

5960
using RuntimeGeneratedFunctions
6061
using RuntimeGeneratedFunctions: drop_expr

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
690690
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
691691
create and solve.
692692
"""
693-
struct HomotopyContinuationProblem{uType, H, D, O} <:
693+
struct HomotopyContinuationProblem{uType, H, D, O, SS} <:
694694
SciMLBase.AbstractNonlinearProblem{uType, true}
695695
"""
696696
The initial values of states in the system. If there are multiple real roots of
@@ -716,6 +716,11 @@ struct HomotopyContinuationProblem{uType, H, D, O} <:
716716
A function which generates and returns observed expressions for the given system.
717717
"""
718718
obsfn::O
719+
"""
720+
The HomotopyContinuation.jl solver and start system, obtained through
721+
`HomotopyContinuation.solver_startsystems`.
722+
"""
723+
solver_and_starts::SS
719724
end
720725

721726
function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...)

test/extensions/homotopy_continuation.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,32 @@ end
6161
@test sol.retcode == ReturnCode.ConvergenceFailure
6262
end
6363

64-
@testset "Parametric exponent" begin
64+
@testset "Parametric exponents" begin
6565
@variables x = 1.0
6666
@parameters n::Integer = 4
6767
@mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
68-
prob = HomotopyContinuationProblem(sys, [])
68+
prob = @test_warn ["parametric", "exponent"] HomotopyContinuationProblem(sys, [])
69+
@test prob.solver_and_starts === nothing
70+
@test_nowarn HomotopyContinuationProblem(sys, []; warn_parametric_exponent = false)
6971
sol = solve(prob; threading = false)
7072
@test SciMLBase.successful_retcode(sol)
7173
end
7274

7375
@testset "Polynomial check and warnings" begin
7476
@variables x = 1.0
75-
@parameters n = 4
76-
@mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
77-
@test_warn ["Exponent", "not an integer", "@parameters"] @test_throws "not a polynomial" HomotopyContinuationProblem(
78-
sys, [])
7977
@mtkbuild sys = NonlinearSystem([x^1.5 + x^2 - 1 ~ 0])
80-
@test_warn ["Exponent", "not an integer"] @test_throws "not a polynomial" HomotopyContinuationProblem(
78+
@test_throws ["Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem(
8179
sys, [])
8280
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
83-
@test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem(
81+
@test_throws ["Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem(
8482
sys, [])
8583
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
86-
@test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem(
84+
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
85+
sys, [])
86+
87+
@variables y = 2.0
88+
@mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)])
89+
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
8790
sys, [])
8891
end
8992

@@ -131,4 +134,21 @@ end
131134
end
132135
end
133136
@test prob.denominator([2.0, 4.0], p)[1] <= 1e-8
137+
138+
@testset "Rational function in observed" begin
139+
@variables x=1 y=1
140+
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
141+
prob = HomotopyContinuationProblem(sys, [])
142+
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
143+
@test_nowarn solve(prob; threading = false)
144+
end
145+
end
146+
147+
@testset "Non-polynomial observed not used in equations" begin
148+
@variables x=1 y
149+
@mtkbuild sys = NonlinearSystem([x^2 - 2 ~ 0, y ~ sin(x)])
150+
prob = HomotopyContinuationProblem(sys, [])
151+
sol = @test_nowarn solve(prob; threading = false)
152+
@test sol[x] 2.0
153+
@test sol[y] sin(2.0)
134154
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ end
109109

110110
if GROUP == "All" || GROUP == "Extensions"
111111
activate_extensions_env()
112-
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
113112
@safetestset "HomotopyContinuation Extension Test" include("extensions/homotopy_continuation.jl")
114113
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
115114
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
115+
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
116116
end
117117
end

0 commit comments

Comments
 (0)