Skip to content

Commit bfdd286

Browse files
authored
Fix Zygote accum for CA in GPU broadcasting (#221)
1 parent 140b899 commit bfdd286

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.15.1"
4+
version = "0.15.2"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -21,6 +21,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2121
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2222
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2323
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
24+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2425

2526
[extensions]
2627
ComponentArraysAdaptExt = "Adapt"
@@ -30,6 +31,7 @@ ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
3031
ComponentArraysReverseDiffExt = "ReverseDiff"
3132
ComponentArraysSciMLBaseExt = "SciMLBase"
3233
ComponentArraysTrackerExt = "Tracker"
34+
ComponentArraysZygoteExt = "Zygote"
3335

3436
[compat]
3537
Adapt = "3"
@@ -46,6 +48,7 @@ SciMLBase = "1"
4648
StaticArraysCore = "1"
4749
StaticArrayInterface = "1"
4850
Tracker = "0.2"
51+
Zygote = "0.6"
4952
julia = "1.6"
5053

5154
[extras]
@@ -58,3 +61,4 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
5861
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
5962
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6063
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
64+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/ComponentArraysZygoteExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ComponentArraysZygoteExt
2+
3+
using ComponentArrays, Zygote
4+
5+
# For most cases this work. However, if the ComponentArray contains ROCArray, it fails to
6+
# compile the broadcast operation on AMDGPU. This will most likely be fixed with proper
7+
# broadcast mechanics in AMDGPU.jl but we can work around that in a harmless fashion for
8+
# now.
9+
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
10+
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
11+
end
12+
13+
end

0 commit comments

Comments
 (0)