Skip to content

Commit 54459b9

Browse files
authored
Implement vector mode: aggregates and vectors (rust-lang#453)
1 parent 2ba00ed commit 54459b9

File tree

3 files changed

+220
-33
lines changed

3 files changed

+220
-33
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,13 +1205,18 @@ class AdjointGenerator
12051205
getForwardBuilder(Builder2);
12061206

12071207
Value *orig_vec = EEI.getVectorOperand();
1208+
Type *vecTy = gutils->getShadowType(orig_vec->getType());
12081209

12091210
auto vec_diffe = gutils->isConstantValue(orig_vec)
1210-
? ConstantVector::getNullValue(orig_vec->getType())
1211+
? Constant::getNullValue(vecTy)
12111212
: diffe(orig_vec, Builder2);
1212-
auto diffe =
1213-
Builder2.CreateExtractElement(vec_diffe, EEI.getIndexOperand());
12141213

1214+
auto rule = [&](Value *vec_diffe) {
1215+
return Builder2.CreateExtractElement(
1216+
vec_diffe, gutils->getNewFromOriginal(EEI.getIndexOperand()));
1217+
};
1218+
1219+
auto diffe = applyChainRule(EEI.getType(), Builder2, rule, vec_diffe);
12151220
setDiffe(&EEI, diffe, Builder2);
12161221
return;
12171222
}
@@ -1260,19 +1265,25 @@ class AdjointGenerator
12601265
Value *orig_inserted = IEI.getOperand(1);
12611266
Value *orig_index = IEI.getOperand(2);
12621267

1268+
Type *insertedTy = gutils->getShadowType(orig_inserted->getType());
1269+
Type *vectorTy = gutils->getShadowType(orig_vector->getType());
1270+
12631271
Value *diff_inserted = gutils->isConstantValue(orig_inserted)
1264-
? ConstantFP::get(orig_inserted->getType(), 0)
1272+
? Constant::getNullValue(insertedTy)
12651273
: diffe(orig_inserted, Builder2);
12661274

1267-
Value *prediff =
1268-
gutils->isConstantValue(orig_vector)
1269-
? ConstantVector::getNullValue(orig_vector->getType())
1270-
: diffe(orig_vector, Builder2);
1275+
Value *prediff = gutils->isConstantValue(orig_vector)
1276+
? Constant::getNullValue(vectorTy)
1277+
: diffe(orig_vector, Builder2);
12711278

1272-
auto dindex = Builder2.CreateInsertElement(
1273-
prediff, diff_inserted, gutils->getNewFromOriginal(orig_index));
1274-
setDiffe(&IEI, dindex, Builder2);
1279+
auto rule = [&](Value *diff_inserted, Value *prediff) {
1280+
return Builder2.CreateInsertElement(
1281+
prediff, diff_inserted, gutils->getNewFromOriginal(orig_index));
1282+
};
12751283

1284+
Value *dindex =
1285+
applyChainRule(IEI.getType(), Builder2, rule, diff_inserted, prediff);
1286+
setDiffe(&IEI, dindex, Builder2);
12761287
return;
12771288
}
12781289
case DerivativeMode::ReverseModeGradient:
@@ -1345,14 +1356,19 @@ class AdjointGenerator
13451356
? ConstantVector::getNullValue(orig_vector2->getType())
13461357
: diffe(orig_vector2, Builder2);
13471358

1359+
auto rule = [&](Value *diffe_vector1, Value *diffe_vector2) {
13481360
#if LLVM_VERSION_MAJOR >= 11
1349-
auto diffe = Builder2.CreateShuffleVector(diffe_vector1, diffe_vector2,
1350-
SVI.getShuffleMaskForBitcode());
1361+
auto diffe = Builder2.CreateShuffleVector(
1362+
diffe_vector1, diffe_vector2, SVI.getShuffleMaskForBitcode());
13511363
#else
1352-
auto diffe = Builder2.CreateShuffleVector(diffe_vector1, diffe_vector2,
1353-
SVI.getOperand(2));
1364+
auto diffe = Builder2.CreateShuffleVector(diffe_vector1, diffe_vector2,
1365+
SVI.getOperand(2));
13541366
#endif
1367+
return diffe;
1368+
};
13551369

1370+
auto diffe = applyChainRule(SVI.getType(), Builder2, rule, diffe_vector1,
1371+
diffe_vector2);
13561372
setDiffe(&SVI, diffe, Builder2);
13571373
return;
13581374
}
@@ -1417,15 +1433,19 @@ class AdjointGenerator
14171433
getForwardBuilder(Builder2);
14181434

14191435
Value *orig_aggregate = EVI.getAggregateOperand();
1436+
Type *agg_type = gutils->getShadowType(orig_aggregate->getType());
1437+
1438+
Value *diffe_aggregate = gutils->isConstantValue(orig_aggregate)
1439+
? Constant::getNullValue(agg_type)
1440+
: diffe(orig_aggregate, Builder2);
14201441

1421-
Value *diffe_aggregate =
1422-
gutils->isConstantValue(orig_aggregate)
1423-
? ConstantAggregate::getNullValue(orig_aggregate->getType())
1424-
: diffe(orig_aggregate, Builder2);
1425-
Value *diffe =
1426-
Builder2.CreateExtractValue(diffe_aggregate, EVI.getIndices());
1442+
auto rule = [&](Value *diffe_aggregate) {
1443+
return Builder2.CreateExtractValue(diffe_aggregate, EVI.getIndices());
1444+
};
14271445

1428-
setDiffe(&EVI, diffe, Builder2);
1446+
Value *diff =
1447+
applyChainRule(EVI.getType(), Builder2, rule, diffe_aggregate);
1448+
setDiffe(&EVI, diff, Builder2);
14291449
return;
14301450
}
14311451
case DerivativeMode::ReverseModeGradient:
@@ -1526,20 +1546,27 @@ class AdjointGenerator
15261546
IRBuilder<> Builder2(&IVI);
15271547
getForwardBuilder(Builder2);
15281548

1529-
Value *orig_inserted = IVI.getInsertedValueOperand();
1549+
Value *orig_val = IVI.getInsertedValueOperand();
15301550
Value *orig_agg = IVI.getAggregateOperand();
15311551

1532-
Value *diff_inserted = gutils->isConstantValue(orig_inserted)
1533-
? ConstantFP::get(orig_inserted->getType(), 0)
1534-
: diffe(orig_inserted, Builder2);
1552+
Type *val_type = gutils->getShadowType(orig_val->getType());
1553+
Type *agg_type = gutils->getShadowType(orig_agg->getType());
1554+
1555+
Value *diff_val = gutils->isConstantValue(orig_val)
1556+
? Constant::getNullValue(val_type)
1557+
: diffe(orig_val, Builder2);
1558+
1559+
Value *diff_agg = gutils->isConstantValue(orig_agg)
1560+
? Constant::getNullValue(agg_type)
1561+
: diffe(orig_agg, Builder2);
1562+
1563+
auto rule = [&](Value *diff_agg, Value *diff_val) {
1564+
return Builder2.CreateInsertValue(diff_agg, diff_val, IVI.getIndices());
1565+
};
15351566

1536-
Value *prediff =
1537-
gutils->isConstantValue(orig_agg)
1538-
? ConstantAggregate::getNullValue(orig_agg->getType())
1539-
: diffe(orig_agg, Builder2);
1540-
auto dindex =
1541-
Builder2.CreateInsertValue(prediff, diff_inserted, IVI.getIndices());
1542-
setDiffe(&IVI, dindex, Builder2);
1567+
Value *diff = applyChainRule(orig_agg->getType(), Builder2, rule,
1568+
diff_agg, diff_val);
1569+
setDiffe(&IVI, diff, Builder2);
15431570

15441571
return;
15451572
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
%struct.Gradients = type { double, double, double }
4+
5+
; Function Attrs: nounwind
6+
declare %struct.Gradients @__enzyme_fwddiff(double (double)*, ...)
7+
8+
; Function Attrs: noinline nounwind readnone uwtable
9+
define double @tester(double %x) {
10+
entry:
11+
%agg1 = insertvalue [3 x double] undef, double %x, 0
12+
%mul = fmul double %x, %x
13+
%agg2 = insertvalue [3 x double] %agg1, double %mul, 1
14+
%add = fadd double %mul, 2.0
15+
%agg3 = insertvalue [3 x double] %agg2, double %add, 2
16+
%res = extractvalue [3 x double] %agg2, 1
17+
ret double %res
18+
}
19+
20+
define %struct.Gradients @test_derivative(double %x) {
21+
entry:
22+
%0 = tail call %struct.Gradients (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 2.0, double 3.0)
23+
ret %struct.Gradients %0
24+
}
25+
26+
27+
; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'")
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %0 = extractvalue [3 x double] %"x'", 0
30+
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
31+
; CHECK-NEXT: %2 = fmul fast double %0, %x
32+
; CHECK-NEXT: %3 = fmul fast double %1, %x
33+
; CHECK-NEXT: %4 = fadd fast double %2, %3
34+
; CHECK-NEXT: %5 = extractvalue [3 x double] %"x'", 1
35+
; CHECK-NEXT: %6 = extractvalue [3 x double] %"x'", 1
36+
; CHECK-NEXT: %7 = fmul fast double %5, %x
37+
; CHECK-NEXT: %8 = fmul fast double %6, %x
38+
; CHECK-NEXT: %9 = fadd fast double %7, %8
39+
; CHECK-NEXT: %10 = extractvalue [3 x double] %"x'", 2
40+
; CHECK-NEXT: %11 = extractvalue [3 x double] %"x'", 2
41+
; CHECK-NEXT: %12 = fmul fast double %10, %x
42+
; CHECK-NEXT: %13 = fmul fast double %11, %x
43+
; CHECK-NEXT: %14 = fadd fast double %12, %13
44+
; CHECK-NEXT: %15 = insertvalue [3 x double] undef, double %4, 0
45+
; CHECK-NEXT: %16 = insertvalue [3 x double] %15, double %9, 1
46+
; CHECK-NEXT: %17 = insertvalue [3 x double] %16, double %14, 2
47+
; CHECK-NEXT: ret [3 x double] %17
48+
; CHECK-NEXT: }
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -S | FileCheck %s
2+
3+
%struct.Gradients = type { {float, float, float}, {float, float, float} }
4+
5+
declare %struct.Gradients @__enzyme_fwddiff({float, float, float} (<4 x float>)*, ...)
6+
7+
define {float, float, float} @square(<4 x float> %x) {
8+
entry:
9+
%vec = insertelement <4 x float> %x, float 1.0, i32 3
10+
%sq = fmul <4 x float> %x, %x
11+
%cb = fmul <4 x float> %sq, %x
12+
%id = shufflevector <4 x float> %sq, <4 x float> %cb, <4 x i32> <i32 0, i32 1, i32 4, i32 5>
13+
%res1 = extractelement <4 x float> %id, i32 1
14+
%res2 = extractelement <4 x float> %id, i32 2
15+
%res3 = extractelement <4 x float> %id, i32 3
16+
%agg1 = insertvalue {float, float, float} undef, float %res1, 0
17+
%agg2 = insertvalue {float, float, float} %agg1, float %res2, 1
18+
%agg3 = insertvalue {float, float, float} %agg2, float %res3, 2
19+
ret {float, float, float} %agg3
20+
}
21+
22+
define %struct.Gradients @dsquare(<4 x float> %x) {
23+
entry:
24+
%call = tail call %struct.Gradients ({float, float, float} (<4 x float>)*, ...) @__enzyme_fwddiff({float, float, float} (<4 x float>)* @square, metadata !"enzyme_width", i64 2, <4 x float> %x, <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>)
25+
ret %struct.Gradients %call
26+
}
27+
28+
29+
; CHECK: define internal [2 x { float, float, float }] @fwddiffe2square(<4 x float> %x, [2 x <4 x float>] %"x'")
30+
; CHECK-NEXT: entry:
31+
; CHECK-NEXT: %sq = fmul <4 x float> %x, %x
32+
; CHECK-NEXT: %0 = extractvalue [2 x <4 x float>] %"x'", 0
33+
; CHECK-NEXT: %1 = extractvalue [2 x <4 x float>] %"x'", 0
34+
; CHECK-NEXT: %2 = fmul fast <4 x float> %0, %x
35+
; CHECK-NEXT: %3 = fmul fast <4 x float> %1, %x
36+
; CHECK-NEXT: %4 = fadd fast <4 x float> %2, %3
37+
; CHECK-NEXT: %5 = insertvalue [2 x <4 x float>] undef, <4 x float> %4, 0
38+
; CHECK-NEXT: %6 = extractvalue [2 x <4 x float>] %"x'", 1
39+
; CHECK-NEXT: %7 = extractvalue [2 x <4 x float>] %"x'", 1
40+
; CHECK-NEXT: %8 = fmul fast <4 x float> %6, %x
41+
; CHECK-NEXT: %9 = fmul fast <4 x float> %7, %x
42+
; CHECK-NEXT: %10 = fadd fast <4 x float> %8, %9
43+
; CHECK-NEXT: %11 = insertvalue [2 x <4 x float>] %5, <4 x float> %10, 1
44+
; CHECK-NEXT: %cb = fmul <4 x float> %sq, %x
45+
; CHECK-NEXT: %12 = extractvalue [2 x <4 x float>] %11, 0
46+
; CHECK-NEXT: %13 = extractvalue [2 x <4 x float>] %"x'", 0
47+
; CHECK-NEXT: %14 = fmul fast <4 x float> %12, %x
48+
; CHECK-NEXT: %15 = fmul fast <4 x float> %13, %sq
49+
; CHECK-NEXT: %16 = fadd fast <4 x float> %14, %15
50+
; CHECK-NEXT: %17 = insertvalue [2 x <4 x float>] undef, <4 x float> %16, 0
51+
; CHECK-NEXT: %18 = extractvalue [2 x <4 x float>] %11, 1
52+
; CHECK-NEXT: %19 = extractvalue [2 x <4 x float>] %"x'", 1
53+
; CHECK-NEXT: %20 = fmul fast <4 x float> %18, %x
54+
; CHECK-NEXT: %21 = fmul fast <4 x float> %19, %sq
55+
; CHECK-NEXT: %22 = fadd fast <4 x float> %20, %21
56+
; CHECK-NEXT: %23 = insertvalue [2 x <4 x float>] %17, <4 x float> %22, 1
57+
; CHECK-NEXT: %id = shufflevector <4 x float> %sq, <4 x float> %cb, <4 x i32> <i32 0, i32 1, i32 4, i32 5>
58+
; CHECK-NEXT: %24 = extractvalue [2 x <4 x float>] %11, 0
59+
; CHECK-NEXT: %25 = extractvalue [2 x <4 x float>] %23, 0
60+
; CHECK-NEXT: %26 = shufflevector <4 x float> %24, <4 x float> %25, <4 x i32> <i32 0, i32 1, i32 4, i32 5>
61+
; CHECK-NEXT: %27 = insertvalue [2 x <4 x float>] undef, <4 x float> %26, 0
62+
; CHECK-NEXT: %28 = extractvalue [2 x <4 x float>] %11, 1
63+
; CHECK-NEXT: %29 = extractvalue [2 x <4 x float>] %23, 1
64+
; CHECK-NEXT: %30 = shufflevector <4 x float> %28, <4 x float> %29, <4 x i32> <i32 0, i32 1, i32 4, i32 5>
65+
; CHECK-NEXT: %31 = insertvalue [2 x <4 x float>] %27, <4 x float> %30, 1
66+
; CHECK-NEXT: %res1 = extractelement <4 x float> %id, i32 1
67+
; CHECK-NEXT: %32 = extractvalue [2 x <4 x float>] %31, 0
68+
; CHECK-NEXT: %33 = extractelement <4 x float> %32, i32 1
69+
; CHECK-NEXT: %34 = insertvalue [2 x float] undef, float %33, 0
70+
; CHECK-NEXT: %35 = extractvalue [2 x <4 x float>] %31, 1
71+
; CHECK-NEXT: %36 = extractelement <4 x float> %35, i32 1
72+
; CHECK-NEXT: %37 = insertvalue [2 x float] %34, float %36, 1
73+
; CHECK-NEXT: %res2 = extractelement <4 x float> %id, i32 2
74+
; CHECK-NEXT: %38 = extractvalue [2 x <4 x float>] %31, 0
75+
; CHECK-NEXT: %39 = extractelement <4 x float> %38, i32 2
76+
; CHECK-NEXT: %40 = insertvalue [2 x float] undef, float %39, 0
77+
; CHECK-NEXT: %41 = extractvalue [2 x <4 x float>] %31, 1
78+
; CHECK-NEXT: %42 = extractelement <4 x float> %41, i32 2
79+
; CHECK-NEXT: %43 = insertvalue [2 x float] %40, float %42, 1
80+
; CHECK-NEXT: %res3 = extractelement <4 x float> %id, i32 3
81+
; CHECK-NEXT: %44 = extractvalue [2 x <4 x float>] %31, 0
82+
; CHECK-NEXT: %45 = extractelement <4 x float> %44, i32 3
83+
; CHECK-NEXT: %46 = insertvalue [2 x float] undef, float %45, 0
84+
; CHECK-NEXT: %47 = extractvalue [2 x <4 x float>] %31, 1
85+
; CHECK-NEXT: %48 = extractelement <4 x float> %47, i32 3
86+
; CHECK-NEXT: %49 = insertvalue [2 x float] %46, float %48, 1
87+
; CHECK-NEXT: %agg1 = insertvalue { float, float, float } undef, float %res1, 0
88+
; CHECK-NEXT: %50 = extractvalue [2 x float] %37, 0
89+
; CHECK-NEXT: %51 = insertvalue { float, float, float } zeroinitializer, float %50, 0
90+
; CHECK-NEXT: %52 = insertvalue [2 x { float, float, float }] undef, { float, float, float } %51, 0
91+
; CHECK-NEXT: %53 = extractvalue [2 x float] %37, 1
92+
; CHECK-NEXT: %54 = insertvalue { float, float, float } zeroinitializer, float %53, 0
93+
; CHECK-NEXT: %55 = insertvalue [2 x { float, float, float }] %52, { float, float, float } %54, 1
94+
; CHECK-NEXT: %agg2 = insertvalue { float, float, float } %agg1, float %res2, 1
95+
; CHECK-NEXT: %56 = extractvalue [2 x { float, float, float }] %55, 0
96+
; CHECK-NEXT: %57 = extractvalue [2 x float] %43, 0
97+
; CHECK-NEXT: %58 = insertvalue { float, float, float } %56, float %57, 1
98+
; CHECK-NEXT: %59 = insertvalue [2 x { float, float, float }] undef, { float, float, float } %58, 0
99+
; CHECK-NEXT: %60 = extractvalue [2 x { float, float, float }] %55, 1
100+
; CHECK-NEXT: %61 = extractvalue [2 x float] %43, 1
101+
; CHECK-NEXT: %62 = insertvalue { float, float, float } %60, float %61, 1
102+
; CHECK-NEXT: %63 = insertvalue [2 x { float, float, float }] %59, { float, float, float } %62, 1
103+
; CHECK-NEXT: %64 = extractvalue [2 x { float, float, float }] %63, 0
104+
; CHECK-NEXT: %65 = extractvalue [2 x float] %49, 0
105+
; CHECK-NEXT: %66 = insertvalue { float, float, float } %64, float %65, 2
106+
; CHECK-NEXT: %67 = insertvalue [2 x { float, float, float }] undef, { float, float, float } %66, 0
107+
; CHECK-NEXT: %68 = extractvalue [2 x { float, float, float }] %63, 1
108+
; CHECK-NEXT: %69 = extractvalue [2 x float] %49, 1
109+
; CHECK-NEXT: %70 = insertvalue { float, float, float } %68, float %69, 2
110+
; CHECK-NEXT: %71 = insertvalue [2 x { float, float, float }] %67, { float, float, float } %70, 1
111+
; CHECK-NEXT: ret [2 x { float, float, float }] %71
112+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)