Skip to content

Commit ae68d53

Browse files
[RISCV][VLOPT] Allow propagation even when VL isn't VLMAX (#112228)
The original goal of this pass was to focus on vector operations with VLMAX. However, users often utilize only part of the result, and such usage may come from the vectorizer. We found that relaxing this constraint can capture more optimization opportunities, such as non-power-of-2 code generation and vector operation sequences with different VLs. --------- Co-authored-by: Kito Cheng <[email protected]>
1 parent e768b07 commit ae68d53

File tree

5 files changed

+118
-55
lines changed

5 files changed

+118
-55
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,3 +4102,17 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
41024102
assert(Scaled >= 3 && Scaled <= 6);
41034103
return Scaled;
41044104
}
4105+
4106+
/// Given two VL operands, do we know that LHS <= RHS?
4107+
bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
4108+
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
4109+
LHS.getReg() == RHS.getReg())
4110+
return true;
4111+
if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
4112+
return true;
4113+
if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
4114+
return false;
4115+
if (!LHS.isImm() || !RHS.isImm())
4116+
return false;
4117+
return LHS.getImm() <= RHS.getImm();
4118+
}

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW);
346346
// Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
347347
static constexpr int64_t VLMaxSentinel = -1LL;
348348

349+
/// Given two VL operands, do we know that LHS <= RHS?
350+
bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS);
351+
349352
// Mask assignments for floating-point
350353
static constexpr unsigned FPMASK_Negative_Infinity = 0x001;
351354
static constexpr unsigned FPMASK_Negative_Normal = 0x002;

llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5151
StringRef getPassName() const override { return PASS_NAME; }
5252

5353
private:
54-
bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
54+
bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
5555
bool tryReduceVL(MachineInstr &MI);
5656
bool isCandidate(const MachineInstr &MI) const;
5757
};
@@ -658,10 +658,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
658658
if (MI.getNumDefs() != 1)
659659
return false;
660660

661+
// If we're not using VLMAX, then we need to be careful whether we are using
662+
// TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it
663+
// does not matter whether we are using TA/TU with a non-undef Passthru, since
664+
// there are no tail elements to be perserved.
661665
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
662666
const MachineOperand &VLOp = MI.getOperand(VLOpNum);
663-
if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
667+
if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) {
668+
// If MI has a non-undef passthru, we will not try to optimize it since
669+
// that requires us to preserve tail elements according to TA/TU.
670+
// Otherwise, The MI has an undef Passthru, so it doesn't matter whether we
671+
// are using TA/TU.
672+
bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
673+
unsigned PassthruOpIdx = MI.getNumExplicitDefs();
674+
if (HasPassthru &&
675+
MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) {
676+
LLVM_DEBUG(
677+
dbgs() << " Not a candidate because it uses non-undef passthru"
678+
" with non-VLMAX VL\n");
679+
return false;
680+
}
681+
}
682+
683+
// If the VL is 1, then there is no need to reduce it. This is an
684+
// optimization, not needed to preserve correctness.
685+
if (VLOp.isImm() && VLOp.getImm() == 1) {
686+
LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n");
664687
return false;
688+
}
665689

666690
// Some instructions that produce vectors have semantics that make it more
667691
// difficult to determine whether the VL can be reduced. For example, some
@@ -684,7 +708,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
684708
return true;
685709
}
686710

687-
bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
711+
bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
688712
MachineInstr &MI) {
689713
// FIXME: Avoid visiting each user for each time we visit something on the
690714
// worklist, combined with an extra visit from the outer loop. Restructure
@@ -730,16 +754,17 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
730754

731755
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
732756
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
733-
// Looking for a register VL that isn't X0.
734-
if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) {
735-
LLVM_DEBUG(dbgs() << " Abort due to user uses X0 as VL.\n");
736-
CanReduceVL = false;
737-
break;
738-
}
757+
758+
// Looking for an immediate or a register VL that isn't X0.
759+
assert(!VLOp.isReg() ||
760+
VLOp.getReg() != RISCV::X0 && "Did not expect X0 VL");
739761

740762
if (!CommonVL) {
741-
CommonVL = VLOp.getReg();
742-
} else if (*CommonVL != VLOp.getReg()) {
763+
CommonVL = &VLOp;
764+
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
765+
} else if (!CommonVL->isIdenticalTo(VLOp)) {
766+
// FIXME: This check requires all users to have the same VL. We can relax
767+
// this and get the largest VL amongst all users.
743768
LLVM_DEBUG(dbgs() << " Abort because users have different VL\n");
744769
CanReduceVL = false;
745770
break;
@@ -776,29 +801,42 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
776801
MachineInstr &MI = *Worklist.pop_back_val();
777802
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
778803

779-
std::optional<Register> CommonVL;
804+
const MachineOperand *CommonVL = nullptr;
780805
bool CanReduceVL = true;
781806
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
782807
CanReduceVL = checkUsers(CommonVL, MI);
783808

784809
if (!CanReduceVL || !CommonVL)
785810
continue;
786811

787-
if (!CommonVL->isVirtual()) {
788-
LLVM_DEBUG(
789-
dbgs() << " Abort due to new VL is not virtual register.\n");
812+
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
813+
"Expected VL to be an Imm or virtual Reg");
814+
815+
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
816+
MachineOperand &VLOp = MI.getOperand(VLOpNum);
817+
818+
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
819+
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
790820
continue;
791821
}
792822

793-
const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
794-
if (!MDT->dominates(VLMI, &MI))
795-
continue;
823+
if (CommonVL->isImm()) {
824+
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
825+
<< CommonVL->getImm() << " for " << MI << "\n");
826+
VLOp.ChangeToImmediate(CommonVL->getImm());
827+
} else {
828+
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
829+
if (!MDT->dominates(VLMI, &MI))
830+
continue;
831+
LLVM_DEBUG(
832+
dbgs() << " Reduce VL from " << VLOp << " to "
833+
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
834+
<< " for " << MI << "\n");
835+
836+
// All our checks passed. We can reduce VL.
837+
VLOp.ChangeToRegister(CommonVL->getReg(), false);
838+
}
796839

797-
// All our checks passed. We can reduce VL.
798-
LLVM_DEBUG(dbgs() << " Reducing VL for: " << MI << "\n");
799-
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
800-
MachineOperand &VLOp = MI.getOperand(VLOpNum);
801-
VLOp.ChangeToRegister(*CommonVL, false);
802840
MadeChange = true;
803841

804842
// Now add all inputs to this instruction to the worklist.

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,6 @@ char RISCVVectorPeephole::ID = 0;
8686
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
8787
false)
8888

89-
/// Given two VL operands, do we know that LHS <= RHS?
90-
static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
91-
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
92-
LHS.getReg() == RHS.getReg())
93-
return true;
94-
if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
95-
return true;
96-
if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
97-
return false;
98-
if (!LHS.isImm() || !RHS.isImm())
99-
return false;
100-
return LHS.getImm() <= RHS.getImm();
101-
}
102-
10389
/// Given \p User that has an input operand with EEW=SEW, which uses the dest
10490
/// operand of \p Src with an unknown EEW, return true if their EEWs match.
10591
bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
@@ -191,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
191177
return false;
192178

193179
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
194-
if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL))
180+
if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
195181
return false;
196182

197183
if (!ensureDominates(VL, *Src))
@@ -580,7 +566,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
580566
MachineOperand &SrcPolicy =
581567
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));
582568

583-
if (isVLKnownLE(MIVL, SrcVL))
569+
if (RISCV::isVLKnownLE(MIVL, SrcVL))
584570
SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC);
585571
}
586572

@@ -631,7 +617,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
631617
// so we don't need to handle a smaller source VL here. However, the
632618
// user's VL may be larger
633619
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
634-
if (!isVLKnownLE(SrcVL, MI.getOperand(3)))
620+
if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
635621
return false;
636622

637623
// If the new passthru doesn't dominate Src, try to move Src so it does.
@@ -650,7 +636,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
650636
// If MI was tail agnostic and the VL didn't increase, preserve it.
651637
int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
652638
if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) &&
653-
isVLKnownLE(MI.getOperand(3), SrcVL))
639+
RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
654640
Policy |= RISCVII::TAIL_AGNOSTIC;
655641
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);
656642

llvm/test/CodeGen/RISCV/rvv/vl-opt.ll

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,46 @@
1111
declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)
1212

1313
define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
14-
; CHECK-LABEL: different_imm_vl_with_ta:
15-
; CHECK: # %bb.0:
16-
; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
17-
; CHECK-NEXT: vadd.vv v8, v10, v12
18-
; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma
19-
; CHECK-NEXT: vadd.vv v8, v8, v10
20-
; CHECK-NEXT: ret
14+
; NOVLOPT-LABEL: different_imm_vl_with_ta:
15+
; NOVLOPT: # %bb.0:
16+
; NOVLOPT-NEXT: vsetivli zero, 5, e32, m2, ta, ma
17+
; NOVLOPT-NEXT: vadd.vv v8, v10, v12
18+
; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
19+
; NOVLOPT-NEXT: vadd.vv v8, v8, v10
20+
; NOVLOPT-NEXT: ret
21+
;
22+
; VLOPT-LABEL: different_imm_vl_with_ta:
23+
; VLOPT: # %bb.0:
24+
; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
25+
; VLOPT-NEXT: vadd.vv v8, v10, v12
26+
; VLOPT-NEXT: vadd.vv v8, v8, v10
27+
; VLOPT-NEXT: ret
2128
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 5)
2229
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
2330
ret <vscale x 4 x i32> %w
2431
}
2532

26-
; No benificial to propagate VL since VL is larger in the use side.
33+
define <vscale x 4 x i32> @vlmax_and_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
34+
; NOVLOPT-LABEL: vlmax_and_imm_vl_with_ta:
35+
; NOVLOPT: # %bb.0:
36+
; NOVLOPT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
37+
; NOVLOPT-NEXT: vadd.vv v8, v10, v12
38+
; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
39+
; NOVLOPT-NEXT: vadd.vv v8, v8, v10
40+
; NOVLOPT-NEXT: ret
41+
;
42+
; VLOPT-LABEL: vlmax_and_imm_vl_with_ta:
43+
; VLOPT: # %bb.0:
44+
; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
45+
; VLOPT-NEXT: vadd.vv v8, v10, v12
46+
; VLOPT-NEXT: vadd.vv v8, v8, v10
47+
; VLOPT-NEXT: ret
48+
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen -1)
49+
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
50+
ret <vscale x 4 x i32> %w
51+
}
52+
53+
; Not beneficial to propagate VL since VL is larger in the use side.
2754
define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
2855
; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
2956
; CHECK: # %bb.0:
@@ -50,8 +77,7 @@ define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %pass
5077
ret <vscale x 4 x i32> %w
5178
}
5279

53-
54-
; No benificial to propagate VL since VL is already one.
80+
; Not beneficial to propagate VL since VL is already one.
5581
define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
5682
; CHECK-LABEL: different_imm_vl_with_ta_1:
5783
; CHECK: # %bb.0:
@@ -110,7 +136,3 @@ define <vscale x 4 x i32> @different_imm_vl_with_tu(<vscale x 4 x i32> %passthru
110136
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen 4)
111137
ret <vscale x 4 x i32> %w
112138
}
113-
114-
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
115-
; NOVLOPT: {{.*}}
116-
; VLOPT: {{.*}}

0 commit comments

Comments
 (0)