Skip to content

Commit 6d4ea42

Browse files
authored
Write direct dispatches for axpy! & axpby! (#225)
1 parent df9bd66 commit 6d4ea42

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.15.3"
4+
version = "0.15.4"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/linear_algebra.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,18 @@ for op in [:*, :\, :/]
4242
end
4343
end
4444
end
45-
end
45+
end
46+
47+
# Common Accumulation Operations
48+
## Needed for CUDA to work properly
49+
function LinearAlgebra.axpy!::Number, x::ComponentArray, y::ComponentArray)
50+
getaxes(x) != getaxes(y) && throw(ArgumentError("Axes of `x` and `y` must match"))
51+
axpy!(α, getdata(x), getdata(y))
52+
return ComponentArray(y, getaxes(y))
53+
end
54+
55+
function LinearAlgebra.axpby!::Number, x::ComponentArray, β::Number, y::ComponentArray)
56+
getaxes(x) != getaxes(y) && throw(ArgumentError("Axes of `x` and `y` must match"))
57+
axpby!(α, getdata(x), β, getdata(y))
58+
return ComponentArray(y, getaxes(y))
59+
end

test/runtests.jl

+23-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ end
123123
# Issue #116
124124
# Part 2: Arrays of arrays
125125
@test_throws Exception ComponentVector(a = [[3], [4, 5]], b = 1)
126-
126+
127127
x = ComponentVector(a = [[3, 3], [4, 5]], b = 1)
128128
@test x.a[1] == [3, 3]
129129
@test x.b == 1
@@ -668,6 +668,28 @@ end
668668
@test ndims(dropdims(ones(1,1), dims=(1,2))) == 0
669669
end
670670

671+
@testset "axpy! / axpby!" begin
672+
y = ComponentArray(a = rand(4), b = rand(4))
673+
x = ComponentArray(a = rand(4), b = rand(4))
674+
ydata = copy(getdata(y))
675+
676+
axpy!(2, x, y)
677+
@test getdata(y) == 2 .* getdata(x) .+ ydata
678+
679+
x = ComponentArray(a = rand(4), c = rand(4))
680+
@test_throws ArgumentError axpy!(2, x, y)
681+
682+
y = ComponentArray(a = rand(4), b = rand(4))
683+
x = ComponentArray(a = rand(4), b = rand(4))
684+
ydata = copy(getdata(y))
685+
686+
axpby!(2, x, 3, y)
687+
@test getdata(y) == 2 .* getdata(x) .+ 3 .* ydata
688+
689+
x = ComponentArray(a = rand(4), c = rand(4))
690+
@test_throws ArgumentError axpby!(2, x, 3, y)
691+
end
692+
671693
@testset "Autodiff" begin
672694
include("autodiff_tests.jl")
673695
end

0 commit comments

Comments
 (0)