1
1
module DynamicExpressionsCUDAExt
2
2
3
3
# TODO : Switch to KernelAbstractions.jl (once they hit v1.0)
4
- using CUDA: @cuda , CuArray, blockDim, blockIdx, threadIdx
4
+ using CUDA: @cuda , CuArray, blockDim, blockIdx, threadIdx, CG
5
5
using DynamicExpressions: OperatorEnum, AbstractExpressionNode
6
6
using DynamicExpressions. EvaluateEquationModule: get_nbin, get_nuna
7
7
using DynamicExpressions. AsArrayModule: as_array
@@ -78,7 +78,7 @@ function eval_tree_array(
78
78
gidx_r = @view gbuffer[7 , :]
79
79
gconstant = @view gbuffer[8 , :]
80
80
81
- num_threads = 256
81
+ num_threads = 1024
82
82
num_blocks = nextpow (2 , ceil (Int, num_elem * num_nodes / num_threads))
83
83
84
84
# ! format: off
@@ -113,25 +113,23 @@ function _launch_gpu_kernel!(
113
113
(nuna > 10 || nbin > 10 ) &&
114
114
error (" Too many operators. Kernels are only compiled up to 10." )
115
115
gpu_kernel! = create_gpu_kernel (operators, Val (nuna), Val (nbin))
116
- for launch in one (I): I (num_launches)
117
- # ! format: off
118
- if buffer isa CuArray
119
- @cuda threads= num_threads blocks= num_blocks gpu_kernel! (
116
+ # ! format: off
117
+ if buffer isa CuArray
118
+ @cuda cooperative= true threads= num_threads blocks= num_blocks gpu_kernel! (
119
+ buffer,
120
+ num_launches, num_elem, num_nodes, execution_order,
121
+ cX, idx_self, idx_l, idx_r,
122
+ degree, constant, val, feature, op
123
+ )
124
+ else
125
+ Threads. @threads for i in 1 : (num_threads * num_blocks * num_launches)
126
+ gpu_kernel! (
120
127
buffer,
121
- launch , num_elem, num_nodes, execution_order,
128
+ num_launches , num_elem, num_nodes, execution_order,
122
129
cX, idx_self, idx_l, idx_r,
123
- degree, constant, val, feature, op
130
+ degree, constant, val, feature, op,
131
+ i
124
132
)
125
- else
126
- Threads. @threads for i in 1 : (num_threads * num_blocks)
127
- gpu_kernel! (
128
- buffer,
129
- launch, num_elem, num_nodes, execution_order,
130
- cX, idx_self, idx_l, idx_r,
131
- degree, constant, val, feature, op,
132
- i
133
- )
134
- end
135
133
end
136
134
# ! format: on
137
135
end
@@ -151,7 +149,7 @@ for nuna in 0:10, nbin in 0:10
151
149
# Storage:
152
150
buffer,
153
151
# Thread info:
154
- launch :: Integer , num_elem:: Integer , num_nodes:: Integer , execution_order:: AbstractArray ,
152
+ num_launches :: Integer , num_elem:: Integer , num_nodes:: Integer , execution_order:: AbstractArray ,
155
153
# Input data and tree
156
154
cX:: AbstractArray , idx_self:: AbstractArray , idx_l:: AbstractArray , idx_r:: AbstractArray ,
157
155
degree:: AbstractArray , constant:: AbstractArray , val:: AbstractArray , feature:: AbstractArray , op:: AbstractArray ,
@@ -163,46 +161,56 @@ for nuna in 0:10, nbin in 0:10
163
161
if i > num_elem * num_nodes
164
162
return nothing
165
163
end
166
-
164
+ #
167
165
node = (i - 1 ) % num_nodes + 1
168
166
elem = (i - node) ÷ num_nodes + 1
167
+ grid_group = CG. this_grid ()
169
168
170
- if execution_order[node] != launch
171
- return nothing
172
- end
169
+ for launch in 1 : num_launches
170
+ if launch > 1
171
+ # TODO : Investigate whether synchronizing within
172
+ # group instead of whole grid is better.
173
+ CG. sync (grid_group)
174
+ end
175
+ if execution_order[node] != launch
176
+ continue
177
+ end
173
178
174
- cur_degree = degree[node]
175
- cur_idx = idx_self[node]
176
- if cur_degree == 0
177
- if constant[node] == 1
178
- cur_val = val[node]
179
- buffer[elem, cur_idx] = cur_val
179
+ cur_degree = degree[node]
180
+ cur_idx = idx_self[node]
181
+ if cur_degree == 0
182
+ if constant[node] == 1
183
+ cur_val = val[node]
184
+ buffer[elem, cur_idx] = cur_val
185
+ else
186
+ cur_feature = feature[node]
187
+ buffer[elem, cur_idx] = cX[cur_feature, elem]
188
+ end
180
189
else
181
- cur_feature = feature[node]
182
- buffer[elem, cur_idx] = cX[cur_feature, elem]
183
- end
184
- else
185
- if cur_degree == 1 && $ nuna > 0
186
- cur_op = op[node]
187
- l_idx = idx_l[node]
188
- Base. Cartesian. @nif (
189
- $ nuna,
190
- i -> i == cur_op,
191
- i -> let op = operators. unaops[i]
192
- buffer[elem, cur_idx] = op (buffer[elem, l_idx])
193
- end
194
- )
195
- elseif $ nbin > 0 # Note this check is to avoid type inference issues when binops is empty
196
- cur_op = op[node]
197
- l_idx = idx_l[node]
198
- r_idx = idx_r[node]
199
- Base. Cartesian. @nif (
200
- $ nbin,
201
- i -> i == cur_op,
202
- i -> let op = operators. binops[i]
203
- buffer[elem, cur_idx] = op (buffer[elem, l_idx], buffer[elem, r_idx])
204
- end
205
- )
190
+ if cur_degree == 1 && $ nuna > 0
191
+ cur_op = op[node]
192
+ l_idx = idx_l[node]
193
+ Base. Cartesian. @nif (
194
+ $ nuna,
195
+ i -> i == cur_op,
196
+ i -> let op = operators. unaops[i]
197
+ buffer[elem, cur_idx] = op (buffer[elem, l_idx])
198
+ end
199
+ )
200
+ elseif $ nbin > 0 # Note this check is to avoid type inference issues when binops is empty
201
+ cur_op = op[node]
202
+ l_idx = idx_l[node]
203
+ r_idx = idx_r[node]
204
+ Base. Cartesian. @nif (
205
+ $ nbin,
206
+ i -> i == cur_op,
207
+ i -> let op = operators. binops[i]
208
+ buffer[elem, cur_idx] = op (
209
+ buffer[elem, l_idx], buffer[elem, r_idx]
210
+ )
211
+ end
212
+ )
213
+ end
206
214
end
207
215
end
208
216
return nothing
0 commit comments