Skip to content

Commit 9d7500c

Browse files
committed
Fix indexing bugs in GPU kernels
1 parent 74094e3 commit 9d7500c

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

ext/DynamicExpressionsCUDAExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Base.similar(x::FakeCuArray, dims::Integer...) = FakeCuArray(similar(x.a, dims..
1515
Base.getindex(x::FakeCuArray, i::Int...) = getindex(x.a, i...)
1616
Base.setindex!(x::FakeCuArray, v, i::Int...) = setindex!(x.a, v, i...)
1717
Base.size(x::FakeCuArray) = size(x.a)
18+
Base.Array(x::FakeCuArray) = Array(x.a)
1819

1920
const MaybeCuArray{T,N} = Union{CuArray{T,2},FakeCuArray{T,N}}
2021

@@ -41,13 +42,16 @@ function eval_tree_array(
4142
num_launches = maximum(execution_order)
4243
num_elem = size(gcX, 2)
4344

44-
## Floating point arrays:
45+
## The following array is our "workspace" for
46+
## the GPU kernel, with size equal to the number of rows
47+
## in the input data by the number of nodes in the tree.
48+
## It has one extra row to store the constant values.
4549
gworkspace = if gpu_workspace === nothing
46-
similar(gcX, num_elem, num_nodes + 1)
50+
similar(gcX, num_elem + 1, num_nodes)
4751
else
4852
gpu_workspace
4953
end
50-
gval = @view gworkspace[:, end]
54+
gval = @view gworkspace[end, :]
5155
copyto!(gval, val)
5256

5357
## Index arrays (much faster to have `@view` here)
@@ -79,7 +83,7 @@ function eval_tree_array(
7983
)
8084
#! format: on
8185

82-
out = ntuple(i -> @view(gworkspace[:, roots[i]]), Val(M + 1))
86+
out = ntuple(i -> @view(gworkspace[begin:end-1, roots[i]]), Val(M + 1))
8387
is_good = ntuple(
8488
i -> true, # Up to user to find NaNs
8589
Val(M + 1),

src/EvaluateEquation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ end
100100
function eval_tree_array(
101101
trees::Tuple{N,Vararg{N,M}}, cX::AbstractMatrix{T}, operators::OperatorEnum; kws...
102102
) where {T<:Number,N<:AbstractExpressionNode{T},M}
103-
outs = ntuple(i -> eval_tree_array(trees[i], cX, operators; kws...)[1], Val(M + 1))
103+
outs = ntuple(i -> eval_tree_array(trees[i], cX, operators; kws...), Val(M + 1))
104104
return ntuple(i -> first(outs[i]), Val(M + 1)), ntuple(i -> last(outs[i]), Val(M + 1))
105105
end
106106

0 commit comments

Comments
 (0)