@@ -9239,6 +9239,323 @@ class AdjointGenerator
9239
9239
return;
9240
9240
}
9241
9241
9242
+ if (funcName == "__mulsc3" || funcName == "__muldc3" ||
9243
+ funcName == "__multc3" || funcName == "__mulxc3") {
9244
+ if (gutils->knownRecomputeHeuristic.find(orig) !=
9245
+ gutils->knownRecomputeHeuristic.end()) {
9246
+ if (!gutils->knownRecomputeHeuristic[orig]) {
9247
+ gutils->cacheForReverse(BuilderZ, newCall,
9248
+ getIndex(orig, CacheType::Self));
9249
+ }
9250
+ }
9251
+
9252
+ eraseIfUnused(*orig);
9253
+ if (gutils->isConstantInstruction(orig))
9254
+ return;
9255
+
9256
+ Value *orig_op0 = call.getOperand(0);
9257
+ Value *orig_op1 = call.getOperand(1);
9258
+ Value *orig_op2 = call.getOperand(2);
9259
+ Value *orig_op3 = call.getOperand(3);
9260
+
9261
+ bool constantval0 = gutils->isConstantValue(orig_op0);
9262
+ bool constantval1 = gutils->isConstantValue(orig_op1);
9263
+ bool constantval2 = gutils->isConstantValue(orig_op2);
9264
+ bool constantval3 = gutils->isConstantValue(orig_op3);
9265
+
9266
+ Value *prim[4] = {gutils->getNewFromOriginal(orig_op0),
9267
+ gutils->getNewFromOriginal(orig_op1),
9268
+ gutils->getNewFromOriginal(orig_op2),
9269
+ gutils->getNewFromOriginal(orig_op3)};
9270
+
9271
+ auto mul = gutils->oldFunc->getParent()->getOrInsertFunction(
9272
+ funcName, called->getFunctionType(), called->getAttributes());
9273
+
9274
+ switch (Mode) {
9275
+ case DerivativeMode::ForwardMode:
9276
+ case DerivativeMode::ForwardModeSplit: {
9277
+ IRBuilder<> Builder2(&call);
9278
+ getForwardBuilder(Builder2);
9279
+
9280
+ Value *diff[4] = {
9281
+ constantval0 ? Constant::getNullValue(orig_op0->getType())
9282
+ : diffe(orig_op0, Builder2),
9283
+ constantval1 ? Constant::getNullValue(orig_op1->getType())
9284
+ : diffe(orig_op1, Builder2),
9285
+ constantval2 ? Constant::getNullValue(orig_op2->getType())
9286
+ : diffe(orig_op2, Builder2),
9287
+ constantval3 ? Constant::getNullValue(orig_op3->getType())
9288
+ : diffe(orig_op3, Builder2)};
9289
+
9290
+ auto cal1 =
9291
+ Builder2.CreateCall(mul, {diff[0], diff[1], prim[2], prim[3]});
9292
+ auto cal2 =
9293
+ Builder2.CreateCall(mul, {prim[0], prim[1], diff[2], diff[3]});
9294
+
9295
+ Value *resReal =
9296
+ Builder2.CreateFAdd(Builder2.CreateExtractValue(cal1, {0}),
9297
+ Builder2.CreateExtractValue(cal2, {0}));
9298
+ Value *resImag =
9299
+ Builder2.CreateFAdd(Builder2.CreateExtractValue(cal1, {1}),
9300
+ Builder2.CreateExtractValue(cal2, {1}));
9301
+
9302
+ Value *res = Builder2.CreateInsertValue(
9303
+ UndefValue::get(call.getType()), resReal, {0});
9304
+ res = Builder2.CreateInsertValue(res, resImag, {1});
9305
+
9306
+ setDiffe(&call, res, Builder2);
9307
+ return;
9308
+ }
9309
+ case DerivativeMode::ReverseModeGradient:
9310
+ case DerivativeMode::ReverseModeCombined: {
9311
+ IRBuilder<> Builder2(call.getParent());
9312
+ getReverseBuilder(Builder2);
9313
+
9314
+ Value *idiff = diffe(&call, Builder2);
9315
+ Value *idiffReal = Builder2.CreateExtractValue(idiff, {0});
9316
+ Value *idiffImag = Builder2.CreateExtractValue(idiff, {1});
9317
+
9318
+ Value *diff0 = nullptr;
9319
+ Value *diff1 = nullptr;
9320
+
9321
+ if (!constantval0 || !constantval1)
9322
+ diff0 = Builder2.CreateCall(mul, {idiffReal, idiffImag,
9323
+ lookup(prim[2], Builder2),
9324
+ lookup(prim[3], Builder2)});
9325
+
9326
+ if (!constantval2 || !constantval3)
9327
+ diff1 = Builder2.CreateCall(mul, {lookup(prim[0], Builder2),
9328
+ lookup(prim[1], Builder2),
9329
+ idiffReal, idiffImag});
9330
+
9331
+ if (diff0 || diff1)
9332
+ setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);
9333
+
9334
+ if (diff0) {
9335
+ addToDiffe(orig_op0, Builder2.CreateExtractValue(diff0, {0}),
9336
+ Builder2, orig_op0->getType());
9337
+ addToDiffe(orig_op1, Builder2.CreateExtractValue(diff0, {1}),
9338
+ Builder2, orig_op1->getType());
9339
+ }
9340
+
9341
+ if (diff1) {
9342
+ addToDiffe(orig_op2, Builder2.CreateExtractValue(diff1, {0}),
9343
+ Builder2, orig_op2->getType());
9344
+ addToDiffe(orig_op3, Builder2.CreateExtractValue(diff1, {1}),
9345
+ Builder2, orig_op3->getType());
9346
+ }
9347
+
9348
+ return;
9349
+ }
9350
+ case DerivativeMode::ReverseModePrimal:
9351
+ return;
9352
+ }
9353
+ }
9354
+
9355
+ if (funcName == "__divsc3" || funcName == "__divdc3" ||
9356
+ funcName == "__divtc3" || funcName == "__divxc3") {
9357
+ if (gutils->knownRecomputeHeuristic.find(orig) !=
9358
+ gutils->knownRecomputeHeuristic.end()) {
9359
+ if (!gutils->knownRecomputeHeuristic[orig]) {
9360
+ gutils->cacheForReverse(BuilderZ, newCall,
9361
+ getIndex(orig, CacheType::Self));
9362
+ }
9363
+ }
9364
+
9365
+ if (gutils->isConstantInstruction(orig))
9366
+ return;
9367
+
9368
+ StringMap<StringRef> map = {
9369
+ {"__divsc3", "__mulsc3"},
9370
+ {"__divdc3", "__muldc3"},
9371
+ {"__divtc3", "__multc3"},
9372
+ {"__divxc3", "__mulxc3"},
9373
+ };
9374
+
9375
+ auto mul = gutils->oldFunc->getParent()->getOrInsertFunction(
9376
+ map[funcName], called->getFunctionType(), called->getAttributes());
9377
+
9378
+ auto div = gutils->oldFunc->getParent()->getOrInsertFunction(
9379
+ funcName, called->getFunctionType(), called->getAttributes());
9380
+
9381
+ Value *orig_op0 = call.getOperand(0);
9382
+ Value *orig_op1 = call.getOperand(1);
9383
+ Value *orig_op2 = call.getOperand(2);
9384
+ Value *orig_op3 = call.getOperand(3);
9385
+
9386
+ bool constantval0 = gutils->isConstantValue(orig_op0);
9387
+ bool constantval1 = gutils->isConstantValue(orig_op1);
9388
+ bool constantval2 = gutils->isConstantValue(orig_op2);
9389
+ bool constantval3 = gutils->isConstantValue(orig_op3);
9390
+
9391
+ Value *prim[4] = {gutils->getNewFromOriginal(orig_op0),
9392
+ gutils->getNewFromOriginal(orig_op1),
9393
+ gutils->getNewFromOriginal(orig_op2),
9394
+ gutils->getNewFromOriginal(orig_op3)};
9395
+
9396
+ switch (Mode) {
9397
+ case DerivativeMode::ForwardMode:
9398
+ case DerivativeMode::ForwardModeSplit: {
9399
+ IRBuilder<> Builder2(&call);
9400
+ getForwardBuilder(Builder2);
9401
+
9402
+ Value *diff[4] = {
9403
+ constantval0 ? Constant::getNullValue(orig_op0->getType())
9404
+ : diffe(orig_op0, Builder2),
9405
+ constantval1 ? Constant::getNullValue(orig_op1->getType())
9406
+ : diffe(orig_op1, Builder2),
9407
+ constantval2 ? Constant::getNullValue(orig_op2->getType())
9408
+ : diffe(orig_op2, Builder2),
9409
+ constantval3 ? Constant::getNullValue(orig_op3->getType())
9410
+ : diffe(orig_op3, Builder2)};
9411
+
9412
+ auto mul1 =
9413
+ Builder2.CreateCall(mul, {diff[0], diff[1], prim[2], prim[3]});
9414
+ auto mul2 =
9415
+ Builder2.CreateCall(mul, {prim[0], prim[1], diff[2], diff[3]});
9416
+ auto sq1 =
9417
+ Builder2.CreateCall(mul, {prim[2], prim[3], prim[2], prim[3]});
9418
+
9419
+ Value *subReal =
9420
+ Builder2.CreateFSub(Builder2.CreateExtractValue(mul1, {0}),
9421
+ Builder2.CreateExtractValue(mul2, {0}));
9422
+ Value *subImag =
9423
+ Builder2.CreateFSub(Builder2.CreateExtractValue(mul1, {1}),
9424
+ Builder2.CreateExtractValue(mul2, {1}));
9425
+
9426
+ auto div1 = Builder2.CreateCall(
9427
+ div, {subReal, subImag, Builder2.CreateExtractValue(sq1, {0}),
9428
+ Builder2.CreateExtractValue(sq1, {1})});
9429
+
9430
+ setDiffe(&call, div1, Builder2);
9431
+
9432
+ eraseIfUnused(*orig);
9433
+
9434
+ return;
9435
+ }
9436
+ case DerivativeMode::ReverseModeGradient:
9437
+ case DerivativeMode::ReverseModeCombined: {
9438
+ IRBuilder<> Builder2(call.getParent());
9439
+ getReverseBuilder(Builder2);
9440
+
9441
+ Value *idiff = diffe(&call, Builder2);
9442
+ Value *idiffReal = Builder2.CreateExtractValue(idiff, {0});
9443
+ Value *idiffImag = Builder2.CreateExtractValue(idiff, {1});
9444
+
9445
+ Value *diff0 = nullptr;
9446
+ Value *diff1 = nullptr;
9447
+
9448
+ if (!constantval0 || !constantval1)
9449
+ diff0 = Builder2.CreateCall(div, {idiffReal, idiffImag,
9450
+ lookup(prim[2], Builder2),
9451
+ lookup(prim[3], Builder2)});
9452
+
9453
+ if (!constantval2 || !constantval3) {
9454
+ auto fdiv = Builder2.CreateCall(div, {idiffReal, idiffImag,
9455
+ lookup(prim[1], Builder2),
9456
+ lookup(prim[2], Builder2)});
9457
+
9458
+ Value *newcall = gutils->getNewFromOriginal(&call);
9459
+
9460
+ diff1 = Builder2.CreateCall(
9461
+ mul,
9462
+ {Builder2.CreateFNeg(Builder2.CreateExtractValue(newcall, {0})),
9463
+ Builder2.CreateFNeg(Builder2.CreateExtractValue(newcall, {1})),
9464
+ Builder2.CreateExtractValue(fdiv, {0}),
9465
+ Builder2.CreateExtractValue(fdiv, {1})});
9466
+ }
9467
+
9468
+ if (diff0 || diff1)
9469
+ setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);
9470
+
9471
+ if (diff0) {
9472
+ addToDiffe(orig_op0, Builder2.CreateExtractValue(diff0, {0}),
9473
+ Builder2, orig_op0->getType());
9474
+ addToDiffe(orig_op1, Builder2.CreateExtractValue(diff0, {1}),
9475
+ Builder2, orig_op1->getType());
9476
+ }
9477
+
9478
+ if (diff1) {
9479
+ addToDiffe(orig_op2, Builder2.CreateExtractValue(diff1, {0}),
9480
+ Builder2, orig_op2->getType());
9481
+ addToDiffe(orig_op3, Builder2.CreateExtractValue(diff1, {1}),
9482
+ Builder2, orig_op3->getType());
9483
+ }
9484
+
9485
+ if (constantval2 && constantval3)
9486
+ eraseIfUnused(*orig);
9487
+
9488
+ return;
9489
+ }
9490
+ case DerivativeMode::ReverseModePrimal:;
9491
+ return;
9492
+ }
9493
+ }
9494
+
9495
+ if (funcName == "scalbn" || funcName == "scalbnf" ||
9496
+ funcName == "scalbnl" || funcName == "scalbln" ||
9497
+ funcName == "scalblnf" || funcName == "scalblnl") {
9498
+ eraseIfUnused(*orig);
9499
+
9500
+ Value *orig_op0 = call.getOperand(0);
9501
+ Value *orig_op1 = call.getOperand(1);
9502
+
9503
+ bool constantval0 = gutils->isConstantValue(orig_op0);
9504
+
9505
+ if (gutils->isConstantInstruction(orig) || constantval0)
9506
+ return;
9507
+
9508
+ Value *op0 = gutils->getNewFromOriginal(orig_op0);
9509
+ Value *op1 = gutils->getNewFromOriginal(orig_op1);
9510
+
9511
+ auto scal = gutils->oldFunc->getParent()->getOrInsertFunction(
9512
+ funcName, called->getFunctionType(), called->getAttributes());
9513
+
9514
+ switch (Mode) {
9515
+ case DerivativeMode::ForwardMode:
9516
+ case DerivativeMode::ForwardModeSplit: {
9517
+ IRBuilder<> Builder2(&call);
9518
+ getForwardBuilder(Builder2);
9519
+
9520
+ Value *diff0 = diffe(orig_op0, Builder2);
9521
+
9522
+ auto cal1 = Builder2.CreateCall(scal, {op0, op1});
9523
+ auto cal2 = Builder2.CreateCall(scal, {diff0, op1});
9524
+
9525
+ Value *diff = Builder2.CreateFMul(
9526
+ cal1, ConstantFP::get(call.getType(), 0.3010299957));
9527
+ diff = Builder2.CreateFAdd(diff, cal2);
9528
+
9529
+ setDiffe(&call, diff, Builder2);
9530
+ return;
9531
+ }
9532
+ case DerivativeMode::ReverseModeGradient:
9533
+ case DerivativeMode::ReverseModeCombined: {
9534
+ IRBuilder<> Builder2(call.getParent());
9535
+ getReverseBuilder(Builder2);
9536
+
9537
+ Value *idiff = diffe(&call, Builder2);
9538
+
9539
+ if (idiff && !constantval0) {
9540
+ op1 = lookup(op1, Builder2);
9541
+
9542
+ auto cal1 = Builder2.CreateCall(scal, {op0, op1});
9543
+ auto cal2 = Builder2.CreateCall(scal, {idiff, op1});
9544
+
9545
+ Value *diff = Builder2.CreateFMul(
9546
+ cal1, ConstantFP::get(call.getType(), 0.3010299957));
9547
+ diff = Builder2.CreateFAdd(diff, cal2);
9548
+
9549
+ addToDiffe(orig_op0, diff, Builder2, call.getType());
9550
+ }
9551
+
9552
+ return;
9553
+ }
9554
+ case DerivativeMode::ReverseModePrimal:;
9555
+ return;
9556
+ }
9557
+ }
9558
+
9242
9559
if (called) {
9243
9560
if (funcName == "erf" || funcName == "erfi" || funcName == "erfc" ||
9244
9561
funcName == "Faddeeva_erf" || funcName == "Faddeeva_erfi" ||
0 commit comments