Skip to content

Commit dab1f7c

Browse files
authored
AMDGPU: Emit 1/llvm.sqrt(x) instead of rsqrt calls in libcall handling (#92863)
With the contract flag we should end up codegening to the rsqrt instruction, or denormal corrected rsqrt sequence present in the library.
1 parent 67e3514 commit dab1f7c

File tree

3 files changed

+56
-29
lines changed

3 files changed

+56
-29
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,16 +1215,36 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
12151215
"__rootn2div");
12161216
replaceCall(FPOp, nval);
12171217
return true;
1218-
} else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
1219-
if (FunctionCallee FPExpr =
1220-
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
1221-
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
1222-
<< ")\n");
1223-
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
1224-
replaceCall(FPOp, nval);
1225-
return true;
1226-
}
12271218
}
1219+
1220+
if (ci_opr1 == -2 &&
1221+
shouldReplaceLibcallWithIntrinsic(CI,
1222+
/*AllowMinSizeF32=*/true,
1223+
/*AllowF64=*/true)) {
1224+
// rootn(x, -2) = rsqrt(x)
1225+
1226+
// The original rootn had looser ulp requirements than the resultant sqrt
1227+
// and fdiv.
1228+
MDBuilder MDHelper(M->getContext());
1229+
MDNode *FPMD = MDHelper.createFPMath(std::max(FPOp->getFPAccuracy(), 2.0f));
1230+
1231+
// TODO: Could handle strictfp but need to fix strict sqrt emission
1232+
FastMathFlags FMF = FPOp->getFastMathFlags();
1233+
FMF.setAllowContract(true);
1234+
1235+
CallInst *Sqrt = B.CreateUnaryIntrinsic(Intrinsic::sqrt, opr0, CI);
1236+
Instruction *RSqrt = cast<Instruction>(
1237+
B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0), Sqrt));
1238+
Sqrt->setFastMathFlags(FMF);
1239+
RSqrt->setFastMathFlags(FMF);
1240+
RSqrt->setMetadata(LLVMContext::MD_fpmath, FPMD);
1241+
1242+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
1243+
<< ")\n");
1244+
replaceCall(CI, RSqrt);
1245+
return true;
1246+
}
1247+
12281248
return false;
12291249
}
12301250

llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ define half @test_rootn_f16_neg1(half %x) {
302302
define half @test_rootn_f16_neg2(half %x) {
303303
; CHECK-LABEL: define half @test_rootn_f16_neg2(
304304
; CHECK-SAME: half [[X:%.*]]) {
305-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call half @_Z5rsqrtDh(half [[X]])
305+
; CHECK-NEXT: [[TMP1:%.*]] = call contract half @llvm.sqrt.f16(half [[X]])
306+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract half 0xH3C00, [[TMP1]], !fpmath [[META0]]
306307
; CHECK-NEXT: ret half [[__ROOTN2RSQRT]]
307308
;
308309
%call = tail call half @_Z5rootnDhi(half %x, i32 -2)
@@ -371,7 +372,8 @@ define <2 x half> @test_rootn_v2f16_neg1(<2 x half> %x) {
371372
define <2 x half> @test_rootn_v2f16_neg2(<2 x half> %x) {
372373
; CHECK-LABEL: define <2 x half> @test_rootn_v2f16_neg2(
373374
; CHECK-SAME: <2 x half> [[X:%.*]]) {
374-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call <2 x half> @_Z5rsqrtDv2_Dh(<2 x half> [[X]])
375+
; CHECK-NEXT: [[TMP1:%.*]] = call contract <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X]])
376+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract <2 x half> <half 0xH3C00, half 0xH3C00>, [[TMP1]], !fpmath [[META0]]
375377
; CHECK-NEXT: ret <2 x half> [[__ROOTN2RSQRT]]
376378
;
377379
%call = tail call <2 x half> @_Z5rootnDv2_DhDv2_i(<2 x half> %x, <2 x i32> <i32 -2, i32 -2>)
@@ -865,7 +867,8 @@ define float @test_rootn_f32__y_neg2(float %x) {
865867
; CHECK-LABEL: define float @test_rootn_f32__y_neg2(
866868
; CHECK-SAME: float [[X:%.*]]) {
867869
; CHECK-NEXT: entry:
868-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]])
870+
; CHECK-NEXT: [[TMP0:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
871+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract float 1.000000e+00, [[TMP0]], !fpmath [[META0]]
869872
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
870873
;
871874
entry:
@@ -877,7 +880,8 @@ define float @test_rootn_f32__y_neg2__flags(float %x) {
877880
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__flags(
878881
; CHECK-SAME: float [[X:%.*]]) {
879882
; CHECK-NEXT: entry:
880-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call nnan nsz float @_Z5rsqrtf(float [[X]])
883+
; CHECK-NEXT: [[TMP0:%.*]] = call nnan nsz contract float @llvm.sqrt.f32(float [[X]])
884+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv nnan nsz contract float 1.000000e+00, [[TMP0]], !fpmath [[META0]]
881885
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
882886
;
883887
entry:
@@ -889,7 +893,7 @@ define float @test_rootn_f32__y_neg2__strictfp(float %x) #1 {
889893
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__strictfp(
890894
; CHECK-SAME: float [[X:%.*]]) #[[ATTR0]] {
891895
; CHECK-NEXT: entry:
892-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]]) #[[ATTR0]]
896+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR0]]
893897
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
894898
;
895899
entry:
@@ -901,7 +905,7 @@ define float @test_rootn_f32__y_neg2__noinline(float %x) {
901905
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__noinline(
902906
; CHECK-SAME: float [[X:%.*]]) {
903907
; CHECK-NEXT: entry:
904-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call float @_Z5rsqrtf(float [[X]])
908+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
905909
; CHECK-NEXT: ret float [[__ROOTN2RSQRT]]
906910
;
907911
entry:
@@ -913,7 +917,7 @@ define float @test_rootn_f32__y_neg2__nobuiltin(float %x) {
913917
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__nobuiltin(
914918
; CHECK-SAME: float [[X:%.*]]) {
915919
; CHECK-NEXT: entry:
916-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
920+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR4:[0-9]+]]
917921
; CHECK-NEXT: ret float [[CALL]]
918922
;
919923
entry:
@@ -925,7 +929,8 @@ define <2 x float> @test_rootn_v2f32__y_neg2(<2 x float> %x) {
925929
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2(
926930
; CHECK-SAME: <2 x float> [[X:%.*]]) {
927931
; CHECK-NEXT: entry:
928-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]])
932+
; CHECK-NEXT: [[TMP0:%.*]] = call contract <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
933+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv contract <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[TMP0]], !fpmath [[META0]]
929934
; CHECK-NEXT: ret <2 x float> [[__ROOTN2RSQRT]]
930935
;
931936
entry:
@@ -937,7 +942,8 @@ define <2 x float> @test_rootn_v2f32__y_neg2__flags(<2 x float> %x) {
937942
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2__flags(
938943
; CHECK-SAME: <2 x float> [[X:%.*]]) {
939944
; CHECK-NEXT: entry:
940-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call nnan nsz <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]])
945+
; CHECK-NEXT: [[TMP0:%.*]] = call nnan nsz contract <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
946+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = fdiv nnan nsz contract <2 x float> <float 1.000000e+00, float 1.000000e+00>, [[TMP0]], !fpmath [[META0]]
941947
; CHECK-NEXT: ret <2 x float> [[__ROOTN2RSQRT]]
942948
;
943949
entry:
@@ -949,7 +955,7 @@ define <2 x float> @test_rootn_v2f32__y_neg2__strictfp(<2 x float> %x) #1 {
949955
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_neg2__strictfp(
950956
; CHECK-SAME: <2 x float> [[X:%.*]]) #[[ATTR0]] {
951957
; CHECK-NEXT: entry:
952-
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = call <2 x float> @_Z5rsqrtDv2_f(<2 x float> [[X]]) #[[ATTR0]]
958+
; CHECK-NEXT: [[__ROOTN2RSQRT:%.*]] = tail call <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> [[X]], <2 x i32> <i32 -2, i32 -2>) #[[ATTR0]]
953959
; CHECK-NEXT: ret <2 x float> [[__ROOTN2RSQRT]]
954960
;
955961
entry:
@@ -1125,7 +1131,7 @@ define float @test_rootn_fast_f32_nobuiltin(float %x, i32 %y) {
11251131
; CHECK-LABEL: define float @test_rootn_fast_f32_nobuiltin(
11261132
; CHECK-SAME: float [[X:%.*]], i32 [[Y:%.*]]) {
11271133
; CHECK-NEXT: entry:
1128-
; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR3]]
1134+
; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR4]]
11291135
; CHECK-NEXT: ret float [[CALL]]
11301136
;
11311137
entry:
@@ -1420,7 +1426,7 @@ entry:
14201426
define float @test_rootn_f32__y_0_nobuiltin(float %x) {
14211427
; CHECK-LABEL: define float @test_rootn_f32__y_0_nobuiltin(
14221428
; CHECK-SAME: float [[X:%.*]]) {
1423-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR3]]
1429+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR4]]
14241430
; CHECK-NEXT: ret float [[CALL]]
14251431
;
14261432
%call = tail call float @_Z5rootnfi(float %x, i32 0) #0
@@ -1430,7 +1436,7 @@ define float @test_rootn_f32__y_0_nobuiltin(float %x) {
14301436
define float @test_rootn_f32__y_1_nobuiltin(float %x) {
14311437
; CHECK-LABEL: define float @test_rootn_f32__y_1_nobuiltin(
14321438
; CHECK-SAME: float [[X:%.*]]) {
1433-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR3]]
1439+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR4]]
14341440
; CHECK-NEXT: ret float [[CALL]]
14351441
;
14361442
%call = tail call float @_Z5rootnfi(float %x, i32 1) #0
@@ -1440,7 +1446,7 @@ define float @test_rootn_f32__y_1_nobuiltin(float %x) {
14401446
define float @test_rootn_f32__y_2_nobuiltin(float %x) {
14411447
; CHECK-LABEL: define float @test_rootn_f32__y_2_nobuiltin(
14421448
; CHECK-SAME: float [[X:%.*]]) {
1443-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR3]]
1449+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR4]]
14441450
; CHECK-NEXT: ret float [[CALL]]
14451451
;
14461452
%call = tail call float @_Z5rootnfi(float %x, i32 2) #0
@@ -1450,7 +1456,7 @@ define float @test_rootn_f32__y_2_nobuiltin(float %x) {
14501456
define float @test_rootn_f32__y_3_nobuiltin(float %x) {
14511457
; CHECK-LABEL: define float @test_rootn_f32__y_3_nobuiltin(
14521458
; CHECK-SAME: float [[X:%.*]]) {
1453-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR3]]
1459+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR4]]
14541460
; CHECK-NEXT: ret float [[CALL]]
14551461
;
14561462
%call = tail call float @_Z5rootnfi(float %x, i32 3) #0
@@ -1460,7 +1466,7 @@ define float @test_rootn_f32__y_3_nobuiltin(float %x) {
14601466
define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
14611467
; CHECK-LABEL: define float @test_rootn_f32__y_neg1_nobuiltin(
14621468
; CHECK-SAME: float [[X:%.*]]) {
1463-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR3]]
1469+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR4]]
14641470
; CHECK-NEXT: ret float [[CALL]]
14651471
;
14661472
%call = tail call float @_Z5rootnfi(float %x, i32 -1) #0
@@ -1470,7 +1476,7 @@ define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
14701476
define float @test_rootn_f32__y_neg2_nobuiltin(float %x) {
14711477
; CHECK-LABEL: define float @test_rootn_f32__y_neg2_nobuiltin(
14721478
; CHECK-SAME: float [[X:%.*]]) {
1473-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3]]
1479+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR4]]
14741480
; CHECK-NEXT: ret float [[CALL]]
14751481
;
14761482
%call = tail call float @_Z5rootnfi(float %x, i32 -2) #0
@@ -1487,7 +1493,8 @@ attributes #2 = { noinline }
14871493
; CHECK: attributes #[[ATTR0]] = { strictfp }
14881494
; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
14891495
; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind memory(read) }
1490-
; CHECK: attributes #[[ATTR3]] = { nobuiltin }
1496+
; CHECK: attributes #[[ATTR3]] = { noinline }
1497+
; CHECK: attributes #[[ATTR4]] = { nobuiltin }
14911498
;.
14921499
; CHECK: [[META0]] = !{float 2.000000e+00}
14931500
; CHECK: [[META1]] = !{float 3.000000e+00}

llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,8 @@ entry:
506506
}
507507

508508
; GCN-LABEL: {{^}}define amdgpu_kernel void @test_rootn_m2
509-
; GCN-POSTLINK: call fast float @_Z5rootnfi(float %tmp, i32 -2)
510-
; GCN-PRELINK: %__rootn2rsqrt = tail call fast float @_Z5rsqrtf(float %tmp)
509+
; GCN: [[SQRT:%.+]] = tail call fast float @llvm.sqrt.f32(float %tmp)
510+
; GCN-NEXT: fdiv fast float 1.000000e+00, [[SQRT]]
511511
define amdgpu_kernel void @test_rootn_m2(ptr addrspace(1) nocapture %a) {
512512
entry:
513513
%tmp = load float, ptr addrspace(1) %a, align 4

0 commit comments

Comments
 (0)