Skip to content

Commit 4c9a761

Browse files
authored
Fix gradient struct return on ARM (rust-lang#378)
* Fix gradient struct return of ARM where the type of CI is a struct * add integration test * add unit test
1 parent 6331179 commit 4c9a761

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,13 +862,23 @@ class Enzyme : public ModulePass {
862862
}
863863
}
864864
}
865+
StructType *CIsty = dyn_cast<StructType>(CI->getType());
866+
StructType *diffretsty = dyn_cast<StructType>(diffret->getType());
865867

866868
if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy() &&
867869
!CI->getType()->isEmptyTy() &&
868870
(!CI->getType()->isVoidTy() ||
869871
CI->paramHasAttr(0, Attribute::StructRet))) {
870872
if (diffret->getType() == CI->getType()) {
871873
CI->replaceAllUsesWith(diffret);
874+
} else if (CIsty && diffretsty && CIsty->isLayoutIdentical(diffretsty)) {
875+
IRBuilder<> Builder(CI);
876+
Value *newStruct = UndefValue::get(CIsty);
877+
for (unsigned int i = 0; i < CIsty->getStructNumElements(); i++) {
878+
Value *elem = Builder.CreateExtractValue(diffret, {i});
879+
newStruct = Builder.CreateInsertValue(newStruct, elem, {i});
880+
}
881+
CI->replaceAllUsesWith(newStruct);
872882
} else if (mode == DerivativeMode::ReverseModePrimal) {
873883
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
874884
if (DL.getTypeSizeInBits(CI->getType()) >=
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
%struct.Gradients = type { double, double }
4+
5+
define dso_local double @muldd(double %x, double %y) #0 {
6+
entry:
7+
%mul = fmul double %x, %y
8+
ret double %mul
9+
}
10+
11+
define dso_local %struct.Gradients @dmuldd(double %x, double %y) local_unnamed_addr #1 {
12+
entry:
13+
%call = call %struct.Gradients (i8*, ...) @_Z17__enzyme_autodiffPvz(i8* bitcast (double (double, double)* @muldd to i8*), double %x, double %y)
14+
ret %struct.Gradients %call
15+
}
16+
17+
declare dso_local %struct.Gradients @_Z17__enzyme_autodiffPvz(i8*, ...) local_unnamed_addr #2
18+
19+
attributes #0 = { norecurse nounwind readnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="non-leaf" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon" "unsafe-fp-math"="false" "use-soft-float"="false" }
20+
attributes #1 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="non-leaf" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon" "unsafe-fp-math"="false" "use-soft-float"="false" }
21+
attributes #2 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="non-leaf" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon" "unsafe-fp-math"="false" "use-soft-float"="false" }
22+
23+
; CHECK: define internal {{(dso_local )?}}{ double, double } @diffemuldd(double %x, double %y, double %differeturn)
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: %m0diffex = fmul fast double %differeturn, %y
26+
; CHECK-NEXT: %m1diffey = fmul fast double %differeturn, %x
27+
; CHECK-NEXT: %0 = insertvalue { double, double } undef, double %m0diffex, 0
28+
; CHECK-NEXT: %1 = insertvalue { double, double } %0, double %m1diffey, 1
29+
; CHECK-NEXT: ret { double, double } %1
30+
; CHECK-NEXT: }
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
2+
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
3+
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
4+
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
5+
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
6+
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
7+
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
8+
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
9+
10+
#include <stdio.h>
11+
#include "test_utils.h"
12+
13+
typedef struct {
14+
double dx,dy;
15+
} Gradients;
16+
17+
extern Gradients __enzyme_autodiff(void*, ...);
18+
19+
double mul(double x, double y) {
20+
return x * y;
21+
}
22+
Gradients dmul(double x, double y) {
23+
return __enzyme_autodiff((void*)mul, x, y);
24+
}
25+
int main() {
26+
double x = 1.0;
27+
double y = 2.0;
28+
printf("mul(x=%f,y%f)=%f\n", x, y, mul(x,y));
29+
printf("ddx dmul(x=%f,y%f)=%f\n", x, y, dmul(x,y).dx);
30+
printf("ddy dmul(x=%f,y%f)=%f\n", x, y, dmul(x,y).dy);
31+
APPROX_EQ(dmul(x,y).dx, 2.0, 10e-10);
32+
APPROX_EQ(dmul(x,y).dy, 1.0, 10e-10);
33+
}

0 commit comments

Comments
 (0)