Skip to content

Commit e56eb6a

Browse files
authored
Add c api inverted bundle (rust-lang#806)
1 parent f92c0eb commit e56eb6a

File tree

5 files changed

+42
-7
lines changed

5 files changed

+42
-7
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,35 @@ const char *EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils,
632632
return cstr;
633633
}
634634

635+
LLVMValueRef EnzymeGradientUtilsCallWithInvertedBundles(
636+
GradientUtils *gutils, LLVMValueRef func, LLVMValueRef *args_vr,
637+
uint64_t args_size, LLVMValueRef orig_vr, CValueType *valTys,
638+
uint64_t valTys_size, LLVMBuilderRef B, uint8_t lookup) {
639+
auto orig = cast<CallInst>(unwrap(orig_vr));
640+
641+
ArrayRef<ValueType> ar((ValueType *)valTys, valTys_size);
642+
643+
IRBuilder<> &BR = *unwrap(B);
644+
645+
auto Defs = gutils->getInvertedBundles(orig, ar, BR, lookup != 0);
646+
647+
SmallVector<Value *, 1> args;
648+
for (size_t i = 0; i < args_size; i++) {
649+
args.push_back(unwrap(args_vr[i]));
650+
}
651+
652+
auto callval = unwrap(func);
653+
654+
#if LLVM_VERSION_MAJOR > 7
655+
auto res = BR.CreateCall(
656+
cast<FunctionType>(callval->getType()->getPointerElementType()), callval,
657+
args, Defs);
658+
#else
659+
auto res = BR.CreateCall(callval, args, Defs);
660+
#endif
661+
return wrap(res);
662+
}
663+
635664
void EnzymeStringFree(const char *cstr) { delete[] cstr; }
636665

637666
void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2,

enzyme/Enzyme/CApi.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ struct CTypeTree {
7070
};
7171
*/
7272

73+
typedef enum {
74+
VT_None = 0,
75+
VT_Primal = 1,
76+
VT_Shadow = 2,
77+
VT_Both = VT_Primal | VT_Shadow,
78+
} CValueType;
79+
7380
struct EnzymeTypeTree;
7481
typedef struct EnzymeTypeTree *CTypeTreeRef;
7582
CTypeTreeRef EnzymeNewTypeTree();

enzyme/Enzyme/Enzyme.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,10 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
206206
if (!byref.count(realidx))
207207
args.push_back(arg.getType());
208208
else
209-
args.push_back(
210-
cast<PointerType>(arg.getType())->getElementType());
209+
args.push_back(arg.getType()->getPointerElementType());
211210
realidx++;
212211
} else {
213-
sretTy = cast<PointerType>(arg.getType())->getElementType();
212+
sretTy = arg.getType()->getPointerElementType();
214213
}
215214
i++;
216215
}
@@ -680,7 +679,7 @@ static bool replaceOriginalCall(CallInst *CI, Function *fn, Value *diffret,
680679
} else if (CI->hasStructRetAttr()) {
681680
Value *sret = CI->getArgOperand(0);
682681
PointerType *stype = cast<PointerType>(sret->getType());
683-
StructType *st = dyn_cast<StructType>(stype->getElementType());
682+
StructType *st = dyn_cast<StructType>(stype->getPointerElementType());
684683

685684
// Assign results to struct allocated at the call site.
686685
if (st && st->isLayoutIdentical(diffretsty)) {
@@ -1285,7 +1284,7 @@ class Enzyme : public ModulePass {
12851284
}
12861285

12871286
if (byRefSize) {
1288-
Type *subTy = cast<PointerType>(res->getType())->getElementType();
1287+
Type *subTy = res->getType()->getPointerElementType();
12891288
auto &DL = fn->getParent()->getDataLayout();
12901289
auto BitSize = DL.getTypeSizeInBits(subTy);
12911290
if (BitSize / 8 != byRefSize) {

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
18321832
if (!foundcalled->hasParamAttribute(i, Attribute::StructRet))
18331833
args.push_back(arg.getType());
18341834
else {
1835-
sretTy = cast<PointerType>(arg.getType())->getElementType();
1835+
sretTy = arg.getType()->getPointerElementType();
18361836
// sretTy = foundcalled->getParamStructRetType(i);
18371837
}
18381838
i++;

enzyme/Enzyme/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count,
130130
if (ZeroMem) {
131131
auto PT = cast<PointerType>(malloccall->getType());
132132
Value *tozero = malloccall;
133-
if (!PT->getElementType()->isIntegerTy(8))
133+
if (!PT->getPointerElementType()->isIntegerTy(8))
134134
tozero = Builder.CreatePointerCast(
135135
tozero, PointerType::get(Type::getInt8Ty(PT->getContext()),
136136
PT->getAddressSpace()));

0 commit comments

Comments
 (0)