Skip to content

Commit 45bc702

Browse files
authored
Vector Tests (rust-lang#357)
* fix bugs * add test * fix test
1 parent c98cc98 commit 45bc702

File tree

3 files changed

+57
-9
lines changed

3 files changed

+57
-9
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,8 +1207,9 @@ class AdjointGenerator
12071207

12081208
Value *prediff =
12091209
gutils->isConstantValue(orig_vector)
1210-
? diffe(orig_vector, Builder2)
1211-
: ConstantVector::getNullValue(orig_vector->getType());
1210+
? ConstantVector::getNullValue(orig_vector->getType())
1211+
: diffe(orig_vector, Builder2);
1212+
12121213
auto dindex = Builder2.CreateInsertElement(
12131214
prediff, diff_inserted, gutils->getNewFromOriginal(orig_index));
12141215
setDiffe(&IEI, dindex, Builder2);
@@ -1275,7 +1276,6 @@ class AdjointGenerator
12751276

12761277
Value *orig_vector1 = SVI.getOperand(0);
12771278
Value *orig_vector2 = SVI.getOperand(1);
1278-
Value *orig_mask = SVI.getOperand(0);
12791279

12801280
auto diffe_vector1 =
12811281
gutils->isConstantValue(orig_vector1)
@@ -1286,8 +1286,13 @@ class AdjointGenerator
12861286
? ConstantVector::getNullValue(orig_vector2->getType())
12871287
: diffe(orig_vector2, Builder2);
12881288

1289-
auto diffe = Builder2.CreateShuffleVector(
1290-
diffe_vector1, diffe_vector2, gutils->getNewFromOriginal(orig_mask));
1289+
#if LLVM_VERSION_MAJOR >= 11
1290+
auto diffe = Builder2.CreateShuffleVector(diffe_vector1, diffe_vector2,
1291+
SVI.getShuffleMaskForBitcode());
1292+
#else
1293+
auto diffe = Builder2.CreateShuffleVector(diffe_vector1, diffe_vector2,
1294+
SVI.getOperand(2));
1295+
#endif
12911296

12921297
setDiffe(&SVI, diffe, Builder2);
12931298
return;
@@ -1465,8 +1470,8 @@ class AdjointGenerator
14651470

14661471
Value *prediff =
14671472
gutils->isConstantValue(orig_agg)
1468-
? diffe(orig_agg, Builder2)
1469-
: ConstantAggregate::getNullValue(orig_agg->getType());
1473+
? ConstantAggregate::getNullValue(orig_agg->getType())
1474+
: diffe(orig_agg, Builder2);
14701475
auto dindex =
14711476
Builder2.CreateInsertValue(prediff, diff_inserted, IVI.getIndices());
14721477
setDiffe(&IVI, dindex, Builder2);

enzyme/test/Enzyme/ForwardMode/square_array.ll

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ declare { double, double } @__enzyme_fwddiff(i8*, ...)
2727
; CHECK-NEXT: %2 = fmul fast double %1, %x
2828
; CHECK-NEXT: %3 = fmul fast double %"x'", %mul
2929
; CHECK-NEXT: %4 = fadd fast double %2, %3
30-
; CHECK-NEXT: %5 = insertvalue { double, double } zeroinitializer, double %4, 1
31-
; CHECK-NEXT: ret { double, double } %5
30+
; CHECK-NEXT: %5 = insertvalue { double, double } zeroinitializer, double %1, 0
31+
; CHECK-NEXT: %6 = insertvalue { double, double } %5, double %4, 1
32+
; CHECK-NEXT: ret { double, double } %6
3233
; CHECK-NEXT: }
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -instcombine -S | FileCheck %s
2+
3+
declare {float, float, float} @__enzyme_fwddiff({float, float, float} (<4 x float>)*, <4 x float>, <4 x float>)
4+
5+
define {float, float, float} @square(<4 x float> %x) {
6+
entry:
7+
%vec = insertelement <4 x float> %x, float 1.0, i32 3
8+
%sq = fmul <4 x float> %x, %x
9+
%cb = fmul <4 x float> %sq, %x
10+
%id = shufflevector <4 x float> %sq, <4 x float> %cb, <4 x i32> <i32 0, i32 1, i32 4, i32 5>
11+
%res1 = extractelement <4 x float> %id, i32 1
12+
%res2 = extractelement <4 x float> %id, i32 2
13+
%res3 = extractelement <4 x float> %id, i32 3
14+
%agg1 = insertvalue {float, float, float} undef, float %res1, 0
15+
%agg2 = insertvalue {float, float, float} %agg1, float %res2, 1
16+
%agg3 = insertvalue {float, float, float} %agg2, float %res3, 2
17+
ret {float, float, float} %agg3
18+
}
19+
20+
define {float, float, float} @dsquare(<4 x float> %x) {
21+
entry:
22+
%call = tail call {float, float, float} @__enzyme_fwddiff({float, float, float} (<4 x float>)* @square, <4 x float> %x, <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>)
23+
ret {float, float, float} %call
24+
}
25+
26+
27+
; CHECK: define internal { float, float, float } @fwddiffesquare(<4 x float> %x, <4 x float> %"x'")
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %sq = fmul <4 x float> %x, %x
30+
; CHECK-NEXT: %0 = fmul fast <4 x float> %"x'", %x
31+
; CHECK-NEXT: %1 = fadd fast <4 x float> %0, %0
32+
; CHECK-NEXT: %2 = fmul fast <4 x float> %1, %x
33+
; CHECK-NEXT: %3 = fmul fast <4 x float> %sq, %"x'"
34+
; CHECK-NEXT: %4 = fadd fast <4 x float> %2, %3
35+
; CHECK-NEXT: %5 = extractelement <4 x float> %1, i32 1
36+
; CHECK-NEXT: %6 = extractelement <4 x float> %4, i32 0
37+
; CHECK-NEXT: %7 = extractelement <4 x float> %4, i32 1
38+
; CHECK-NEXT: %8 = insertvalue { float, float, float } zeroinitializer, float %5, 0
39+
; CHECK-NEXT: %9 = insertvalue { float, float, float } %8, float %6, 1
40+
; CHECK-NEXT: %10 = insertvalue { float, float, float } %9, float %7, 2
41+
; CHECK-NEXT: ret { float, float, float } %10
42+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)