@@ -4,6 +4,7 @@ module OptimizationDIExt
4
4
import OptimizationBase, OptimizationBase. ArrayInterface
5
5
import OptimizationBase. SciMLBase: OptimizationFunction
6
6
import OptimizationBase. LinearAlgebra: I
7
+ import DifferentiationInterface
7
8
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_jacobian, gradient!!, hessian!!, jacobian!!
8
9
using ADTypes
9
10
@@ -22,7 +23,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::ADTypes.AbstractADT
22
23
hess_sparsity = f. hess_prototype
23
24
hess_colors = f. hess_colorvec
24
25
if f. hess === nothing
25
- extras_hess = prepare_hessian (_f, adtype, x)
26
+ extras_hess = prepare_hessian (_f, DifferentiationInterface . SecondOrder ( adtype) , x) # placeholder logic, can be made much better
26
27
function hess (res, θ, args... )
27
28
hessian!! (_f, res, adtype, θ, extras_hess)
28
29
end
@@ -32,10 +33,6 @@ function OptimizationBase.instantiate_function(f, x, adtype::ADTypes.AbstractADT
32
33
33
34
if f. hv === nothing
34
35
hv = function (H, θ, v, args... )
35
- # _θ = ForwardDiff.Dual.(θ, v)
36
- # res = similar(_θ)
37
- # grad(res, _θ, args...)
38
- # H .= getindex.(ForwardDiff.partials.(res), 1)
39
36
res = zeros (length (θ), length (θ))
40
37
hess (res, θ, args... )
41
38
H .= res * v
@@ -66,7 +63,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::ADTypes.AbstractADT
66
63
conshess_colors = f. cons_hess_colorvec
67
64
if cons != = nothing && f. cons_h === nothing
68
65
fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
69
- extras_cons_hess = prepare_hessian .(fncs, Ref (adtype), Ref (x))
66
+ extras_cons_hess = prepare_hessian .(fncs, Ref (DifferentiationInterface . SecondOrder ( adtype) ), Ref (x))
70
67
71
68
function cons_h (H, θ)
72
69
for i in 1 : num_cons
@@ -110,9 +107,9 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
110
107
hess_sparsity = f. hess_prototype
111
108
hess_colors = f. hess_colorvec
112
109
if f. hess === nothing
113
- extras_hess = prepare_hessian (_f, adtype, x)
110
+ extras_hess = prepare_hessian (_f, DifferentiationInterface . SecondOrder ( adtype) , x) # placeholder logic, can be made much better
114
111
function hess (res, θ, args... )
115
- hessian!! (_f, res, adtype, θ, extras_hess)
112
+ hessian!! (_f, res, DifferentiationInterface . SecondOrder ( adtype) , θ, extras_hess)
116
113
end
117
114
else
118
115
hess = (H, θ, args... ) -> f. hess (H, θ, p, args... )
@@ -154,7 +151,7 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
154
151
conshess_colors = f. cons_hess_colorvec
155
152
if cons != = nothing && f. cons_h === nothing
156
153
fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
157
- extras_cons_hess = prepare_hessian .(fncs, Ref (adtype), Ref (x))
154
+ extras_cons_hess = prepare_hessian .(fncs, Ref (DifferentiationInterface . SecondOrder ( adtype) ), Ref (x))
158
155
159
156
function cons_h (H, θ)
160
157
for i in 1 : num_cons
0 commit comments