Skip to content

Commit c1e579a

Browse files
committed
Start switching to KernelAbstractions.jl
1 parent 676ad86 commit c1e579a

4 files changed

+36
-58
lines changed

Project.toml

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1414

1515
[weakdeps]
1616
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
17-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
17+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1818
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1919
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2020
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2121
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2222

2323
[extensions]
2424
DynamicExpressionsBumperExt = "Bumper"
25-
DynamicExpressionsCUDAExt = "CUDA"
25+
DynamicExpressionsKernelAbstractionsExt = "KernelAbstractions"
2626
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
2727
DynamicExpressionsOptimExt = "Optim"
2828
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
@@ -31,7 +31,7 @@ DynamicExpressionsZygoteExt = "Zygote"
3131
[compat]
3232
Aqua = "0.7"
3333
Bumper = "0.6"
34-
CUDA = "4, 5"
34+
KernelAbstractions = "0.9"
3535
Compat = "3.37, 4"
3636
Enzyme = "^0.11.12"
3737
LoopVectorization = "0.12"
@@ -47,9 +47,9 @@ julia = "1.6"
4747
[extras]
4848
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4949
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
50-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5150
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5251
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
52+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
5353
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5454
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
5555
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
@@ -61,4 +61,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6161
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6262

6363
[targets]
64-
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "CUDA", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
64+
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "KernelAbstractions", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]

ext/DynamicExpressionsCUDAExt.jl renamed to ext/DynamicExpressionsKernelAbstractionsExt.jl

+26-52
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,31 @@
1-
module DynamicExpressionsCUDAExt
1+
module DynamicExpressionsKernelAbstractionsExt
22

3-
using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx
3+
using KernelAbstractions: @index, @kernel, @Const, get_backend
44
using DynamicExpressions: OperatorEnum, AbstractExpressionNode
55
using DynamicExpressions.EvaluateEquationModule: get_nbin, get_nuna
66
using DynamicExpressions.AsArrayModule: as_array
77

8-
import DynamicExpressions.EvaluateEquationModule: eval_tree_array
8+
import DynamicExpressions.ExtensionInterfaceModule: gpu_eval_tree_array
99

10-
# array type for exclusively testing purposes
11-
struct FakeCuArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
12-
a::A
13-
end
14-
Base.similar(x::FakeCuArray, dims::Integer...) = FakeCuArray(similar(x.a, dims...))
15-
Base.getindex(x::FakeCuArray, i::Int...) = getindex(x.a, i...)
16-
Base.setindex!(x::FakeCuArray, v, i::Int...) = setindex!(x.a, v, i...)
17-
Base.size(x::FakeCuArray) = size(x.a)
18-
19-
const MaybeCuArray{T,N} = Union{CuArray{T,N},FakeCuArray{T,N}}
20-
21-
to_device(a, ::CuArray) = CuArray(a)
22-
to_device(a, ::FakeCuArray) = FakeCuArray(a)
23-
24-
function eval_tree_array(
25-
tree::AbstractExpressionNode{T}, gcX::MaybeCuArray{T,2}, operators::OperatorEnum; kws...
10+
function gpu_eval_tree_array(
11+
tree::AbstractExpressionNode{T}, gcX, operators::OperatorEnum; kws...
2612
) where {T<:Number}
27-
(outs, is_good) = eval_tree_array((tree,), gcX, operators; kws...)
13+
(outs, is_good) = gpu_eval_tree_array((tree,), gcX, operators; kws...)
2814
return (only(outs), only(is_good))
2915
end
3016

31-
function eval_tree_array(
17+
function gpu_eval_tree_array(
3218
trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}},
33-
gcX::MaybeCuArray{T,2},
19+
gcX,
3420
operators::OperatorEnum;
21+
backend=get_backend(gcX),
3522
buffer=nothing,
3623
gpu_workspace=nothing,
3724
gpu_buffer=nothing,
3825
roots=nothing,
3926
num_nodes=nothing,
4027
num_launches=nothing,
4128
update_buffers::Val{_update_buffers}=Val(true),
42-
kws...,
4329
) where {T<:Number,N<:AbstractExpressionNode{T},_update_buffers}
4430
if _update_buffers
4531
(; val, roots, buffer, num_nodes, num_launches) = as_array(Int32, trees; buffer)
@@ -82,6 +68,7 @@ function eval_tree_array(
8268

8369
#! format: off
8470
_launch_gpu_kernel!(
71+
backend,
8572
num_threads, num_blocks, num_launches, gworkspace,
8673
# Thread info:
8774
num_elem, num_nodes, gexecution_order,
@@ -99,6 +86,7 @@ end
9986

10087
#! format: off
10188
function _launch_gpu_kernel!(
89+
backend,
10290
num_threads, num_blocks, num_launches::Integer, buffer::AbstractArray{T,2},
10391
# Thread info:
10492
num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray{I},
@@ -114,24 +102,12 @@ function _launch_gpu_kernel!(
114102
gpu_kernel! = create_gpu_kernel(operators, Val(nuna), Val(nbin))
115103
for launch in one(I):I(num_launches)
116104
#! format: off
117-
if buffer isa CuArray
118-
@cuda threads=num_threads blocks=num_blocks gpu_kernel!(
119-
buffer,
120-
launch, num_elem, num_nodes, execution_order,
121-
cX, idx_self, idx_l, idx_r,
122-
degree, constant, val, feature, op
123-
)
124-
else
125-
Threads.@threads for i in 1:(num_threads * num_blocks)
126-
gpu_kernel!(
127-
buffer,
128-
launch, num_elem, num_nodes, execution_order,
129-
cX, idx_self, idx_l, idx_r,
130-
degree, constant, val, feature, op,
131-
i
132-
)
133-
end
134-
end
105+
gpu_kernel!(backend, num_threads * num_blocks)(
106+
buffer,
107+
launch, num_elem, num_nodes, execution_order,
108+
cX, idx_self, idx_l, idx_r,
109+
degree, constant, val, feature, op
110+
)
135111
#! format: on
136112
end
137113
return nothing
@@ -146,19 +122,17 @@ end
146122
for nuna in 0:10, nbin in 0:10
147123
@eval function create_gpu_kernel(operators::OperatorEnum, ::Val{$nuna}, ::Val{$nbin})
148124
#! format: off
149-
function (
125+
@kernel function k(
150126
# Storage:
151127
buffer,
152128
# Thread info:
153-
launch::Integer, num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray,
129+
@Const(launch)::Integer, @Const(num_elem)::Integer, @Const(num_nodes)::Integer, @Const(execution_order)::AbstractArray{I},
154130
# Input data and tree
155-
cX::AbstractArray, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray,
156-
degree::AbstractArray, constant::AbstractArray, val::AbstractArray, feature::AbstractArray, op::AbstractArray,
157-
# Override for unittesting:
158-
i=nothing,
131+
@Const(cX)::AbstractArray, @Const(idx_self)::AbstractArray, @Const(idx_l)::AbstractArray, @Const(idx_r)::AbstractArray,
132+
@Const(degree)::AbstractArray, @Const(constant)::AbstractArray, @Const(val)::AbstractArray, @Const(feature)::AbstractArray, @Const(op)::AbstractArray,
159133
)
160134
#! format: on
161-
i = i === nothing ? (blockIdx().x - 1) * blockDim().x + threadIdx().x : i
135+
i = @index(Global, Linear)
162136
if i > num_elem * num_nodes
163137
return nothing
164138
end
@@ -186,8 +160,8 @@ for nuna in 0:10, nbin in 0:10
186160
l_idx = idx_l[node]
187161
Base.Cartesian.@nif(
188162
$nuna,
189-
i -> i == cur_op,
190-
i -> let op = operators.unaops[i]
163+
j -> j == cur_op,
164+
j -> let op = operators.unaops[j]
191165
buffer[elem, cur_idx] = op(buffer[elem, l_idx])
192166
end
193167
)
@@ -197,8 +171,8 @@ for nuna in 0:10, nbin in 0:10
197171
r_idx = idx_r[node]
198172
Base.Cartesian.@nif(
199173
$nbin,
200-
i -> i == cur_op,
201-
i -> let op = operators.binops[i]
174+
j -> j == cur_op,
175+
j -> let op = operators.binops[j]
202176
buffer[elem, cur_idx] = op(buffer[elem, l_idx], buffer[elem, r_idx])
203177
end
204178
)

src/DynamicExpressions.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ import .EquationModule: constructorof, preserve_sharing
4848
eval_diff_tree_array, eval_grad_tree_array
4949
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
5050
@reexport import .EvaluationHelpersModule
51-
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
51+
@reexport import .ExtensionInterfaceModule:
52+
node_to_symbolic, symbolic_to_node, gpu_eval_tree_array
5253
@reexport import .RandomModule: NodeSampler
5354
@reexport import .AsArrayModule: as_array
5455

src/ExtensionInterface.jl

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ end
1414
function bumper_eval_tree_array(args...)
1515
return error("Please load the Bumper.jl package to use this feature.")
1616
end
17+
function gpu_eval_tree_array(args...)
18+
return error("Please load a GPU backend such as CUDA.jl to use this feature.")
19+
end
1720
function bumper_kern1! end
1821
function bumper_kern2! end
1922

0 commit comments

Comments
 (0)