Skip to content

Commit 9f49619

Browse files
committed
hack: try implementing cooperative group
1 parent bfa9148 commit 9f49619

File tree

1 file changed

+62
-54
lines changed

1 file changed

+62
-54
lines changed

ext/DynamicExpressionsCUDAExt.jl

+62-54
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DynamicExpressionsCUDAExt
22

33
# TODO: Switch to KernelAbstractions.jl (once they hit v1.0)
4-
using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx
4+
using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx, CG
55
using DynamicExpressions: OperatorEnum, AbstractExpressionNode
66
using DynamicExpressions.EvaluateEquationModule: get_nbin, get_nuna
77
using DynamicExpressions.AsArrayModule: as_array
@@ -78,7 +78,7 @@ function eval_tree_array(
7878
gidx_r = @view gbuffer[7, :]
7979
gconstant = @view gbuffer[8, :]
8080

81-
num_threads = 256
81+
num_threads = 1024
8282
num_blocks = nextpow(2, ceil(Int, num_elem * num_nodes / num_threads))
8383

8484
#! format: off
@@ -113,25 +113,23 @@ function _launch_gpu_kernel!(
113113
(nuna > 10 || nbin > 10) &&
114114
error("Too many operators. Kernels are only compiled up to 10.")
115115
gpu_kernel! = create_gpu_kernel(operators, Val(nuna), Val(nbin))
116-
for launch in one(I):I(num_launches)
117-
#! format: off
118-
if buffer isa CuArray
119-
@cuda threads=num_threads blocks=num_blocks gpu_kernel!(
116+
#! format: off
117+
if buffer isa CuArray
118+
@cuda cooperative=true threads=num_threads blocks=num_blocks gpu_kernel!(
119+
buffer,
120+
num_launches, num_elem, num_nodes, execution_order,
121+
cX, idx_self, idx_l, idx_r,
122+
degree, constant, val, feature, op
123+
)
124+
else
125+
Threads.@threads for i in 1:(num_threads * num_blocks * num_launches)
126+
gpu_kernel!(
120127
buffer,
121-
launch, num_elem, num_nodes, execution_order,
128+
num_launches, num_elem, num_nodes, execution_order,
122129
cX, idx_self, idx_l, idx_r,
123-
degree, constant, val, feature, op
130+
degree, constant, val, feature, op,
131+
i
124132
)
125-
else
126-
Threads.@threads for i in 1:(num_threads * num_blocks)
127-
gpu_kernel!(
128-
buffer,
129-
launch, num_elem, num_nodes, execution_order,
130-
cX, idx_self, idx_l, idx_r,
131-
degree, constant, val, feature, op,
132-
i
133-
)
134-
end
135133
end
136134
#! format: on
137135
end
@@ -151,7 +149,7 @@ for nuna in 0:10, nbin in 0:10
151149
# Storage:
152150
buffer,
153151
# Thread info:
154-
launch::Integer, num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray,
152+
num_launches::Integer, num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray,
155153
# Input data and tree
156154
cX::AbstractArray, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray,
157155
degree::AbstractArray, constant::AbstractArray, val::AbstractArray, feature::AbstractArray, op::AbstractArray,
@@ -163,46 +161,56 @@ for nuna in 0:10, nbin in 0:10
163161
if i > num_elem * num_nodes
164162
return nothing
165163
end
166-
164+
#
167165
node = (i - 1) % num_nodes + 1
168166
elem = (i - node) ÷ num_nodes + 1
167+
grid_group = CG.this_grid()
169168

170-
if execution_order[node] != launch
171-
return nothing
172-
end
169+
for launch in 1:num_launches
170+
if launch > 1
171+
# TODO: Investigate whether synchronizing within
172+
# group instead of whole grid is better.
173+
CG.sync(grid_group)
174+
end
175+
if execution_order[node] != launch
176+
continue
177+
end
173178

174-
cur_degree = degree[node]
175-
cur_idx = idx_self[node]
176-
if cur_degree == 0
177-
if constant[node] == 1
178-
cur_val = val[node]
179-
buffer[elem, cur_idx] = cur_val
179+
cur_degree = degree[node]
180+
cur_idx = idx_self[node]
181+
if cur_degree == 0
182+
if constant[node] == 1
183+
cur_val = val[node]
184+
buffer[elem, cur_idx] = cur_val
185+
else
186+
cur_feature = feature[node]
187+
buffer[elem, cur_idx] = cX[cur_feature, elem]
188+
end
180189
else
181-
cur_feature = feature[node]
182-
buffer[elem, cur_idx] = cX[cur_feature, elem]
183-
end
184-
else
185-
if cur_degree == 1 && $nuna > 0
186-
cur_op = op[node]
187-
l_idx = idx_l[node]
188-
Base.Cartesian.@nif(
189-
$nuna,
190-
i -> i == cur_op,
191-
i -> let op = operators.unaops[i]
192-
buffer[elem, cur_idx] = op(buffer[elem, l_idx])
193-
end
194-
)
195-
elseif $nbin > 0 # Note this check is to avoid type inference issues when binops is empty
196-
cur_op = op[node]
197-
l_idx = idx_l[node]
198-
r_idx = idx_r[node]
199-
Base.Cartesian.@nif(
200-
$nbin,
201-
i -> i == cur_op,
202-
i -> let op = operators.binops[i]
203-
buffer[elem, cur_idx] = op(buffer[elem, l_idx], buffer[elem, r_idx])
204-
end
205-
)
190+
if cur_degree == 1 && $nuna > 0
191+
cur_op = op[node]
192+
l_idx = idx_l[node]
193+
Base.Cartesian.@nif(
194+
$nuna,
195+
i -> i == cur_op,
196+
i -> let op = operators.unaops[i]
197+
buffer[elem, cur_idx] = op(buffer[elem, l_idx])
198+
end
199+
)
200+
elseif $nbin > 0 # Note this check is to avoid type inference issues when binops is empty
201+
cur_op = op[node]
202+
l_idx = idx_l[node]
203+
r_idx = idx_r[node]
204+
Base.Cartesian.@nif(
205+
$nbin,
206+
i -> i == cur_op,
207+
i -> let op = operators.binops[i]
208+
buffer[elem, cur_idx] = op(
209+
buffer[elem, l_idx], buffer[elem, r_idx]
210+
)
211+
end
212+
)
213+
end
206214
end
207215
end
208216
return nothing

0 commit comments

Comments
 (0)