Skip to content

Commit f6155ba

Browse files
committed
Fix unsafe sin in CUDA kernel tests
1 parent 9d7500c commit f6155ba

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

ext/DynamicExpressionsCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function eval_tree_array(
8383
)
8484
#! format: on
8585

86-
out = ntuple(i -> @view(gworkspace[begin:end-1, roots[i]]), Val(M + 1))
86+
out = ntuple(i -> @view(gworkspace[begin:(end - 1), roots[i]]), Val(M + 1))
8787
is_good = ntuple(
8888
i -> true, # Up to user to find NaNs
8989
Val(M + 1),

test/test_cuda.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,31 @@ const FakeCuArray = ext.FakeCuArray
77

88
include("tree_gen_utils.jl")
99

10+
safe_sin(x) = isfinite(x) ? sin(x) : convert(eltype(x), NaN)
11+
safe_cos(x) = isfinite(x) ? cos(x) : convert(eltype(x), NaN)
12+
1013
let
11-
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin]);
14+
operators = OperatorEnum(;
15+
binary_operators=[+, -, *, /], unary_operators=[safe_sin, safe_cos]
16+
)
1217
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)
1318

1419
for T in (Float32, Float64, ComplexF64), num_trees in (1, 2, 3), seed in 0:10
1520
Random.seed!(seed)
1621
num_rows = rand(10:30)
1722
nodes_per = rand(10:25, num_trees)
18-
trees = ntuple(i -> gen_random_tree_fixed_size(nodes_per[i], operators, 3, T), num_trees)
19-
@show trees
23+
trees = ntuple(
24+
i -> gen_random_tree_fixed_size(nodes_per[i], operators, 3, T), num_trees
25+
)
2026
X = randn(T, 3, num_rows)
2127
y, completed = eval_tree_array(trees, X, operators)
2228
gpu_y, gpu_completed = eval_tree_array(trees, FakeCuArray(X), operators)
2329
gpu_y = Array.(gpu_y)
2430

2531
for i in eachindex(completed, gpu_completed)
26-
@test ((completed[i] && gpu_completed[i]) && (y[i] gpu_y[i])) || (!completed[i] && !gpu_completed[i])
32+
if completed[i]
33+
@test y[i] gpu_y[i]
34+
end
2735
end
2836
end
2937
end

0 commit comments

Comments
 (0)