Skip to content

[Breaking] Allow value and partials to have distinct types #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.12"
version = "0.11.0"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand Down
26 changes: 13 additions & 13 deletions src/apiutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
@generated function dualize(::Type{T}, x::StaticArray) where T
N = length(x)
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
V = StaticArrays.similar_type(x, Dual{T,eltype(x),N})
V = StaticArrays.similar_type(x, Dual{T,eltype(x),N,eltype(x)})
return quote
chunk = Chunk{$N}()
$(Expr(:meta, :inline))
Expand Down Expand Up @@ -53,38 +53,38 @@ end
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x,
seed::Partials{N,P} = zero(Partials{N,P})) where {T,V,N,P}
for i in eachindex(duals)
duals[i] = Dual{T,V,N}(x[i], seed)
duals[i] = Dual{T,V,N,P}(x[i], seed)
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x,
seeds::NTuple{N,Partials{N,P}}) where {T,V,N,P}
for i in 1:N
duals[i] = Dual{T,V,N}(x[i], seeds[i])
duals[i] = Dual{T,V,N,P}(x[i], seeds[i])
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x, index,
seed::Partials{N,P} = zero(Partials{N,P})) where {T,V,N,P}
offset = index - 1
for i in 1:N
j = i + offset
duals[j] = Dual{T,V,N}(x[j], seed)
duals[j] = Dual{T,V,N,P}(x[j], seed)
end
return duals
end

function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
function seed!(duals::AbstractArray{Dual{T,V,N,P}}, x, index,
seeds::NTuple{N,Partials{N,P}}, chunksize = N) where {T,V,N,P}
offset = index - 1
for i in 1:chunksize
j = i + offset
duals[j] = Dual{T,V,N}(x[j], seeds[i])
duals[j] = Dual{T,V,N,P}(x[j], seeds[i])
end
return duals
end
14 changes: 7 additions & 7 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function DerivativeConfig(f::F,
y::AbstractArray{Y},
x::X,
tag::T = Tag(f, X)) where {F,X<:Real,Y<:Real,T}
duals = similar(y, Dual{T,Y,1})
duals = similar(y, Dual{T,Y,1,Y})
return DerivativeConfig{T,typeof(duals)}(duals)
end

Expand Down Expand Up @@ -119,7 +119,7 @@ function GradientConfig(f::F,
::Chunk{N} = Chunk(x),
::T = Tag(f, V)) where {F,V,N,T}
seeds = construct_seeds(Partials{N,V})
duals = similar(x, Dual{T,V,N})
duals = similar(x, Dual{T,V,N,V})
return GradientConfig{T,V,N,typeof(duals)}(seeds, duals)
end

Expand Down Expand Up @@ -156,7 +156,7 @@ function JacobianConfig(f::F,
::Chunk{N} = Chunk(x),
::T = Tag(f, V)) where {F,V,N,T}
seeds = construct_seeds(Partials{N,V})
duals = similar(x, Dual{T,V,N})
duals = similar(x, Dual{T,V,N,V})
return JacobianConfig{T,V,N,typeof(duals)}(seeds, duals)
end

Expand All @@ -182,8 +182,8 @@ function JacobianConfig(f::F,
::Chunk{N} = Chunk(x),
::T = Tag(f, X)) where {F,Y,X,N,T}
seeds = construct_seeds(Partials{N,X})
yduals = similar(y, Dual{T,Y,N})
xduals = similar(x, Dual{T,X,N})
yduals = similar(y, Dual{T,Y,N,Y})
xduals = similar(x, Dual{T,X,N,X})
duals = (yduals, xduals)
return JacobianConfig{T,X,N,typeof(duals)}(seeds, duals)
end
Expand All @@ -197,7 +197,7 @@ Base.eltype(::Type{JacobianConfig{T,V,N,D}}) where {T,V,N,D} = Dual{T,V,N}

struct HessianConfig{T,V,N,DG,DJ} <: AbstractConfig{N}
jacobian_config::JacobianConfig{T,V,N,DJ}
gradient_config::GradientConfig{T,Dual{T,V,N},N,DG}
gradient_config::GradientConfig{T,Dual{T,V,N,V},N,DG}
end

"""
Expand Down Expand Up @@ -254,4 +254,4 @@ end

checktag(::HessianConfig{T},f,x) where {T} = checktag(T,f,x)
Base.eltype(::Type{HessianConfig{T,V,N,DG,DJ}}) where {T,V,N,DG,DJ} =
Dual{T,Dual{T,V,N},N}
Dual{T,Dual{T,V,N,V},N,Dual{T,V,N,V}}
8 changes: 4 additions & 4 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This method assumes that `isa(f(x), Union{Real,AbstractArray})`.
"""
@inline function derivative(f::F, x::R) where {F,R<:Real}
T = typeof(Tag(f, R))
return extract_derivative(T, f(Dual{T}(x, one(x))))
return extract_derivative(T, f(Dual{T}(x, oneunit(x))))
end

"""
Expand All @@ -27,7 +27,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
CHK && checktag(T, f!, x)
ydual = cfg.duals
seed!(ydual, y)
f!(ydual, Dual{T}(x, one(x)))
f!(ydual, Dual{T}(x, oneunit(x)))
map!(value, y, ydual)
return extract_derivative(T, ydual)
end
Expand All @@ -43,7 +43,7 @@ This method assumes that `isa(f(x), Union{Real,AbstractArray})`.
@inline function derivative!(result::Union{AbstractArray,DiffResult},
f::F, x::R) where {F,R<:Real}
T = typeof(Tag(f, R))
ydual = f(Dual{T}(x, one(x)))
ydual = f(Dual{T}(x, oneunit(x)))
result = extract_value!(T, result, ydual)
result = extract_derivative!(T, result, ydual)
return result
Expand All @@ -63,7 +63,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
CHK && checktag(T, f!, x)
ydual = cfg.duals
seed!(ydual, y)
f!(ydual, Dual{T}(x, one(x)))
f!(ydual, Dual{T}(x, oneunit(x)))
result = extract_value!(T, result, y, ydual)
result = extract_derivative!(T, result, ydual)
return result
Expand Down
96 changes: 61 additions & 35 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ Dual. By default, only `<:Real` types are allowed.
can_dual(::Type{<:Real}) = true
can_dual(::Type) = false

struct Dual{T,V,N} <: Real
struct Dual{T,V,N,P} <: Real
value::V
partials::Partials{N,V}
function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N}
partials::Partials{N,P}
function Dual{T, V, N, P}(value::V, partials::Partials{N, P}) where {T, V, N, P}
can_dual(V) || throw_cannot_dual(V)
new{T, V, N}(value, partials)
can_dual(P) || throw_cannot_dual(P)
new{T, V, N, P}(value, partials)
end
end
@inline Dual{T,V,N}(value, partials::Partials{N,P}) where {T,V,N,P} = Dual{T,V,N,P}(value, partials)

##############
# Exceptions #
Expand Down Expand Up @@ -52,9 +54,10 @@ tag can be extracted, so it should be used in the _innermost_ function.
# Constructors #
################

@inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V} = Dual{T,V,N}(value, partials)

@inline function Dual{T}(value::A, partials::Partials{N,B}) where {T,N,A,B}
@inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V<:Real} = Dual{T,V,N,V}(value, partials) # ambiguity resolution
@inline Dual{T}(value::V, partials::Partials{N,V}) where {T,N,V} = Dual{T,V,N,V}(value, partials)
@inline Dual{T}(value::V, partials::Partials{N,P}) where {T,N,V,P} = Dual{T,V,N,P}(value, partials)
@inline function Dual{T}(value::A, partials::Partials{N,B}) where {T,N,A<:Real,B<:Real}
C = promote_type(A, B)
return Dual{T}(convert(C, value), convert(Partials{N,C}, partials))
end
Expand All @@ -68,8 +71,9 @@ end
@inline Dual(args...) = Dual{Nothing}(args...)

# we define these special cases so that the "constructor <--> convert" pun holds for `Dual`
@inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x
@inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x)
@inline Dual{T,V,N,P}(x::Dual{T,V,N,P}) where {T,V,N,P} = x
@inline Dual{T,V,N,P}(x) where {T,V,N,P} = convert(Dual{T,V,N,P}, x)
@inline Dual{T,V,N,P}(x::Number) where {T,V,N,P} = convert(Dual{T,V,N,P}, x)
@inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x)
@inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x)

Expand Down Expand Up @@ -109,15 +113,21 @@ end


@inline npartials(::Dual{T,V,N}) where {T,V,N} = N
@inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N
@inline npartials(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = N

@inline order(::Type{V}) where {V} = 0
@inline order(::Type{Dual{T,V,N}}) where {T,V,N} = 1 + order(V)
@inline order(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = 1 + order(V)

@inline valtype(::V) where {V} = V
@inline valtype(::Type{V}) where {V} = V
@inline valtype(::Dual{T,V,N}) where {T,V,N} = V
@inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V
@inline valtype(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = V

@inline partialtype(::V) where {V} = V
@inline partialtype(::Type{V}) where {V} = V
@inline partialtype(::Dual{T,V,N,P}) where {T,V,N,P} = P
@inline partialtype(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = P

@inline tagtype(::V) where {V} = Nothing
@inline tagtype(::Type{V}) where {V} = Nothing
Expand Down Expand Up @@ -282,10 +292,10 @@ Base.round(d::Dual) = round(value(d))
Base.hash(d::Dual) = hash(value(d))
Base.hash(d::Dual, hsh::UInt) = hash(value(d), hsh)

function Base.read(io::IO, ::Type{Dual{T,V,N}}) where {T,V,N}
function Base.read(io::IO, ::Type{Dual{T,V,N,P}}) where {T,V,N,P}
value = read(io, V)
partials = read(io, Partials{N,V})
return Dual{T,V,N}(value, partials)
partials = read(io, Partials{N,P})
return Dual{T,V,N,P}(value, partials)
end

function Base.write(io::IO, d::Dual)
Expand All @@ -294,18 +304,24 @@ function Base.write(io::IO, d::Dual)
end

@inline Base.zero(d::Dual) = zero(typeof(d))
@inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(zero(V), zero(Partials{N,V}))
@inline Base.zero(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(zero(V), zero(Partials{N,P}))
@inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = zero(Dual{T,V,N,V})

@inline Base.one(d::Dual) = one(typeof(d))
@inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(one(V), zero(Partials{N,V}))
@inline Base.one(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(one(V), zero(Partials{N,P}))
@inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = one(Dual{T,V,N,V})

@inline Random.rand(rng::AbstractRNG, d::Dual) = rand(rng, value(d))
@inline Random.rand(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(V), zero(Partials{N,V}))
@inline Random.rand(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(rand(rng, V), zero(Partials{N,V}))
@inline Random.randn(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(V), zero(Partials{N,V}))
@inline Random.randn(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randn(rng, V), zero(Partials{N,V}))
@inline Random.randexp(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(V), zero(Partials{N,V}))
@inline Random.randexp(rng::AbstractRNG, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(randexp(rng, V), zero(Partials{N,V}))
@inline Random.rand(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(rand(V), zero(Partials{N,P}))
@inline Random.rand(rng::AbstractRNG, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(rand(rng, V), zero(Partials{N,P}))
@inline Random.randn(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randn(V), zero(Partials{N,P}))
@inline Random.randn(rng::AbstractRNG, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randn(rng, V), zero(Partials{N,P}))
@inline Random.randexp(::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randexp(V), zero(Partials{N,P}))
@inline Random.randexp(rng::AbstractRNG, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T}(randexp(rng, V), zero(Partials{N,P}))

@inline Base.zero(::Type{Partials{N,Dual{T,V,M}}}) where {N,T,V,M} = zero(Partials{N,Dual{T,V,M,V}})
@inline Base.one(::Type{Partials{N,Dual{T,V,M}}}) where {N,T,V,M} = one(Partials{N,Dual{T,V,M,V}})


# Predicates #
#------------#
Expand All @@ -331,35 +347,45 @@ end
# Promotion/Conversion #
########################

Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}},
::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2}
Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1,P1}},
::Type{Dual{T2,V2,N2,P2}}) where {T1,V1,N1,P1,T2,V2,N2,P2}
# V1 and V2 might themselves be Dual types
if T2 ≺ T1
Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1}
Dual{T1,promote_type(V1,Dual{T2,V2,N2,P2}),N1,P1}
else
Dual{T2,promote_type(V2,Dual{T1,V1,N1}),N2}
Dual{T2,promote_type(V2,Dual{T1,V1,N1,P1}),N2,P2}
end
end

function Base.promote_rule(::Type{Dual{T,A,N,PA}},
::Type{Dual{T,B,N,PB}}) where {T,A,B,PA,PB,N}
return Dual{T,promote_type(A, B),N,promote_type(PA, PB)}
end
function Base.promote_rule(::Type{Dual{T,A,N}},
::Type{Dual{T,B,N}}) where {T,A,B,N}
return Dual{T,promote_type(A, B),N}
return Dual{T,promote_type(A, B),N,promote_type(A, B)}
end

for R in (Irrational, Real, BigFloat, Bool)
if isconcretetype(R) # issue #322
@eval begin
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N}
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N}
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N,P}}) where {T,V,N,P} = Dual{T,promote_type($R, V),N,promote_type($R, P)}
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N,promote_type($R, V)}
end
else
@eval begin
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N}
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N}
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N,P}}) where {R<:$R,T,V,N,P} = Dual{T,promote_type(R, V),N,promote_type(R, P)}
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N,promote_type(R, V)}
end
end
end

Base.convert(::Type{Partials{N,Dual{T,V,M}}}, partials::Partials) where {N,T,V,M} =
convert(Partials{N,Dual{T,V,M,V}}, partials)

Base.convert(::Type{Dual{T,V,N,P}}, d::Dual{T}) where {T,V,N,P} = Dual{T}(convert(V, value(d)), convert(Partials{N,P}, partials(d)))
Base.convert(::Type{Dual{T,V,N,P}}, x) where {T,V,N,P} = Dual{T}(convert(V, x), zero(Partials{N,P}))
Base.convert(::Type{Dual{T,V,N,P}}, x::Number) where {T,V,N,P} = Dual{T}(convert(V, x), zero(Partials{N,P}))
Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d)))
Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
Expand Down Expand Up @@ -621,10 +647,10 @@ function Base.show(io::IO, d::Dual{T,V,N}) where {T,V,N}
print(io, ")")
end

function Base.typemin(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
ForwardDiff.Dual{T,V,N}(typemin(V))
function Base.typemin(::Type{ForwardDiff.Dual{T,V,N,P}}) where {T,V,N,P}
ForwardDiff.Dual{T,V,N,P}(typemin(V))
end

function Base.typemax(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
ForwardDiff.Dual{T,V,N}(typemax(V))
function Base.typemax(::Type{ForwardDiff.Dual{T,V,N,P}}) where {T,V,N,P}
ForwardDiff.Dual{T,V,N,P}(typemax(V))
end
12 changes: 6 additions & 6 deletions src/partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end
##############################

@generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i}
ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...)
ex = Expr(:tuple, [ifelse(i === j, :(oneunit(V)), :(zero(V))) for j in 1:N]...)
return :(Partials($(ex)))
end

Expand Down Expand Up @@ -92,18 +92,18 @@ end

if NANSAFE_MODE_ENABLED
@inline function Base.:*(partials::Partials, x::Real)
x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
x = ifelse(!isfinite(x) && iszero(partials), oneunit(x), x)
return Partials(scale_tuple(partials.values, x))
end

@inline function Base.:/(partials::Partials, x::Real)
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
x = ifelse(x == zero(x) && iszero(partials), oneunit(x), x)
return Partials(div_tuple_by_scalar(partials.values, x))
end

@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
x_a = ifelse(!isfinite(x_a) && iszero(a), oneunit(x_a), x_a)
x_b = ifelse(!isfinite(x_b) && iszero(b), oneunit(x_b), x_b)
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
end
else
Expand Down Expand Up @@ -184,7 +184,7 @@ end
@generated function one_tuple(::Type{NTuple{N,V}}) where {N,V}
ex = tupexpr(i -> :(z), N)
return quote
z = one(V)
z = oneunit(V)
return $ex
end
end
Expand Down
Loading