@@ -353,23 +353,39 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
353
353
if (llvm::any_of (VL, [](Value *V) { return !isa<Instruction>(V); }))
354
354
return InstructionsState (VL[BaseIndex], nullptr , nullptr );
355
355
356
+ bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
356
357
bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
357
358
unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode ();
358
359
unsigned AltOpcode = Opcode;
359
360
unsigned AltIndex = BaseIndex;
360
361
361
362
// Check for one alternate opcode from another BinaryOperator.
362
- // TODO - can we support other operators (casts etc.)?
363
+ // TODO - generalize to support all operators (types, calls etc.).
363
364
for (int Cnt = 0 , E = VL.size (); Cnt < E; Cnt++) {
364
365
unsigned InstOpcode = cast<Instruction>(VL[Cnt])->getOpcode ();
365
- if (InstOpcode != Opcode && InstOpcode != AltOpcode) {
366
- if (Opcode == AltOpcode && IsBinOp && isa<BinaryOperator>(VL[Cnt])) {
366
+ if (IsBinOp && isa<BinaryOperator>(VL[Cnt])) {
367
+ if (InstOpcode == Opcode || InstOpcode == AltOpcode)
368
+ continue ;
369
+ if (Opcode == AltOpcode) {
367
370
AltOpcode = InstOpcode;
368
371
AltIndex = Cnt;
369
372
continue ;
370
373
}
371
- return InstructionsState (VL[BaseIndex], nullptr , nullptr );
372
- }
374
+ } else if (IsCastOp && isa<CastInst>(VL[Cnt])) {
375
+ Type *Ty0 = cast<Instruction>(VL[BaseIndex])->getOperand (0 )->getType ();
376
+ Type *Ty1 = cast<Instruction>(VL[Cnt])->getOperand (0 )->getType ();
377
+ if (Ty0 == Ty1) {
378
+ if (InstOpcode == Opcode || InstOpcode == AltOpcode)
379
+ continue ;
380
+ if (Opcode == AltOpcode) {
381
+ AltOpcode = InstOpcode;
382
+ AltIndex = Cnt;
383
+ continue ;
384
+ }
385
+ }
386
+ } else if (InstOpcode == Opcode || InstOpcode == AltOpcode)
387
+ continue ;
388
+ return InstructionsState (VL[BaseIndex], nullptr , nullptr );
373
389
}
374
390
375
391
return InstructionsState (VL[BaseIndex], cast<Instruction>(VL[BaseIndex]),
@@ -2363,32 +2379,45 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
2363
2379
return ReuseShuffleCost + VecCallCost - ScalarCallCost;
2364
2380
}
2365
2381
case Instruction::ShuffleVector: {
2366
- assert (S.isAltShuffle () && Instruction::isBinaryOp (S.getOpcode ()) &&
2367
- Instruction::isBinaryOp (S.getAltOpcode ()) &&
2382
+ assert (S.isAltShuffle () &&
2383
+ ((Instruction::isBinaryOp (S.getOpcode ()) &&
2384
+ Instruction::isBinaryOp (S.getAltOpcode ())) ||
2385
+ (Instruction::isCast (S.getOpcode ()) &&
2386
+ Instruction::isCast (S.getAltOpcode ()))) &&
2368
2387
" Invalid Shuffle Vector Operand" );
2369
2388
int ScalarCost = 0 ;
2370
2389
if (NeedToShuffleReuses) {
2371
2390
for (unsigned Idx : E->ReuseShuffleIndices ) {
2372
2391
Instruction *I = cast<Instruction>(VL[Idx]);
2373
- ReuseShuffleCost -=
2374
- TTI-> getArithmeticInstrCost (I-> getOpcode (), ScalarTy );
2392
+ ReuseShuffleCost -= TTI-> getInstructionCost (
2393
+ I, TargetTransformInfo::TCK_RecipThroughput );
2375
2394
}
2376
2395
for (Value *V : VL) {
2377
2396
Instruction *I = cast<Instruction>(V);
2378
- ReuseShuffleCost +=
2379
- TTI-> getArithmeticInstrCost (I-> getOpcode (), ScalarTy );
2397
+ ReuseShuffleCost += TTI-> getInstructionCost (
2398
+ I, TargetTransformInfo::TCK_RecipThroughput );
2380
2399
}
2381
2400
}
2382
2401
int VecCost = 0 ;
2383
2402
for (Value *i : VL) {
2384
2403
Instruction *I = cast<Instruction>(i);
2385
2404
assert (S.isOpcodeOrAlt (I) && " Unexpected main/alternate opcode" );
2386
- ScalarCost += TTI->getArithmeticInstrCost (I->getOpcode (), ScalarTy);
2405
+ ScalarCost += TTI->getInstructionCost (
2406
+ I, TargetTransformInfo::TCK_RecipThroughput);
2387
2407
}
2388
2408
// VecCost is equal to sum of the cost of creating 2 vectors
2389
2409
// and the cost of creating shuffle.
2390
- VecCost = TTI->getArithmeticInstrCost (S.getOpcode (), VecTy);
2391
- VecCost += TTI->getArithmeticInstrCost (S.getAltOpcode (), VecTy);
2410
+ if (Instruction::isBinaryOp (S.getOpcode ())) {
2411
+ VecCost = TTI->getArithmeticInstrCost (S.getOpcode (), VecTy);
2412
+ VecCost += TTI->getArithmeticInstrCost (S.getAltOpcode (), VecTy);
2413
+ } else {
2414
+ Type *Src0SclTy = S.MainOp ->getOperand (0 )->getType ();
2415
+ Type *Src1SclTy = S.AltOp ->getOperand (0 )->getType ();
2416
+ VectorType *Src0Ty = VectorType::get (Src0SclTy, VL.size ());
2417
+ VectorType *Src1Ty = VectorType::get (Src1SclTy, VL.size ());
2418
+ VecCost = TTI->getCastInstrCost (S.getOpcode (), VecTy, Src0Ty);
2419
+ VecCost += TTI->getCastInstrCost (S.getAltOpcode (), VecTy, Src1Ty);
2420
+ }
2392
2421
VecCost += TTI->getShuffleCost (TargetTransformInfo::SK_Select, VecTy, 0 );
2393
2422
return ReuseShuffleCost + VecCost - ScalarCost;
2394
2423
}
@@ -3470,30 +3499,47 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
3470
3499
}
3471
3500
case Instruction::ShuffleVector: {
3472
3501
ValueList LHSVL, RHSVL;
3473
- assert (S.isAltShuffle () && Instruction::isBinaryOp (S.getOpcode ()) &&
3474
- Instruction::isBinaryOp (S.getAltOpcode ()) &&
3502
+ assert (S.isAltShuffle () &&
3503
+ ((Instruction::isBinaryOp (S.getOpcode ()) &&
3504
+ Instruction::isBinaryOp (S.getAltOpcode ())) ||
3505
+ (Instruction::isCast (S.getOpcode ()) &&
3506
+ Instruction::isCast (S.getAltOpcode ()))) &&
3475
3507
" Invalid Shuffle Vector Operand" );
3476
- reorderAltShuffleOperands (S, E->Scalars , LHSVL, RHSVL);
3477
- setInsertPointAfterBundle (E->Scalars , S);
3478
3508
3479
- Value *LHS = vectorizeTree (LHSVL);
3480
- Value *RHS = vectorizeTree (RHSVL);
3509
+ Value *LHS, *RHS;
3510
+ if (Instruction::isBinaryOp (S.getOpcode ())) {
3511
+ reorderAltShuffleOperands (S, E->Scalars , LHSVL, RHSVL);
3512
+ setInsertPointAfterBundle (E->Scalars , S);
3513
+ LHS = vectorizeTree (LHSVL);
3514
+ RHS = vectorizeTree (RHSVL);
3515
+ } else {
3516
+ ValueList INVL;
3517
+ for (Value *V : E->Scalars )
3518
+ INVL.push_back (cast<Instruction>(V)->getOperand (0 ));
3519
+ setInsertPointAfterBundle (E->Scalars , S);
3520
+ LHS = vectorizeTree (INVL);
3521
+ }
3481
3522
3482
3523
if (E->VectorizedValue ) {
3483
3524
LLVM_DEBUG (dbgs () << " SLP: Diamond merged for " << *VL0 << " .\n " );
3484
3525
return E->VectorizedValue ;
3485
3526
}
3486
3527
3487
- // Create a vector of LHS op1 RHS
3488
- Value *V0 = Builder.CreateBinOp (
3528
+ Value *V0, *V1;
3529
+ if (Instruction::isBinaryOp (S.getOpcode ())) {
3530
+ V0 = Builder.CreateBinOp (
3489
3531
static_cast <Instruction::BinaryOps>(S.getOpcode ()), LHS, RHS);
3490
-
3491
- // Create a vector of LHS op2 RHS
3492
- Value *V1 = Builder.CreateBinOp (
3532
+ V1 = Builder.CreateBinOp (
3493
3533
static_cast <Instruction::BinaryOps>(S.getAltOpcode ()), LHS, RHS);
3534
+ } else {
3535
+ V0 = Builder.CreateCast (
3536
+ static_cast <Instruction::CastOps>(S.getOpcode ()), LHS, VecTy);
3537
+ V1 = Builder.CreateCast (
3538
+ static_cast <Instruction::CastOps>(S.getAltOpcode ()), LHS, VecTy);
3539
+ }
3494
3540
3495
3541
// Create shuffle to take alternate operations from the vector.
3496
- // Also, gather up odd and even scalar ops to propagate IR flags to
3542
+ // Also, gather up main and alt scalar ops to propagate IR flags to
3497
3543
// each vector operation.
3498
3544
ValueList OpScalars, AltScalars;
3499
3545
unsigned e = E->Scalars .size ();
0 commit comments