Skip to content

Commit 8d89343

Browse files
mtk in weakdeps and secondorder in DI
1 parent 905c453 commit 8d89343

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,18 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515

1616
[weakdeps]
1717
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
18+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1819

1920
[extensions]
2021
OptimizationDIExt = "DifferentiationInterface"
22+
OptimizationMTKExt = "ModelingToolkit"
2123

2224
[compat]
2325
ADTypes = "0.2.5"
2426
ArrayInterface = "7.6"
2527
DocStringExtensions = "0.9"
2628
LinearAlgebra = "1.9, 1.10"
29+
ModelingToolkit = "9"
2730
Reexport = "1.2"
2831
Requires = "1"
2932
SciMLBase = "2"

ext/OptimizationDIExt.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module OptimizationDIExt
44
import OptimizationBase, OptimizationBase.ArrayInterface
55
import OptimizationBase.SciMLBase: OptimizationFunction
66
import OptimizationBase.LinearAlgebra: I
7+
import DifferentiationInterface
78
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_jacobian, gradient!!, hessian!!, jacobian!!
89
using ADTypes
910

@@ -22,7 +23,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::ADTypes.AbstractADT
2223
hess_sparsity = f.hess_prototype
2324
hess_colors = f.hess_colorvec
2425
if f.hess === nothing
25-
extras_hess = prepare_hessian(_f, adtype, x)
26+
extras_hess = prepare_hessian(_f, DifferentiationInterface.SecondOrder(adtype), x) #placeholder logic, can be made much better
2627
function hess(res, θ, args...)
2728
hessian!!(_f, res, adtype, θ, extras_hess)
2829
end
@@ -32,10 +33,6 @@ function OptimizationBase.instantiate_function(f, x, adtype::ADTypes.AbstractADT
3233

3334
if f.hv === nothing
3435
hv = function (H, θ, v, args...)
35-
# _θ = ForwardDiff.Dual.(θ, v)
36-
# res = similar(_θ)
37-
# grad(res, _θ, args...)
38-
# H .= getindex.(ForwardDiff.partials.(res), 1)
3936
res = zeros(length(θ), length(θ))
4037
hess(res, θ, args...)
4138
H .= res * v
@@ -66,7 +63,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::ADTypes.AbstractADT
6663
conshess_colors = f.cons_hess_colorvec
6764
if cons !== nothing && f.cons_h === nothing
6865
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
69-
extras_cons_hess = prepare_hessian.(fncs, Ref(adtype), Ref(x))
66+
extras_cons_hess = prepare_hessian.(fncs, Ref(DifferentiationInterface.SecondOrder(adtype)), Ref(x))
7067

7168
function cons_h(H, θ)
7269
for i in 1:num_cons
@@ -110,9 +107,9 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
110107
hess_sparsity = f.hess_prototype
111108
hess_colors = f.hess_colorvec
112109
if f.hess === nothing
113-
extras_hess = prepare_hessian(_f, adtype, x)
110+
extras_hess = prepare_hessian(_f, DifferentiationInterface.SecondOrder(adtype), x) #placeholder logic, can be made much better
114111
function hess(res, θ, args...)
115-
hessian!!(_f, res, adtype, θ, extras_hess)
112+
hessian!!(_f, res, DifferentiationInterface.SecondOrder(adtype), θ, extras_hess)
116113
end
117114
else
118115
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
@@ -154,7 +151,7 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
154151
conshess_colors = f.cons_hess_colorvec
155152
if cons !== nothing && f.cons_h === nothing
156153
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
157-
extras_cons_hess = prepare_hessian.(fncs, Ref(adtype), Ref(x))
154+
extras_cons_hess = prepare_hessian.(fncs, Ref(DifferentiationInterface.SecondOrder(adtype)), Ref(x))
158155

159156
function cons_h(H, θ)
160157
for i in 1:num_cons

0 commit comments

Comments
 (0)