Skip to content

feat: cache start system and solver in HomotopyContinuation interface #3192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
Expand Down Expand Up @@ -94,6 +95,7 @@ Distributions = "0.23, 0.24, 0.25"
DocStringExtensions = "0.7, 0.8, 0.9"
DomainSets = "0.6, 0.7"
DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
EnumX = "1.0.4"
ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
Expand Down
149 changes: 126 additions & 23 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,104 @@ function contains_variable(x, wrt)
any(y -> occursin(y, x), wrt)
end

"""
Possible reasons why a term is not polynomial
"""
MTK.EnumX.@enumx NonPolynomialReason begin
NonIntegerExponent
ExponentContainsUnknowns
BaseNotPolynomial
UnrecognizedOperation
end

function display_reason(reason::NonPolynomialReason.T, sym)
if reason == NonPolynomialReason.NonIntegerExponent
pow = arguments(sym)[2]
"In $sym: Exponent $pow is not an integer"
elseif reason == NonPolynomialReason.ExponentContainsUnknowns
pow = arguments(sym)[2]
"In $sym: Exponent $pow contains unknowns of the system"
elseif reason == NonPolynomialReason.BaseNotPolynomial
base = arguments(sym)[1]
"In $sym: Base $base is not a polynomial in the unknowns"
elseif reason == NonPolynomialReason.UnrecognizedOperation
op = operation(sym)
"""
In $sym: Operation $op is not recognized. Allowed polynomial operations are \
`*, /, +, -, ^`.
"""
else
error("This should never happen. Please open an issue in ModelingToolkit.jl.")
end
end

mutable struct PolynomialData
non_polynomial_terms::Vector{BasicSymbolic}
reasons::Vector{NonPolynomialReason.T}
has_parametric_exponent::Bool
end

PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)

struct NotPolynomialError <: Exception
eq::Equation
data::PolynomialData
end

function Base.showerror(io::IO, err::NotPolynomialError)
println(io,
"Equation $(err.eq) is not a polynomial in the unknowns for the following reasons:")
for (term, reason) in zip(err.data.non_polynomial_terms, err.data.reasons)
println(io, display_reason(reason, term))
end
end

function is_polynomial!(data, y, wrt)
process_polynomial!(data, y, wrt)
isempty(data.reasons)
end

"""
$(TYPEDSIGNATURES)

Check if `x` is polynomial with respect to the variables in `wrt`.
Return information about the polynmial `x` with respect to variables in `wrt`,
writing said information to `data`.
"""
function is_polynomial(x, wrt)
function process_polynomial!(data::PolynomialData, x, wrt)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return true
iscall(x) || return true
contains_variable(x, wrt) || return true
any(isequal(x), wrt) && return true

if operation(x) in (*, +, -, /)
return all(y -> is_polynomial(y, wrt), arguments(x))
return all(y -> is_polynomial!(data, y, wrt), arguments(x))
end
if operation(x) == (^)
b, p = arguments(x)
is_pow_integer = symtype(p) <: Integer
if !is_pow_integer
if symbolic_type(p) == NotSymbolic()
@warn "In $x: Exponent $p is not an integer"
else
@warn "In $x: Exponent $p is not an integer. Use `@parameters p::Integer` to declare integer parameters."
end
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.NonIntegerExponent)
end
if symbolic_type(p) != NotSymbolic()
data.has_parametric_exponent = true
end

exponent_has_unknowns = contains_variable(p, wrt)
if exponent_has_unknowns
@warn "In $x: Exponent $p cannot contain unknowns of the system."
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
end
base_polynomial = is_polynomial(b, wrt)
base_polynomial = is_polynomial!(data, b, wrt)
if !base_polynomial
@warn "In $x: Base is not a polynomial"
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.BaseNotPolynomial)
end
return base_polynomial && !exponent_has_unknowns && is_pow_integer
end
@warn "In $x: Unrecognized operation $(operation(x)). Allowed polynomial operations are `*, +, -, ^`"
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.UnrecognizedOperation)
return false
end

Expand Down Expand Up @@ -179,21 +241,39 @@ Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial
The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution`
will contain the polynomial root closest to the point specified by `u0map` (if real roots
exist for the system).

Keyword arguments:
- `eval_expression`: Whether to `eval` the generated functions or use a `RuntimeGeneratedFunction`.
- `eval_module`: The module to use for `eval`/`@RuntimeGeneratedFunction`
- `warn_parametric_exponent`: Whether to warn if the system contains a parametric
exponent preventing the homotopy from being cached.

All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
"""
function MTK.HomotopyContinuationProblem(
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
eval_module = ModelingToolkit, kwargs...)
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
end

dvs = unknowns(sys)
eqs = equations(sys)
# we need to consider `full_equations` because observed also should be
# polynomials (if used in equations) and we don't know if observed is used
# in denominator.
# This is not the most efficient, and would be improved significantly with
# CSE/hashconsing.
eqs = full_equations(sys)

denoms = []
has_parametric_exponents = false
eqs2 = map(eqs) do eq
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
data = PolynomialData()
process_polynomial!(data, eq.lhs, dvs)
process_polynomial!(data, eq.rhs, dvs)
has_parametric_exponents |= data.has_parametric_exponent
if !isempty(data.non_polynomial_terms)
throw(NotPolynomialError(eq, data))
end
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)

Expand All @@ -212,6 +292,9 @@ function MTK.HomotopyContinuationProblem(
end

sys2 = MTK.@set sys.eqs = eqs2
# remove observed equations to avoid adding them in codegen
MTK.@set! sys2.observed = Equation[]
MTK.@set! sys2.substitutions = nothing

nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
jac = true, eval_expression, eval_module)
Expand All @@ -223,29 +306,49 @@ function MTK.HomotopyContinuationProblem(

obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)

return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
if has_parametric_exponents
if warn_parametric_exponent
@warn """
The system has parametric exponents, preventing caching of the homotopy. \
This will cause `solve` to be slower. Pass `warn_parametric_exponent \
= false` to turn off this warning
"""
end
solver_and_starts = nothing
else
solver_and_starts = HomotopyContinuation.solver_startsolutions(mtkhsys; kwargs...)
end
return MTK.HomotopyContinuationProblem(
u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
end

"""
$(TYPEDSIGNATURES)

Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
uses `HomotopyContinuation.jl`. The original solution as returned by
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
`NonlinearSolution`.

All keyword arguments have their default values in HomotopyContinuation.jl, except
`show_progress` which defaults to `false`.
All keyword arguments except the ones listed below are forwarded to
`HomotopyContinuation.solve`. Note that the solver and start solutions are precomputed,
and only keyword arguments related to the solve process are valid. All keyword
arguments have their default values in HomotopyContinuation.jl, except `show_progress`
which defaults to `false`.

Extra keyword arguments:
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
the denominator to be below `denominator_abstol` will be discarded.
"""
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
alg = nothing; show_progress = false, denominator_abstol = 1e-7, kwargs...)
sol = HomotopyContinuation.solve(
prob.homotopy_continuation_system; show_progress, kwargs...)
if prob.solver_and_starts === nothing
sol = HomotopyContinuation.solve(
prob.homotopy_continuation_system; show_progress, kwargs...)
else
solver, starts = prob.solver_and_starts
sol = HomotopyContinuation.solve(solver, starts; show_progress, kwargs...)
end
realsols = HomotopyContinuation.results(sol; only_real = true)
if isempty(realsols)
u = state_values(prob)
Expand Down
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using RecursiveArrayTools
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
import CommonSolve
import EnumX

using RuntimeGeneratedFunctions
using RuntimeGeneratedFunctions: drop_expr
Expand Down
7 changes: 6 additions & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
create and solve.
"""
struct HomotopyContinuationProblem{uType, H, D, O} <:
struct HomotopyContinuationProblem{uType, H, D, O, SS} <:
SciMLBase.AbstractNonlinearProblem{uType, true}
"""
The initial values of states in the system. If there are multiple real roots of
Expand All @@ -716,6 +716,11 @@ struct HomotopyContinuationProblem{uType, H, D, O} <:
A function which generates and returns observed expressions for the given system.
"""
obsfn::O
"""
The HomotopyContinuation.jl solver and start system, obtained through
`HomotopyContinuation.solver_startsystems`.
"""
solver_and_starts::SS
end

function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...)
Expand Down
38 changes: 29 additions & 9 deletions test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,32 @@ end
@test sol.retcode == ReturnCode.ConvergenceFailure
end

@testset "Parametric exponent" begin
@testset "Parametric exponents" begin
@variables x = 1.0
@parameters n::Integer = 4
@mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
prob = HomotopyContinuationProblem(sys, [])
prob = @test_warn ["parametric", "exponent"] HomotopyContinuationProblem(sys, [])
@test prob.solver_and_starts === nothing
@test_nowarn HomotopyContinuationProblem(sys, []; warn_parametric_exponent = false)
sol = solve(prob; threading = false)
@test SciMLBase.successful_retcode(sol)
end

@testset "Polynomial check and warnings" begin
@variables x = 1.0
@parameters n = 4
@mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
@test_warn ["Exponent", "not an integer", "@parameters"] @test_throws "not a polynomial" HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([x^1.5 + x^2 - 1 ~ 0])
@test_warn ["Exponent", "not an integer"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@test_throws ["Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
@test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@test_throws ["Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
@test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
sys, [])

@variables y = 2.0
@mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)])
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
sys, [])
end

Expand Down Expand Up @@ -131,4 +134,21 @@ end
end
end
@test prob.denominator([2.0, 4.0], p)[1] <= 1e-8

@testset "Rational function in observed" begin
@variables x=1 y=1
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
@test_nowarn solve(prob; threading = false)
end
end

@testset "Non-polynomial observed not used in equations" begin
@variables x=1 y
@mtkbuild sys = NonlinearSystem([x^2 - 2 ~ 0, y ~ sin(x)])
prob = HomotopyContinuationProblem(sys, [])
sol = @test_nowarn solve(prob; threading = false)
@test sol[x] ≈ √2.0
@test sol[y] ≈ sin(√2.0)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ end

if GROUP == "All" || GROUP == "Extensions"
activate_extensions_env()
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
@safetestset "HomotopyContinuation Extension Test" include("extensions/homotopy_continuation.jl")
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
end
end
Loading