Skip to content

Commit 68db988

Browse files
authored
Better sret (rust-lang#708)
1 parent f6fed13 commit 68db988

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,14 +1254,19 @@ class Enzyme : public ModulePass {
12541254
#endif
12551255
CI->replaceAllUsesWith(cload);
12561256
} else {
1257-
llvm::errs() << *CI << " - " << *diffret << "\n";
1258-
assert(0 && " what");
1257+
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
1258+
"Cannot cast return type of gradient ",
1259+
*diffret->getType(), *diffret, ", to desired type ",
1260+
*CI->getType());
1261+
return false;
12591262
}
12601263
} else if (CI->hasStructRetAttr()) {
12611264
Value *sret = CI->getArgOperand(0);
1265+
PointerType *stype = cast<PointerType>(sret->getType());
1266+
StructType *st = dyn_cast<StructType>(stype->getElementType());
12621267

12631268
// Assign results to struct allocated at the call site.
1264-
if (StructType *st = cast<StructType>(diffret->getType())) {
1269+
if (st && st->isLayoutIdentical(diffretsty)) {
12651270
for (unsigned int i = 0; i < st->getNumElements(); i++) {
12661271
#if LLVM_VERSION_MAJOR > 7
12671272
Value *sgep = Builder.CreateStructGEP(
@@ -1271,6 +1276,20 @@ class Enzyme : public ModulePass {
12711276
#endif
12721277
Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}), sgep);
12731278
}
1279+
} else {
1280+
auto &DL = fn->getParent()->getDataLayout();
1281+
if (DL.getTypeSizeInBits(stype->getElementType()) !=
1282+
DL.getTypeSizeInBits(diffret->getType())) {
1283+
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
1284+
"Cannot cast return type of gradient ",
1285+
*diffret->getType(), *diffret, ", to desired type ",
1286+
*stype->getElementType());
1287+
return false;
1288+
}
1289+
Builder.CreateStore(
1290+
diffret, Builder.CreatePointerCast(
1291+
sret, PointerType::get(diffret->getType(),
1292+
stype->getAddressSpace())));
12741293
}
12751294
} else {
12761295

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
2+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
3+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
4+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
5+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
6+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
7+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
8+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
9+
10+
#include "test_utils.h"
11+
#include <iostream>
12+
#include <sstream>
13+
#include <utility>
14+
15+
typedef struct {
16+
double df[3];
17+
} Gradient;
18+
extern Gradient __enzyme_autodiff(void*, double, double , double);
19+
20+
double myfunction(double x, double y, double z){
21+
return x * y * z;
22+
}
23+
24+
void dmyfunction(double x, double y, double z, double* res) {
25+
Gradient g = __enzyme_autodiff((void*)myfunction, x, y, z);
26+
27+
res[0]=g.df[0];
28+
res[1]=g.df[1];
29+
res[2]=g.df[2];
30+
}
31+
32+
int main() {
33+
double *res=new double(3);
34+
dmyfunction(3,4,5,res);
35+
APPROX_EQ(res[0], 4*5., 1e-7);
36+
APPROX_EQ(res[1], 3*5., 1e-7);
37+
APPROX_EQ(res[1], 3*4., 1e-7);
38+
}

0 commit comments

Comments
 (0)