Skip to content

Commit 08df9a3

Browse files
committed
Use vectors for vector input; tuple for tuple
1 parent 7ceacc0 commit 08df9a3

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

ext/DynamicExpressionsCUDAExt.jl

+5-8
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ function eval_tree_array(
2929
end
3030

3131
function eval_tree_array(
32-
trees::Tuple{N,Vararg{N,M}},
32+
trees::Union{NTuple{M,N} where M,AbstractVector{N}},
3333
gcX::MaybeCuArray{T,2},
3434
operators::OperatorEnum;
3535
buffer=nothing,
3636
gpu_workspace=nothing,
3737
gpu_buffer=nothing,
3838
kws...,
39-
) where {T<:Number,N<:AbstractExpressionNode{T},M}
40-
(; val, execution_order, roots, buffer, num_nodes) = as_array(Int32, trees...; buffer)
39+
) where {T<:Number,N<:AbstractExpressionNode{T}}
40+
(; val, execution_order, roots, buffer, num_nodes) = as_array(Int32, trees; buffer)
4141
num_launches = maximum(execution_order)
4242
num_elem = size(gcX, 2)
4343

@@ -82,11 +82,8 @@ function eval_tree_array(
8282
)
8383
#! format: on
8484

85-
out = ntuple(i -> @view(gworkspace[begin:(end - 1), roots[i]]), Val(M + 1))
86-
is_good = ntuple(
87-
i -> true, # Up to user to find NaNs
88-
Val(M + 1),
89-
)
85+
out = (r -> @view(gworkspace[begin:(end - 1), r])).(roots)
86+
is_good = (_ -> true).(trees)
9087

9188
return (out, is_good)
9289
end

src/AsArray.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@ using ..EquationModule: AbstractExpressionNode, tree_mapreduce, count_nodes
44

55
function as_array(
66
::Type{I},
7-
tree::N,
8-
additional_trees::Vararg{N,M};
7+
trees::Union{NTuple{M,N} where M,AbstractVector{N}};
98
buffer::Union{AbstractArray,Nothing}=nothing,
10-
) where {T,N<:AbstractExpressionNode{T},I,M}
11-
trees = (tree, additional_trees...)
9+
) where {T,N<:AbstractExpressionNode{T},I}
1210
each_num_nodes = (t -> count_nodes(t; break_sharing=Val(true))).(trees)
1311
num_nodes = sum(each_num_nodes)
1412

15-
roots = cumsum(tuple(one(I), each_num_nodes[1:(end - 1)]...))
13+
# Want `roots` to be tuple if `trees` is tuple and similar for vector
14+
roots = cumsum(
15+
if each_num_nodes isa Tuple
16+
tuple(one(I), each_num_nodes[1:(end - 1)]...)
17+
else
18+
vcat(one(I), each_num_nodes[1:(end - 1)])
19+
end,
20+
)
1621

1722
val = Array{T}(undef, num_nodes)
1823

src/EvaluateEquation.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,13 @@ function eval_tree_array(
9898
return eval_tree_array(tree, cX, operators; turbo, bumper)
9999
end
100100
function eval_tree_array(
101-
trees::Tuple{N,Vararg{N,M}}, cX::AbstractMatrix{T}, operators::OperatorEnum; kws...
102-
) where {T<:Number,N<:AbstractExpressionNode{T},M}
103-
outs = ntuple(i -> eval_tree_array(trees[i], cX, operators; kws...), Val(M + 1))
104-
return ntuple(i -> first(outs[i]), Val(M + 1)), ntuple(i -> last(outs[i]), Val(M + 1))
101+
trees::Union{NTuple{M,N} where M,AbstractArray{N}},
102+
cX::AbstractMatrix{T},
103+
operators::OperatorEnum;
104+
kws...,
105+
) where {T<:Number,N<:AbstractExpressionNode{T}}
106+
outs = (t -> eval_tree_array(t, cX, operators; kws...)).(trees)
107+
return first.(outs), last.(outs)
105108
end
106109

107110
get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)

test/test_cuda.jl

+18-4
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,41 @@ let
2121

2222
nrow = rand(10:30)
2323
nnodes = rand(10:25, ntrees)
24+
use_tuple = rand(Bool)
2425

2526
buffer = rand(Bool) ? ones(Int32, 8, sum(nnodes)) : nothing
2627
gpu_buffer = rand(Bool) ? FakeCuArray(ones(Int32, 8, sum(nnodes))) : nothing
2728
gpu_workspace = rand(Bool) ? FakeCuArray(ones(T, nrow + 1, sum(nnodes))) : nothing
2829

2930
trees = ntuple(i -> gen_random_tree_fixed_size(nnodes[i], operators, 3, T), ntrees)
31+
trees = use_tuple ? trees : collect(trees)
3032
X = randn(T, 3, nrow)
3133
if ntrees > 1
32-
y, completed = eval_tree_array(trees, X, operators)
33-
gpu_y, gpu_completed = eval_tree_array(
34+
y, completed = @inferred eval_tree_array(trees, X, operators)
35+
gpu_y, gpu_completed = @inferred eval_tree_array(
3436
trees, FakeCuArray(X), operators; buffer, gpu_workspace, gpu_buffer
3537
)
3638

39+
# Should give same result either way
3740
for i in eachindex(completed, gpu_completed)
3841
if completed[i]
3942
@test y[i] gpu_y[i]
4043
end
4144
end
45+
46+
# Should return same type as input
47+
if use_tuple
48+
@test y isa Tuple
49+
@test gpu_y isa Tuple
50+
else
51+
@test y isa Vector
52+
@test gpu_y isa Vector
53+
end
4254
else
43-
y, completed = eval_tree_array(only(trees), X, operators)
44-
gpu_y, gpu_completed = eval_tree_array(only(trees), FakeCuArray(X), operators)
55+
y, completed = @inferred eval_tree_array(only(trees), X, operators)
56+
gpu_y, gpu_completed = @inferred eval_tree_array(
57+
only(trees), FakeCuArray(X), operators
58+
)
4559
if completed
4660
@test y gpu_y
4761
end

0 commit comments

Comments
 (0)