Skip to content

Commit 034a18f

Browse files
authored
Soft float (rust-lang#589)
1 parent c990136 commit 034a18f

File tree

13 files changed

+625
-1
lines changed

13 files changed

+625
-1
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,11 @@ const std::set<std::string> KnownInactiveFunctions = {
242242
"_msize",
243243
"ftnio_fmt_write64",
244244
"f90_strcmp_klen",
245-
"__swift_instantiateConcreteTypeFromMangledName"};
245+
"__swift_instantiateConcreteTypeFromMangledName",
246+
"logb",
247+
"logbf",
248+
"logbl",
249+
};
246250

247251
/// Is the use of value val as an argument of call CI known to be inactive
248252
/// This tool can only be used when in DOWN mode

enzyme/Enzyme/AdjointGenerator.h

+317
Original file line numberDiff line numberDiff line change
@@ -9239,6 +9239,323 @@ class AdjointGenerator
92399239
return;
92409240
}
92419241

9242+
if (funcName == "__mulsc3" || funcName == "__muldc3" ||
9243+
funcName == "__multc3" || funcName == "__mulxc3") {
9244+
if (gutils->knownRecomputeHeuristic.find(orig) !=
9245+
gutils->knownRecomputeHeuristic.end()) {
9246+
if (!gutils->knownRecomputeHeuristic[orig]) {
9247+
gutils->cacheForReverse(BuilderZ, newCall,
9248+
getIndex(orig, CacheType::Self));
9249+
}
9250+
}
9251+
9252+
eraseIfUnused(*orig);
9253+
if (gutils->isConstantInstruction(orig))
9254+
return;
9255+
9256+
Value *orig_op0 = call.getOperand(0);
9257+
Value *orig_op1 = call.getOperand(1);
9258+
Value *orig_op2 = call.getOperand(2);
9259+
Value *orig_op3 = call.getOperand(3);
9260+
9261+
bool constantval0 = gutils->isConstantValue(orig_op0);
9262+
bool constantval1 = gutils->isConstantValue(orig_op1);
9263+
bool constantval2 = gutils->isConstantValue(orig_op2);
9264+
bool constantval3 = gutils->isConstantValue(orig_op3);
9265+
9266+
Value *prim[4] = {gutils->getNewFromOriginal(orig_op0),
9267+
gutils->getNewFromOriginal(orig_op1),
9268+
gutils->getNewFromOriginal(orig_op2),
9269+
gutils->getNewFromOriginal(orig_op3)};
9270+
9271+
auto mul = gutils->oldFunc->getParent()->getOrInsertFunction(
9272+
funcName, called->getFunctionType(), called->getAttributes());
9273+
9274+
switch (Mode) {
9275+
case DerivativeMode::ForwardMode:
9276+
case DerivativeMode::ForwardModeSplit: {
9277+
IRBuilder<> Builder2(&call);
9278+
getForwardBuilder(Builder2);
9279+
9280+
Value *diff[4] = {
9281+
constantval0 ? Constant::getNullValue(orig_op0->getType())
9282+
: diffe(orig_op0, Builder2),
9283+
constantval1 ? Constant::getNullValue(orig_op1->getType())
9284+
: diffe(orig_op1, Builder2),
9285+
constantval2 ? Constant::getNullValue(orig_op2->getType())
9286+
: diffe(orig_op2, Builder2),
9287+
constantval3 ? Constant::getNullValue(orig_op3->getType())
9288+
: diffe(orig_op3, Builder2)};
9289+
9290+
auto cal1 =
9291+
Builder2.CreateCall(mul, {diff[0], diff[1], prim[2], prim[3]});
9292+
auto cal2 =
9293+
Builder2.CreateCall(mul, {prim[0], prim[1], diff[2], diff[3]});
9294+
9295+
Value *resReal =
9296+
Builder2.CreateFAdd(Builder2.CreateExtractValue(cal1, {0}),
9297+
Builder2.CreateExtractValue(cal2, {0}));
9298+
Value *resImag =
9299+
Builder2.CreateFAdd(Builder2.CreateExtractValue(cal1, {1}),
9300+
Builder2.CreateExtractValue(cal2, {1}));
9301+
9302+
Value *res = Builder2.CreateInsertValue(
9303+
UndefValue::get(call.getType()), resReal, {0});
9304+
res = Builder2.CreateInsertValue(res, resImag, {1});
9305+
9306+
setDiffe(&call, res, Builder2);
9307+
return;
9308+
}
9309+
case DerivativeMode::ReverseModeGradient:
9310+
case DerivativeMode::ReverseModeCombined: {
9311+
IRBuilder<> Builder2(call.getParent());
9312+
getReverseBuilder(Builder2);
9313+
9314+
Value *idiff = diffe(&call, Builder2);
9315+
Value *idiffReal = Builder2.CreateExtractValue(idiff, {0});
9316+
Value *idiffImag = Builder2.CreateExtractValue(idiff, {1});
9317+
9318+
Value *diff0 = nullptr;
9319+
Value *diff1 = nullptr;
9320+
9321+
if (!constantval0 || !constantval1)
9322+
diff0 = Builder2.CreateCall(mul, {idiffReal, idiffImag,
9323+
lookup(prim[2], Builder2),
9324+
lookup(prim[3], Builder2)});
9325+
9326+
if (!constantval2 || !constantval3)
9327+
diff1 = Builder2.CreateCall(mul, {lookup(prim[0], Builder2),
9328+
lookup(prim[1], Builder2),
9329+
idiffReal, idiffImag});
9330+
9331+
if (diff0 || diff1)
9332+
setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);
9333+
9334+
if (diff0) {
9335+
addToDiffe(orig_op0, Builder2.CreateExtractValue(diff0, {0}),
9336+
Builder2, orig_op0->getType());
9337+
addToDiffe(orig_op1, Builder2.CreateExtractValue(diff0, {1}),
9338+
Builder2, orig_op1->getType());
9339+
}
9340+
9341+
if (diff1) {
9342+
addToDiffe(orig_op2, Builder2.CreateExtractValue(diff1, {0}),
9343+
Builder2, orig_op2->getType());
9344+
addToDiffe(orig_op3, Builder2.CreateExtractValue(diff1, {1}),
9345+
Builder2, orig_op3->getType());
9346+
}
9347+
9348+
return;
9349+
}
9350+
case DerivativeMode::ReverseModePrimal:
9351+
return;
9352+
}
9353+
}
9354+
9355+
if (funcName == "__divsc3" || funcName == "__divdc3" ||
9356+
funcName == "__divtc3" || funcName == "__divxc3") {
9357+
if (gutils->knownRecomputeHeuristic.find(orig) !=
9358+
gutils->knownRecomputeHeuristic.end()) {
9359+
if (!gutils->knownRecomputeHeuristic[orig]) {
9360+
gutils->cacheForReverse(BuilderZ, newCall,
9361+
getIndex(orig, CacheType::Self));
9362+
}
9363+
}
9364+
9365+
if (gutils->isConstantInstruction(orig))
9366+
return;
9367+
9368+
StringMap<StringRef> map = {
9369+
{"__divsc3", "__mulsc3"},
9370+
{"__divdc3", "__muldc3"},
9371+
{"__divtc3", "__multc3"},
9372+
{"__divxc3", "__mulxc3"},
9373+
};
9374+
9375+
auto mul = gutils->oldFunc->getParent()->getOrInsertFunction(
9376+
map[funcName], called->getFunctionType(), called->getAttributes());
9377+
9378+
auto div = gutils->oldFunc->getParent()->getOrInsertFunction(
9379+
funcName, called->getFunctionType(), called->getAttributes());
9380+
9381+
Value *orig_op0 = call.getOperand(0);
9382+
Value *orig_op1 = call.getOperand(1);
9383+
Value *orig_op2 = call.getOperand(2);
9384+
Value *orig_op3 = call.getOperand(3);
9385+
9386+
bool constantval0 = gutils->isConstantValue(orig_op0);
9387+
bool constantval1 = gutils->isConstantValue(orig_op1);
9388+
bool constantval2 = gutils->isConstantValue(orig_op2);
9389+
bool constantval3 = gutils->isConstantValue(orig_op3);
9390+
9391+
Value *prim[4] = {gutils->getNewFromOriginal(orig_op0),
9392+
gutils->getNewFromOriginal(orig_op1),
9393+
gutils->getNewFromOriginal(orig_op2),
9394+
gutils->getNewFromOriginal(orig_op3)};
9395+
9396+
switch (Mode) {
9397+
case DerivativeMode::ForwardMode:
9398+
case DerivativeMode::ForwardModeSplit: {
9399+
IRBuilder<> Builder2(&call);
9400+
getForwardBuilder(Builder2);
9401+
9402+
Value *diff[4] = {
9403+
constantval0 ? Constant::getNullValue(orig_op0->getType())
9404+
: diffe(orig_op0, Builder2),
9405+
constantval1 ? Constant::getNullValue(orig_op1->getType())
9406+
: diffe(orig_op1, Builder2),
9407+
constantval2 ? Constant::getNullValue(orig_op2->getType())
9408+
: diffe(orig_op2, Builder2),
9409+
constantval3 ? Constant::getNullValue(orig_op3->getType())
9410+
: diffe(orig_op3, Builder2)};
9411+
9412+
auto mul1 =
9413+
Builder2.CreateCall(mul, {diff[0], diff[1], prim[2], prim[3]});
9414+
auto mul2 =
9415+
Builder2.CreateCall(mul, {prim[0], prim[1], diff[2], diff[3]});
9416+
auto sq1 =
9417+
Builder2.CreateCall(mul, {prim[2], prim[3], prim[2], prim[3]});
9418+
9419+
Value *subReal =
9420+
Builder2.CreateFSub(Builder2.CreateExtractValue(mul1, {0}),
9421+
Builder2.CreateExtractValue(mul2, {0}));
9422+
Value *subImag =
9423+
Builder2.CreateFSub(Builder2.CreateExtractValue(mul1, {1}),
9424+
Builder2.CreateExtractValue(mul2, {1}));
9425+
9426+
auto div1 = Builder2.CreateCall(
9427+
div, {subReal, subImag, Builder2.CreateExtractValue(sq1, {0}),
9428+
Builder2.CreateExtractValue(sq1, {1})});
9429+
9430+
setDiffe(&call, div1, Builder2);
9431+
9432+
eraseIfUnused(*orig);
9433+
9434+
return;
9435+
}
9436+
case DerivativeMode::ReverseModeGradient:
9437+
case DerivativeMode::ReverseModeCombined: {
9438+
IRBuilder<> Builder2(call.getParent());
9439+
getReverseBuilder(Builder2);
9440+
9441+
Value *idiff = diffe(&call, Builder2);
9442+
Value *idiffReal = Builder2.CreateExtractValue(idiff, {0});
9443+
Value *idiffImag = Builder2.CreateExtractValue(idiff, {1});
9444+
9445+
Value *diff0 = nullptr;
9446+
Value *diff1 = nullptr;
9447+
9448+
if (!constantval0 || !constantval1)
9449+
diff0 = Builder2.CreateCall(div, {idiffReal, idiffImag,
9450+
lookup(prim[2], Builder2),
9451+
lookup(prim[3], Builder2)});
9452+
9453+
if (!constantval2 || !constantval3) {
9454+
auto fdiv = Builder2.CreateCall(div, {idiffReal, idiffImag,
9455+
lookup(prim[1], Builder2),
9456+
lookup(prim[2], Builder2)});
9457+
9458+
Value *newcall = gutils->getNewFromOriginal(&call);
9459+
9460+
diff1 = Builder2.CreateCall(
9461+
mul,
9462+
{Builder2.CreateFNeg(Builder2.CreateExtractValue(newcall, {0})),
9463+
Builder2.CreateFNeg(Builder2.CreateExtractValue(newcall, {1})),
9464+
Builder2.CreateExtractValue(fdiv, {0}),
9465+
Builder2.CreateExtractValue(fdiv, {1})});
9466+
}
9467+
9468+
if (diff0 || diff1)
9469+
setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);
9470+
9471+
if (diff0) {
9472+
addToDiffe(orig_op0, Builder2.CreateExtractValue(diff0, {0}),
9473+
Builder2, orig_op0->getType());
9474+
addToDiffe(orig_op1, Builder2.CreateExtractValue(diff0, {1}),
9475+
Builder2, orig_op1->getType());
9476+
}
9477+
9478+
if (diff1) {
9479+
addToDiffe(orig_op2, Builder2.CreateExtractValue(diff1, {0}),
9480+
Builder2, orig_op2->getType());
9481+
addToDiffe(orig_op3, Builder2.CreateExtractValue(diff1, {1}),
9482+
Builder2, orig_op3->getType());
9483+
}
9484+
9485+
if (constantval2 && constantval3)
9486+
eraseIfUnused(*orig);
9487+
9488+
return;
9489+
}
9490+
case DerivativeMode::ReverseModePrimal:;
9491+
return;
9492+
}
9493+
}
9494+
9495+
if (funcName == "scalbn" || funcName == "scalbnf" ||
9496+
funcName == "scalbnl" || funcName == "scalbln" ||
9497+
funcName == "scalblnf" || funcName == "scalblnl") {
9498+
eraseIfUnused(*orig);
9499+
9500+
Value *orig_op0 = call.getOperand(0);
9501+
Value *orig_op1 = call.getOperand(1);
9502+
9503+
bool constantval0 = gutils->isConstantValue(orig_op0);
9504+
9505+
if (gutils->isConstantInstruction(orig) || constantval0)
9506+
return;
9507+
9508+
Value *op0 = gutils->getNewFromOriginal(orig_op0);
9509+
Value *op1 = gutils->getNewFromOriginal(orig_op1);
9510+
9511+
auto scal = gutils->oldFunc->getParent()->getOrInsertFunction(
9512+
funcName, called->getFunctionType(), called->getAttributes());
9513+
9514+
switch (Mode) {
9515+
case DerivativeMode::ForwardMode:
9516+
case DerivativeMode::ForwardModeSplit: {
9517+
IRBuilder<> Builder2(&call);
9518+
getForwardBuilder(Builder2);
9519+
9520+
Value *diff0 = diffe(orig_op0, Builder2);
9521+
9522+
auto cal1 = Builder2.CreateCall(scal, {op0, op1});
9523+
auto cal2 = Builder2.CreateCall(scal, {diff0, op1});
9524+
9525+
Value *diff = Builder2.CreateFMul(
9526+
cal1, ConstantFP::get(call.getType(), 0.3010299957));
9527+
diff = Builder2.CreateFAdd(diff, cal2);
9528+
9529+
setDiffe(&call, diff, Builder2);
9530+
return;
9531+
}
9532+
case DerivativeMode::ReverseModeGradient:
9533+
case DerivativeMode::ReverseModeCombined: {
9534+
IRBuilder<> Builder2(call.getParent());
9535+
getReverseBuilder(Builder2);
9536+
9537+
Value *idiff = diffe(&call, Builder2);
9538+
9539+
if (idiff && !constantval0) {
9540+
op1 = lookup(op1, Builder2);
9541+
9542+
auto cal1 = Builder2.CreateCall(scal, {op0, op1});
9543+
auto cal2 = Builder2.CreateCall(scal, {idiff, op1});
9544+
9545+
Value *diff = Builder2.CreateFMul(
9546+
cal1, ConstantFP::get(call.getType(), 0.3010299957));
9547+
diff = Builder2.CreateFAdd(diff, cal2);
9548+
9549+
addToDiffe(orig_op0, diff, Builder2, call.getType());
9550+
}
9551+
9552+
return;
9553+
}
9554+
case DerivativeMode::ReverseModePrimal:;
9555+
return;
9556+
}
9557+
}
9558+
92429559
if (called) {
92439560
if (funcName == "erf" || funcName == "erfi" || funcName == "erfc" ||
92449561
funcName == "Faddeeva_erf" || funcName == "Faddeeva_erfi" ||

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,23 @@ const std::map<std::string, llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
9292
{"log1p", Intrinsic::not_intrinsic},
9393
{"log2", Intrinsic::log2},
9494
{"logb", Intrinsic::not_intrinsic},
95+
{"logbf", Intrinsic::not_intrinsic},
96+
{"logbl", Intrinsic::not_intrinsic},
9597
{"pow", Intrinsic::pow},
9698
{"sqrt", Intrinsic::sqrt},
9799
{"cbrt", Intrinsic::not_intrinsic},
98100
{"hypot", Intrinsic::not_intrinsic},
99101

102+
{"__mulsc3", Intrinsic::not_intrinsic},
103+
{"__muldc3", Intrinsic::not_intrinsic},
104+
{"__multc3", Intrinsic::not_intrinsic},
105+
{"__mulxc3", Intrinsic::not_intrinsic},
106+
107+
{"__divsc3", Intrinsic::not_intrinsic},
108+
{"__divdc3", Intrinsic::not_intrinsic},
109+
{"__divtc3", Intrinsic::not_intrinsic},
110+
{"__divxc3", Intrinsic::not_intrinsic},
111+
100112
{"Faddeeva_erf", Intrinsic::not_intrinsic},
101113
{"Faddeeva_erfc", Intrinsic::not_intrinsic},
102114
{"Faddeeva_erfcx", Intrinsic::not_intrinsic},
@@ -139,6 +151,11 @@ const std::map<std::string, llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
139151
{"fma", Intrinsic::fma},
140152
{"ilogb", Intrinsic::not_intrinsic},
141153
{"scalbn", Intrinsic::not_intrinsic},
154+
{"scalbnf", Intrinsic::not_intrinsic},
155+
{"scalbnl", Intrinsic::not_intrinsic},
156+
{"scalbln", Intrinsic::not_intrinsic},
157+
{"scalblnf", Intrinsic::not_intrinsic},
158+
{"scalblnl", Intrinsic::not_intrinsic},
142159
{"powi", Intrinsic::powi},
143160
{"cabs", Intrinsic::not_intrinsic},
144161
{"ldexp", Intrinsic::not_intrinsic},

0 commit comments

Comments
 (0)