@@ -1254,14 +1254,19 @@ class Enzyme : public ModulePass {
1254
1254
#endif
1255
1255
CI->replaceAllUsesWith (cload);
1256
1256
} else {
1257
- llvm::errs () << *CI << " - " << *diffret << " \n " ;
1258
- assert (0 && " what" );
1257
+ EmitFailure (" IllegalReturnCast" , CI->getDebugLoc (), CI,
1258
+ " Cannot cast return type of gradient " ,
1259
+ *diffret->getType (), *diffret, " , to desired type " ,
1260
+ *CI->getType ());
1261
+ return false ;
1259
1262
}
1260
1263
} else if (CI->hasStructRetAttr ()) {
1261
1264
Value *sret = CI->getArgOperand (0 );
1265
+ PointerType *stype = cast<PointerType>(sret->getType ());
1266
+ StructType *st = dyn_cast<StructType>(stype->getElementType ());
1262
1267
1263
1268
// Assign results to struct allocated at the call site.
1264
- if (StructType * st = cast<StructType>(diffret-> getType () )) {
1269
+ if (st && st-> isLayoutIdentical (diffretsty )) {
1265
1270
for (unsigned int i = 0 ; i < st->getNumElements (); i++) {
1266
1271
#if LLVM_VERSION_MAJOR > 7
1267
1272
Value *sgep = Builder.CreateStructGEP (
@@ -1271,6 +1276,20 @@ class Enzyme : public ModulePass {
1271
1276
#endif
1272
1277
Builder.CreateStore (Builder.CreateExtractValue (diffret, {i}), sgep);
1273
1278
}
1279
+ } else {
1280
+ auto &DL = fn->getParent ()->getDataLayout ();
1281
+ if (DL.getTypeSizeInBits (stype->getElementType ()) !=
1282
+ DL.getTypeSizeInBits (diffret->getType ())) {
1283
+ EmitFailure (" IllegalReturnCast" , CI->getDebugLoc (), CI,
1284
+ " Cannot cast return type of gradient " ,
1285
+ *diffret->getType (), *diffret, " , to desired type " ,
1286
+ *stype->getElementType ());
1287
+ return false ;
1288
+ }
1289
+ Builder.CreateStore (
1290
+ diffret, Builder.CreatePointerCast (
1291
+ sret, PointerType::get (diffret->getType (),
1292
+ stype->getAddressSpace ())));
1274
1293
}
1275
1294
} else {
1276
1295
0 commit comments