Skip to content

Implement value_gradient_and_hessian #305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ hvp!
prepare_hessian
hessian
hessian!
value_gradient_and_hessian
value_gradient_and_hessian!
```

## Utilities
Expand Down
18 changes: 9 additions & 9 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ These operators are computed using the input `x` and a "seed" `v`, which lives e

Several variants of each operator are defined.

| out-of-place | in-place | out-of-place + primal | in-place + primal |
| :-------------------------- | :--------------------------- | :----------------------------------------------- | :----------------------------------------------- |
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
| out-of-place | in-place | out-of-place + primal | in-place + primal |
| :-------------------------- | :--------------------------- | :----------------------------------------------- | :------------------------------------------------ |
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
| [`second_derivative`](@ref) | [`second_derivative!`](@ref) | [`value_derivative_and_second_derivative`](@ref) | [`value_derivative_and_second_derivative!`](@ref) |
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA |
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
| [`hessian`](@ref) | [`hessian!`](@ref) | [`value_gradient_and_hessian`](@ref) | [`value_gradient_and_hessian!`](@ref) NA |
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |

## Mutation and signatures

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ using DifferentiationInterface:
JacobianExtras,
PullbackExtras,
PushforwardExtras,
SecondDerivativeExtras
SecondDerivativeExtras,
maybe_dense_ad
using FastDifferentiation:
derivative,
hessian,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
## Pushforward

struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras
struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E1!} <: PushforwardExtras
y_prototype::Y
jvp_exe::E1
jvp_exe!::E2
jvp_exe!::E1!
end

function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx)
Expand Down Expand Up @@ -70,9 +70,9 @@ end

## Pullback

struct FastDifferentiationOneArgPullbackExtras{E1,E2} <: PullbackExtras
struct FastDifferentiationOneArgPullbackExtras{E1,E1!} <: PullbackExtras
vjp_exe::E1
vjp_exe!::E2
vjp_exe!::E1!
end

function DI.prepare_pullback(f, ::AutoFastDifferentiation, x, dy)
Expand Down Expand Up @@ -133,10 +133,10 @@ end

## Derivative

struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras
struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E1!} <: DerivativeExtras
y_prototype::Y
der_exe::E1
der_exe!::E2
der_exe!::E1!
end

function DI.prepare_derivative(f, ::AutoFastDifferentiation, x)
Expand Down Expand Up @@ -190,13 +190,12 @@ end

## Gradient

struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras
struct FastDifferentiationOneArgGradientExtras{E1,E1!} <: GradientExtras
jac_exe::E1
jac_exe!::E2
jac_exe!::E1!
end

function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x)
y_prototype = f(x)
x_var = make_variables(:x, size(x)...)
y_var = f(x_var)

Expand Down Expand Up @@ -241,10 +240,10 @@ end

## Jacobian

struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2} <: JacobianExtras
struct FastDifferentiationOneArgJacobianExtras{Y,E1,E1!} <: JacobianExtras
y_prototype::Y
jac_exe::E1
jac_exe!::E2
jac_exe!::E1!
end

function DI.prepare_jacobian(
Expand Down Expand Up @@ -307,34 +306,29 @@ end

## Second derivative

struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1!,E2,E2!} <:
struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,D,E2,E2!} <:
SecondDerivativeExtras
y_prototype::Y
der_exe::E1
der_exe!::E1!
derivative_extras::D
der2_exe::E2
der2_exe!::E2!
end

function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x)
function DI.prepare_second_derivative(f, backend::AutoFastDifferentiation, x)
y_prototype = f(x)
x_var = only(make_variables(:x))
y_var = f(x_var)

x_vec_var = monovec(x_var)
y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var)

der_vec_var = derivative(y_vec_var, x_var)
der2_vec_var = derivative(y_vec_var, x_var, x_var)

der_exe = make_function(der_vec_var, x_vec_var; in_place=false)
der_exe! = make_function(der_vec_var, x_vec_var; in_place=true)

der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false)
der2_exe! = make_function(der2_vec_var, x_vec_var; in_place=true)

derivative_extras = DI.prepare_derivative(f, backend, x)
return FastDifferentiationAllocatingSecondDerivativeExtras(
y_prototype, der_exe, der_exe!, der2_exe, der2_exe!
y_prototype, derivative_extras, der2_exe, der2_exe!
)
end

Expand Down Expand Up @@ -364,20 +358,13 @@ end

function DI.value_derivative_and_second_derivative(
f,
::AutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
)
y = f(x)
if extras.y_prototype isa Number
der = only(extras.der_exe(monovec(x)))
der2 = only(extras.der2_exe(monovec(x)))
return y, der, der2
else
der = reshape(extras.der_exe(monovec(x)), size(extras.y_prototype))
der2 = reshape(extras.der2_exe(monovec(x)), size(extras.y_prototype))
return y, der, der2
end
y, der = DI.value_and_derivative(f, backend, x, extras.derivative_extras)
der2 = DI.second_derivative(f, backend, x, extras)
return y, der, der2
end

function DI.value_derivative_and_second_derivative!(
Expand All @@ -388,17 +375,16 @@ function DI.value_derivative_and_second_derivative!(
x,
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
)
y = f(x)
extras.der_exe!(vec(der), monovec(x))
extras.der2_exe!(vec(der2), monovec(x))
y, _ = DI.value_and_derivative!(f, der, backend, x, extras.derivative_extras)
DI.second_derivative!(f, der2, backend, x, extras)
return y, der, der2
end

## HVP

struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras
hvp_exe::E1
hvp_exe!::E2
struct FastDifferentiationHVPExtras{E2,E2!} <: HVPExtras
hvp_exe::E2
hvp_exe!::E2!
end

function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v)
Expand Down Expand Up @@ -428,24 +414,30 @@ end

## Hessian

struct FastDifferentiationHessianExtras{E1,E2} <: HessianExtras
hess_exe::E1
hess_exe!::E2
struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras
gradient_extras::G
hess_exe::E2
hess_exe!::E2!
end

function DI.prepare_hessian(
f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x
)
x_vec_var = make_variables(:x, size(x)...)
y_vec_var = f(x_vec_var)
x_var = make_variables(:x, size(x)...)
y_var = f(x_var)

x_vec_var = vec(x_var)

hess_var = if backend isa AutoSparse
sparse_hessian(y_vec_var, vec(x_vec_var))
sparse_hessian(y_var, x_vec_var)
else
hessian(y_vec_var, vec(x_vec_var))
hessian(y_var, x_vec_var)
end
hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false)
hess_exe! = make_function(hess_var, vec(x_vec_var); in_place=true)
return FastDifferentiationHessianExtras(hess_exe, hess_exe!)
hess_exe = make_function(hess_var, x_vec_var; in_place=false)
hess_exe! = make_function(hess_var, x_vec_var; in_place=true)

gradient_extras = DI.prepare_gradient(f, maybe_dense_ad(backend), x)
return FastDifferentiationHessianExtras(gradient_extras, hess_exe, hess_exe!)
end

function DI.hessian(
Expand All @@ -467,3 +459,29 @@ function DI.hessian!(
extras.hess_exe!(hess, vec(x))
return hess
end

function DI.value_gradient_and_hessian(
f,
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
x,
extras::FastDifferentiationHessianExtras,
)
y, grad = DI.value_and_gradient(f, maybe_dense_ad(backend), x, extras.gradient_extras)
hess = DI.hessian(f, backend, x, extras)
return y, grad, hess
end

function DI.value_gradient_and_hessian!(
f,
grad,
hess,
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
x,
extras::FastDifferentiationHessianExtras,
)
y, _ = DI.value_and_gradient!(
f, grad, maybe_dense_ad(backend), x, extras.gradient_extras
)
DI.hessian!(f, hess, backend, x, extras)
return y, grad, hess
end
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
## Pushforward

struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras
struct FastDifferentiationTwoArgPushforwardExtras{E1,E1!} <: PushforwardExtras
jvp_exe::E1
jvp_exe!::E2
jvp_exe!::E1!
end

function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx)
Expand Down Expand Up @@ -80,9 +80,9 @@ end

## Pullback

struct FastDifferentiationTwoArgPullbackExtras{E1,E2} <: PullbackExtras
struct FastDifferentiationTwoArgPullbackExtras{E1,E1!} <: PullbackExtras
vjp_exe::E1
vjp_exe!::E2
vjp_exe!::E1!
end

function DI.prepare_pullback(f!, y, ::AutoFastDifferentiation, x, dy)
Expand Down Expand Up @@ -156,9 +156,9 @@ end

## Derivative

struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras
struct FastDifferentiationTwoArgDerivativeExtras{E1,E1!} <: DerivativeExtras
der_exe::E1
der_exe!::E2
der_exe!::E1!
end

function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x)
Expand Down Expand Up @@ -216,9 +216,9 @@ end

## Jacobian

struct FastDifferentiationTwoArgJacobianExtras{E1,E2} <: JacobianExtras
struct FastDifferentiationTwoArgJacobianExtras{E1,E1!} <: JacobianExtras
jac_exe::E1
jac_exe!::E2
jac_exe!::E1!
end

function DI.prepare_jacobian(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,39 @@ end

## Hessian

struct FiniteDiffHessianExtras{C} <: HessianExtras
cache::C
struct FiniteDiffHessianExtras{C1,C2} <: HessianExtras
gradient_cache::C1
hessian_cache::C2
end

function DI.prepare_hessian(f, backend::AutoFiniteDiff, x)
cache = HessianCache(x, fdhtype(backend))
return FiniteDiffHessianExtras(cache)
y = f(x)
df = zero(y) .* x
gradient_cache = GradientCache(df, x, fdtype(backend))
hessian_cache = HessianCache(x, fdhtype(backend))
return FiniteDiffHessianExtras(gradient_cache, hessian_cache)
end

# cache cannot be reused because of https://github.com/JuliaDiff/FiniteDiff.jl/issues/185

function DI.hessian(f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
return finite_difference_hessian(f, x, extras.cache)
return finite_difference_hessian(f, x, extras.hessian_cache)
end

function DI.hessian!(f, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
return finite_difference_hessian!(hess, f, x, extras.cache)
return finite_difference_hessian!(hess, f, x, extras.hessian_cache)
end

function DI.value_gradient_and_hessian(
f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras
)
grad = finite_difference_gradient(f, x, extras.gradient_cache)
hess = finite_difference_hessian(f, x, extras.hessian_cache)
return f(x), grad, hess
end

function DI.value_gradient_and_hessian!(
f, grad, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras
)
finite_difference_gradient!(grad, f, x, extras.gradient_cache)
finite_difference_hessian!(hess, f, x, extras.hessian_cache)
return f(x), grad, hess
end
Loading
Loading