Skip to content

[NVPTX] Improve folding to mad with immediate 1 #93628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 87 additions & 6 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5614,17 +5614,98 @@ static SDValue TryMULWIDECombine(SDNode *N,
return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
}

static SDValue matchMADConstOnePattern(SDValue X, SDValue Add) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X is unused.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

if (Add->getOpcode() != ISD::ADD)
return SDValue();

SDValue Y = Add->getOperand(0);
ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Add->getOperand(1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed to have const operand to be last? I think we normalize them, but I'm not 100% sure it's always the case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've added the other case as well just in case.

if (!Const || Const->getZExtValue() != 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. Phrasing the condition in positive terms would be more readable, IMO.
if (Const && Const->getZExtValue() == 1) return Y;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

return SDValue();

return Y;
}

static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
TargetLowering::DAGCombinerInfo &DCI) {

if (SDValue Y = matchMADConstOnePattern(X, Add))
return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);

return SDValue();
}

static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
SDLoc DL,
TargetLowering::DAGCombinerInfo &DCI) {
if (Select->getOpcode() != ISD::SELECT)
return SDValue();

SDValue Cond = Select->getOperand(0);

unsigned ConstOpNo = 1;
auto *Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
if (!Const || Const->getZExtValue() != 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we could extract the common pattern into a helper function:

bool isConstOne(Operand) {
  const auto *Const = dyn_cast<ConstantSDNode>(Operand);
  return Const && Const->getZExtValue() == 1;
}

and then use it in handful of instances of this pattern throughout the code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

ConstOpNo = 2;
Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
if (!Const || Const->getZExtValue() != 1)
return SDValue();
}

SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);

// Do not combine if the resulting sequence is not obviously profitable.
if (!matchMADConstOnePattern(X, Y))
return SDValue();

SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);

return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
(ConstOpNo == 1) ? X : NewMul,
(ConstOpNo == 1) ? NewMul : X);
}

static SDValue
PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
TargetLowering::DAGCombinerInfo &DCI) {

EVT VT = N0.getValueType();
if (VT.isVector())
return SDValue();

if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
return SDValue();

SDLoc DL(N);

// (mul x, (add y, 1)) -> (mad x, y, x)
if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
return Res;
if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
return Res;

// (mul x, (select y, 1)) -> (select (mul x, y), x)
if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
return Res;
if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
return Res;

return SDValue();
}

/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
static SDValue PerformMULCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
if (OptLevel > CodeGenOptLevel::None) {
// Try mul.wide combining at OptLevel > 0
if (SDValue Ret = TryMULWIDECombine(N, DCI))
return Ret;
}
if (OptLevel == CodeGenOptLevel::None)
return SDValue();

return SDValue();
if (SDValue Ret = TryMULWIDECombine(N, DCI))
return Ret;

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
return PerformMULCombineWithOperands(N, N0, N1, DCI);
}

/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
Expand Down
101 changes: 101 additions & 0 deletions llvm/test/CodeGen/NVPTX/combine-mad.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -O1 | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another test which could use autogenerated CHECK patterns.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -O1 | FileCheck %s
; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -O1 | %ptxas-verify %}
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -O1 | %ptxas-verify %}

define i32 @test1(i32 %n, i32 %m) {
;
; CHECK: ld.param.u32 %[[N:r[0-9]+]], [test1_param_0];
; CHECK: ld.param.u32 %[[M:r[0-9]+]], [test1_param_1];
; CHECK: mad.lo.s32 %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
; CHECK: st.param.b32 [func_retval0+0], %[[MAD]];
;
%add = add i32 %n, 1
%mul = mul i32 %add, %m
ret i32 %mul
}

define i32 @test1_rev(i32 %n, i32 %m) {
;
; CHECK: ld.param.u32 %[[N:r[0-9]+]], [test1_rev_param_0];
; CHECK: ld.param.u32 %[[M:r[0-9]+]], [test1_rev_param_1];
; CHECK: mad.lo.s32 %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
; CHECK: st.param.b32 [func_retval0+0], %[[MAD]];
;
%add = add i32 %n, 1
%mul = mul i32 %m, %add
ret i32 %mul
}

; Transpose (mul (select)) if it can then be folded to mad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it buy us anything?

mul(m,select(1,n)) will probably have the same performance as select(mul(m,n), m) as the critical path will always have mul and select, just in different order.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By itself this transform doesn't help much, I agree. However, if m or n are add(x,1) then it enables the other transformation. In the code we're checking for this case and only running the transformation when it would enable further folding. A rare case to be sure, but better to support it than not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of optimization is not target-specific and should probably be done somewhere in instcombine. Perhaps move the optimization of mul(m,select(1,n)) there as a separate patch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instcombine already canonicalizes in the opposite direction, select(mul(m,n), m) -> mul(m,select(1,n)). I think this is target specific because it is only worth doing to improve mad folding.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

define i32 @test2(i32 %n, i32 %m, i32 %s) {
;
; CHECK: ld.param.u32 %[[N:r[0-9]+]], [test2_param_0];
; CHECK: ld.param.u32 %[[M:r[0-9]+]], [test2_param_1];
; CHECK: ld.param.u32 %[[S:r[0-9]+]], [test2_param_2];
; CHECK: setp.lt.s32 %[[COND:p[0-9]+]], %[[S]], 1;
; CHECK: mad.lo.s32 %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
; CHECK: selp.b32 %[[SEL:r[0-9]+]], %[[M]], %[[MAD]], %[[COND]];
; CHECK: st.param.b32 [func_retval0+0], %[[SEL]];
;
%add = add i32 %n, 1
%cond = icmp slt i32 %s, 1
%sel = select i1 %cond, i32 1, i32 %add
%mul = mul i32 %sel, %m
ret i32 %mul
}

;; Transpose (mul (select)) if it can then be folded to mad
define i32 @test2_rev1(i32 %n, i32 %m, i32 %s) {
;
; CHECK: ld.param.u32 %[[N:r[0-9]+]], [test2_rev1_param_0];
; CHECK: ld.param.u32 %[[M:r[0-9]+]], [test2_rev1_param_1];
; CHECK: ld.param.u32 %[[S:r[0-9]+]], [test2_rev1_param_2];
; CHECK: setp.lt.s32 %[[COND:p[0-9]+]], %[[S]], 1;
; CHECK: mad.lo.s32 %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
; CHECK: selp.b32 %[[SEL:r[0-9]+]], %[[MAD]], %[[M]], %[[COND]];
; CHECK: st.param.b32 [func_retval0+0], %[[SEL]];
;
%add = add i32 %n, 1
%cond = icmp slt i32 %s, 1
%sel = select i1 %cond, i32 %add, i32 1
%mul = mul i32 %sel, %m
ret i32 %mul
}

;; Transpose (mul (select)) if it can then be folded to mad
define i32 @test2_rev2(i32 %n, i32 %m, i32 %s) {
;
; CHECK: ld.param.u32 %[[N:r[0-9]+]], [test2_rev2_param_0];
; CHECK: ld.param.u32 %[[M:r[0-9]+]], [test2_rev2_param_1];
; CHECK: ld.param.u32 %[[S:r[0-9]+]], [test2_rev2_param_2];
; CHECK: setp.lt.s32 %[[COND:p[0-9]+]], %[[S]], 1;
; CHECK: mad.lo.s32 %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
; CHECK: selp.b32 %[[SEL:r[0-9]+]], %[[MAD]], %[[M]], %[[COND]];
; CHECK: st.param.b32 [func_retval0+0], %[[SEL]];
;
%add = add i32 %n, 1
%cond = icmp slt i32 %s, 1
%sel = select i1 %cond, i32 %add, i32 1
%mul = mul i32 %m, %sel
ret i32 %mul
}

;; Leave (mul (select)) intact if it transposing is not profitable
define i32 @test3(i32 %n, i32 %m, i32 %s) {
;
; CHECK: ld.param.u32 %[[N:r[0-9]+]], [test3_param_0];
; CHECK: add.s32 %[[ADD:r[0-9]+]], %[[N]], 3;
; CHECK: ld.param.u32 %[[M:r[0-9]+]], [test3_param_1];
; CHECK: ld.param.u32 %[[S:r[0-9]+]], [test3_param_2];
; CHECK: setp.lt.s32 %[[COND:p[0-9]+]], %[[S]], 1;
; CHECK: selp.b32 %[[SEL:r[0-9]+]], 1, %[[ADD]], %[[COND]];
; CHECK: mul.lo.s32 %[[MUL:r[0-9]+]], %[[SEL]], %[[M]];
; CHECK: st.param.b32 [func_retval0+0], %[[MUL]];
;
%add = add i32 %n, 3
%cond = icmp slt i32 %s, 1
%sel = select i1 %cond, i32 1, i32 %add
%mul = mul i32 %sel, %m
ret i32 %mul
}
Loading