Skip to content

Adds an abstract type to NamedArrayPartition #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: arraypart_zero
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition, NamedArrayPartition
export ArrayPartition, NamedArrayPartition, AbstractNamedArrayPartition

end # module
4 changes: 3 additions & 1 deletion src/array_partition.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
abstract type AbstractArrayPartition{T} <: AbstractVector{T} end

"""
```julia
ArrayPartition(x::AbstractArray...)
Expand All @@ -23,7 +25,7 @@ A = ArrayPartition(y, z)

we would have `A.x[1]==y` and `A.x[2]==z`. Broadcasting like `f.(A)` is efficient.
"""
struct ArrayPartition{T, S <: Tuple} <: AbstractVector{T}
struct ArrayPartition{T, S <: Tuple} <: AbstractArrayPartition{T}
x::S
end

Expand Down
148 changes: 92 additions & 56 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
AbstractNamedArrayPartition{T, A, NT}

An abstract type above that of `NamedArrayPartition` that can be used to subtype a
new and seperately named 'NamedArrayPartition'-like structure. This can be done
by defining your new type as:

```julia
struct foo{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractNamedArrayPartition{T, A, NT}
array_partition::A
names_to_indices::NT
end
```

where `foo` is your custom name and then all funcitonalities of NamedArrayPartitions will be inherited.
"""
abstract type AbstractNamedArrayPartition{T, A, NT} <: AbstractArrayPartition{T} end

"""
NamedArrayPartition(; kwargs...)
NamedArrayPartition(x::NamedTuple)
Expand All @@ -6,137 +24,155 @@ Similar to an `ArrayPartition` but the individual arrays can be accessed via the
constructor-specified names. However, unlike `ArrayPartition`, each individual array
must have the same element type.
"""
struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractVector{T}
struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractNamedArrayPartition{T, A, NT}
array_partition::A
names_to_indices::NT
end
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
function NamedArrayPartition(x::NamedTuple)
(::Type{T})(; kwargs...) where {T<:AbstractNamedArrayPartition} = T(NamedTuple(kwargs))
function (::Type{T})(x::NamedTuple) where {T<:AbstractNamedArrayPartition}
names_to_indices = NamedTuple(Pair(symbol, index)
for (index, symbol) in enumerate(keys(x)))

# enforce homogeneity of eltypes
@assert all(eltype.(values(x)) .== eltype(first(x)))
T = eltype(first(x))
R = eltype(first(x))
S = typeof(values(x))
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
return T(ArrayPartition{R, S}(values(x)), names_to_indices)
end

function named_partition_constructor(X::T) where {T<:AbstractNamedArrayPartition}
getfield(parentmodule(T), nameof(T))
end

# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# fields except through `getfield` and accessor functions.
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
ArrayPartition(x::AbstractNamedArrayPartition) = getfield(x, :array_partition)

function Base.similar(A::NamedArrayPartition)
NamedArrayPartition(
# With new type structure this function does the same as Base.similar(x::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one commented?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I couldn't manage in either the old or the new method to call the less specific function as providing an NamedArrayPartition will have the type structure X{N, A, NT} and therefore a later version would have instead been called. In the original code the two functions are:

function Base.similar(A::NamedArrayPartition)
    NamedArrayPartition(
        similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
end

function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT}
    NamedArrayPartition{T, S, NT}(
        similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end

But in all cases I tested the second always overwrites the first as any NamedArrayPartition has the structure of the second

#= function Base.similar(A::T) where {T<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
end
end =#

# return ArrayPartition when possible, otherwise next best thing of the correct size
function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N}
NamedArrayPartition(
function Base.similar(A::T, dims::NTuple{N, Int}) where {T<:AbstractNamedArrayPartition, N}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices))
end

# similar array partition of common type
@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T}
NamedArrayPartition(
@inline function Base.similar(A::S, ::Type{T}) where {S<:AbstractNamedArrayPartition, T}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices))
end

# return ArrayPartition when possible, otherwise next best thing of the correct size
function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
NamedArrayPartition(
function Base.similar(A::S, ::Type{T}, dims::NTuple{N, Int}) where {T, N, S<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices))
end

# similar array partition with different types
function Base.similar(
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
NamedArrayPartition(
A::U, ::Type{T}, ::Type{S}, R::DataType...) where {T, S, U<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
end

Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
Base.Array(x::AbstractNamedArrayPartition) = Array(ArrayPartition(x))

function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN}
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
function Base.zero(x::R) where {R <: AbstractNamedArrayPartition}
R(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
end
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors
Base.zero(A::AbstractNamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors

Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
function Base.getproperty(x::NamedArrayPartition, s::Symbol)
Base.propertynames(x::AbstractNamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
function Base.getproperty(x::AbstractNamedArrayPartition, s::Symbol)
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
end

# this enables x.s = some_array.
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
@inline function Base.setproperty!(x::AbstractNamedArrayPartition, s::Symbol, v)
index = getproperty(getfield(x, :names_to_indices), s)
ArrayPartition(x).x[index] .= v
end

# print out NamedArrayPartition as a NamedTuple
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
function Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition)
Base.summary(x::AbstractNamedArrayPartition) = string(typeof(x), " with arrays:")
function Base.show(io::IO, m::MIME"text/plain", x::AbstractNamedArrayPartition)
show(
io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))
end

Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
Base.size(x::AbstractNamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::AbstractNamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::AbstractNamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)

Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
function Base.map(f, x::NamedArrayPartition)
NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
Base.setindex!(x::AbstractNamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
function Base.map(f, x::T) where {T<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(x)
Tconstr(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
end
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))
Base.mapreduce(f, op, x::AbstractNamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
# Base.filter(f, x::AbstractNamedArrayPartition) = filter(f, ArrayPartition(x))

function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT}
NamedArrayPartition{T, S, NT}(
similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end
function Base.similar(x::AbstractNamedArrayPartition{T, A, NT}) where {T, A, NT}
# Safely extract the concrete type parameters

Tconstr = named_partition_constructor(x)
return Tconstr{T, A, NT}(
similar(getfield(x, :array_partition)),
getfield(x, :names_to_indices)
)
end
# broadcasting
function Base.BroadcastStyle(::Type{<:NamedArrayPartition})
Broadcast.ArrayStyle{NamedArrayPartition}()
function Base.BroadcastStyle(::Type{T}) where{T<:AbstractNamedArrayPartition}
Broadcast.ArrayStyle{T}()
end
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
::Type{ElType}) where {ElType}
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}},
::Type{ElType}) where {ElType, T<:AbstractNamedArrayPartition}
x = find_NamedArrayPartition(bc)
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
Tconstr = named_partition_constructor(x)
return Tconstr(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end

# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
function Base.BroadcastStyle(
::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1})
::Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1})
Broadcast.DefaultArrayStyle{1}()
end

# hook into ArrayPartition broadcasting routines
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = Broadcast.Broadcasted(
@inline RecursiveArrayTools.npartitions(x::AbstractNamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}}, i) = Broadcast.Broadcasted(
bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)
@inline RecursiveArrayTools.unpack(x::AbstractNamedArrayPartition, i) = unpack(ArrayPartition(x), i)

function Base.copy(A::NamedArrayPartition{T, S, NT}) where {T, S, NT}
NamedArrayPartition{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
function Base.copy(A::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
Tconstr = named_partition_constructor(A)
Tconstr{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
end

@inline NamedArrayPartition(f::F, N, names_to_indices) where {F <: Function} = NamedArrayPartition(
@inline (::Type{T})(f::F, N, names_to_indices) where {F <: Function, T<:AbstractNamedArrayPartition} = T(
ArrayPartition(ntuple(f, Val(N))), names_to_indices)

@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}}) where {T<:AbstractNamedArrayPartition}
N = npartitions(bc)
@inline function f(i)
copy(unpack(bc, i))
end
x = find_NamedArrayPartition(bc)
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
Tconstr = named_partition_constructor(x)
Tconstr(f, N, getfield(x, :names_to_indices))
end

@inline function Base.copyto!(dest::NamedArrayPartition,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
@inline function Base.copyto!(dest::AbstractNamedArrayPartition,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}})
N = npartitions(dest, bc)
@inline function f(i)
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
Expand All @@ -146,7 +182,7 @@ end
end

#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
function ArrayInterface.zeromatrix(A::AbstractNamedArrayPartition)
B = ArrayPartition(A)
x = reduce(vcat,vec.(B.x))
x .* x' .* false
Expand All @@ -159,5 +195,5 @@ function find_NamedArrayPartition(args::Tuple)
end
find_NamedArrayPartition(x) = x
find_NamedArrayPartition(::Tuple{}) = nothing
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
find_NamedArrayPartition(x::AbstractNamedArrayPartition, rest) = x
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)
2 changes: 1 addition & 1 deletion test/named_array_partition_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test
using RecursiveArrayTools, Test, ArrayInterface

@testset "NamedArrayPartition tests" begin
x = NamedArrayPartition(a = ones(10), b = rand(20))
Expand Down
Loading