Skip to content

Commit 0f97b48

Browse files
authored
[Scalarizer][DirectX] Add support for scalarization of Target intrinsics (#108776)
Since we are using the Scalarizer pass in the backend we needed a way to allow this pass to operate on Target intrinsics. We achieved this by adding `TargetTransformInfo ` to the Scalarizer pass. This allowed us to call a function available to the DirectX backend to know if an intrinsic is a target intrinsic that should be scalarized.
1 parent f36580f commit 0f97b48

File tree

11 files changed

+152
-41
lines changed

11 files changed

+152
-41
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,8 @@ class TargetTransformInfo {
882882
/// should use coldcc calling convention.
883883
bool useColdCCForColdCall(Function &F) const;
884884

885+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
886+
885887
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
886888
/// are set if the demanded result elements need to be inserted and/or
887889
/// extracted from vectors.
@@ -1928,6 +1930,7 @@ class TargetTransformInfo::Concept {
19281930
virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
19291931
virtual bool shouldBuildRelLookupTables() = 0;
19301932
virtual bool useColdCCForColdCall(Function &F) = 0;
1933+
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
19311934
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
19321935
const APInt &DemandedElts,
19331936
bool Insert, bool Extract,
@@ -2467,7 +2470,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
24672470
bool useColdCCForColdCall(Function &F) override {
24682471
return Impl.useColdCCForColdCall(F);
24692472
}
2470-
2473+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) override {
2474+
return Impl.isTargetIntrinsicTriviallyScalarizable(ID);
2475+
}
24712476
InstructionCost getScalarizationOverhead(VectorType *Ty,
24722477
const APInt &DemandedElts,
24732478
bool Insert, bool Extract,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ class TargetTransformInfoImplBase {
373373

374374
bool useColdCCForColdCall(Function &F) const { return false; }
375375

376+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const {
377+
return false;
378+
}
379+
376380
InstructionCost getScalarizationOverhead(VectorType *Ty,
377381
const APInt &DemandedElts,
378382
bool Insert, bool Extract,

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
789789
return Cost;
790790
}
791791

792+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const {
793+
return false;
794+
}
795+
792796
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
793797
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
794798
bool Extract,

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,11 @@ bool TargetTransformInfo::useColdCCForColdCall(Function &F) const {
587587
return TTIImpl->useColdCCForColdCall(F);
588588
}
589589

590+
bool TargetTransformInfo::isTargetIntrinsicTriviallyScalarizable(
591+
Intrinsic::ID ID) const {
592+
return TTIImpl->isTargetIntrinsicTriviallyScalarizable(ID);
593+
}
594+
590595
InstructionCost TargetTransformInfo::getScalarizationOverhead(
591596
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
592597
TTI::TargetCostKind CostKind) const {

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_llvm_target(DirectXCodeGen
1818
DirectXRegisterInfo.cpp
1919
DirectXSubtarget.cpp
2020
DirectXTargetMachine.cpp
21+
DirectXTargetTransformInfo.cpp
2122
DXContainerGlobals.cpp
2223
DXILFinalizeLinkage.cpp
2324
DXILIntrinsicExpansion.cpp
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "DirectXTargetTransformInfo.h"
13+
#include "llvm/IR/Intrinsics.h"
14+
#include "llvm/IR/IntrinsicsDirectX.h"
15+
16+
bool llvm::DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
17+
Intrinsic::ID ID) const {
18+
switch (ID) {
19+
case Intrinsic::dx_frac:
20+
case Intrinsic::dx_rsqrt:
21+
return true;
22+
default:
23+
return false;
24+
}
25+
}

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
3434
: BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)),
3535
TLI(ST->getTargetLowering()) {}
3636
unsigned getMinVectorRegisterBitWidth() const { return 32; }
37+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
3738
};
3839
} // namespace llvm
3940

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/PostOrderIterator.h"
1919
#include "llvm/ADT/SmallVector.h"
2020
#include "llvm/ADT/Twine.h"
21+
#include "llvm/Analysis/TargetTransformInfo.h"
2122
#include "llvm/Analysis/VectorUtils.h"
2223
#include "llvm/IR/Argument.h"
2324
#include "llvm/IR/BasicBlock.h"
@@ -281,17 +282,20 @@ T getWithDefaultOverride(const cl::opt<T> &ClOption,
281282

282283
class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
283284
public:
284-
ScalarizerVisitor(DominatorTree *DT, ScalarizerPassOptions Options)
285-
: DT(DT), ScalarizeVariableInsertExtract(getWithDefaultOverride(
286-
ClScalarizeVariableInsertExtract,
287-
Options.ScalarizeVariableInsertExtract)),
285+
ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI,
286+
ScalarizerPassOptions Options)
287+
: DT(DT), TTI(TTI), ScalarizeVariableInsertExtract(getWithDefaultOverride(
288+
ClScalarizeVariableInsertExtract,
289+
Options.ScalarizeVariableInsertExtract)),
288290
ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
289291
Options.ScalarizeLoadStore)),
290292
ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits,
291293
Options.ScalarizeMinBits)) {}
292294

293295
bool visit(Function &F);
294296

297+
bool isTriviallyScalarizable(Intrinsic::ID ID);
298+
295299
// InstVisitor methods. They return true if the instruction was scalarized,
296300
// false if nothing changed.
297301
bool visitInstruction(Instruction &I) { return false; }
@@ -335,6 +339,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
335339
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
336340

337341
DominatorTree *DT;
342+
const TargetTransformInfo *TTI;
338343

339344
const bool ScalarizeVariableInsertExtract;
340345
const bool ScalarizeLoadStore;
@@ -358,6 +363,7 @@ ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options)
358363

359364
void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
360365
AU.addRequired<DominatorTreeWrapperPass>();
366+
AU.addRequired<TargetTransformInfoWrapperPass>();
361367
AU.addPreserved<DominatorTreeWrapperPass>();
362368
}
363369

@@ -445,7 +451,9 @@ bool ScalarizerLegacyPass::runOnFunction(Function &F) {
445451
return false;
446452

447453
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
448-
ScalarizerVisitor Impl(DT, Options);
454+
const TargetTransformInfo *TTI =
455+
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
456+
ScalarizerVisitor Impl(DT, TTI, Options);
449457
return Impl.visit(F);
450458
}
451459

@@ -689,8 +697,11 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
689697
return true;
690698
}
691699

692-
static bool isTriviallyScalariable(Intrinsic::ID ID) {
693-
return isTriviallyVectorizable(ID);
700+
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
701+
if (isTriviallyVectorizable(ID))
702+
return true;
703+
return Function::isTargetIntrinsic(ID) &&
704+
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
694705
}
695706

696707
/// If a call to a vector typed intrinsic function, split into a scalar call per
@@ -705,7 +716,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
705716
return false;
706717

707718
Intrinsic::ID ID = F->getIntrinsicID();
708-
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
719+
720+
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID))
709721
return false;
710722

711723
// unsigned NumElems = VT->getNumElements();
@@ -1249,7 +1261,8 @@ bool ScalarizerVisitor::finish() {
12491261

12501262
PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
12511263
DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1252-
ScalarizerVisitor Impl(DT, Options);
1264+
const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
1265+
ScalarizerVisitor Impl(DT, TTI, Options);
12531266
bool Changed = Impl.visit(F);
12541267
PreservedAnalyses PA;
12551268
PA.preserve<DominatorTreeAnalysis>();

llvm/test/CodeGen/DirectX/frac.ll

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,55 @@
1-
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
23

34
; Make sure dxil operation function calls for frac are generated for float and half.
4-
; CHECK:call float @dx.op.unary.f32(i32 22, float %{{.*}})
5-
; CHECK:call half @dx.op.unary.f16(i32 22, half %{{.*}})
65

7-
target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
8-
target triple = "dxil-pc-shadermodel6.7-library"
6+
define noundef half @frac_half(half noundef %a) {
7+
; CHECK-LABEL: define noundef half @frac_half(
8+
; CHECK-SAME: half noundef [[A:%.*]]) {
9+
; CHECK-NEXT: [[ENTRY:.*:]]
10+
; CHECK-NEXT: [[DX_FRAC1:%.*]] = call half @dx.op.unary.f16(i32 22, half [[A]])
11+
; CHECK-NEXT: ret half [[DX_FRAC1]]
12+
;
13+
entry:
14+
%dx.frac = call half @llvm.dx.frac.f16(half %a)
15+
ret half %dx.frac
16+
}
917

10-
; Function Attrs: noinline nounwind optnone
1118
define noundef float @frac_float(float noundef %a) #0 {
19+
; CHECK-LABEL: define noundef float @frac_float(
20+
; CHECK-SAME: float noundef [[A:%.*]]) {
21+
; CHECK-NEXT: [[ENTRY:.*:]]
22+
; CHECK-NEXT: [[DX_FRAC1:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A]])
23+
; CHECK-NEXT: ret float [[DX_FRAC1]]
24+
;
1225
entry:
13-
%a.addr = alloca float, align 4
14-
store float %a, ptr %a.addr, align 4
15-
%0 = load float, ptr %a.addr, align 4
16-
%dx.frac = call float @llvm.dx.frac.f32(float %0)
26+
%dx.frac = call float @llvm.dx.frac.f32(float %a)
1727
ret float %dx.frac
1828
}
1929

20-
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
21-
declare float @llvm.dx.frac.f32(float) #1
22-
23-
; Function Attrs: noinline nounwind optnone
24-
define noundef half @frac_half(half noundef %a) #0 {
30+
define noundef <4 x float> @frac_float4(<4 x float> noundef %a) #0 {
31+
; CHECK-LABEL: define noundef <4 x float> @frac_float4(
32+
; CHECK-SAME: <4 x float> noundef [[A:%.*]]) {
33+
; CHECK-NEXT: [[ENTRY:.*:]]
34+
; CHECK-NEXT: [[A_I0:%.*]] = extractelement <4 x float> [[A]], i64 0
35+
; CHECK-NEXT: [[DOTI04:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I0]])
36+
; CHECK-NEXT: [[A_I1:%.*]] = extractelement <4 x float> [[A]], i64 1
37+
; CHECK-NEXT: [[DOTI13:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I1]])
38+
; CHECK-NEXT: [[A_I2:%.*]] = extractelement <4 x float> [[A]], i64 2
39+
; CHECK-NEXT: [[DOTI22:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I2]])
40+
; CHECK-NEXT: [[A_I3:%.*]] = extractelement <4 x float> [[A]], i64 3
41+
; CHECK-NEXT: [[DOTI31:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I3]])
42+
; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x float> poison, float [[DOTI04]], i64 0
43+
; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x float> [[DOTUPTO0]], float [[DOTI13]], i64 1
44+
; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x float> [[DOTUPTO1]], float [[DOTI22]], i64 2
45+
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> [[DOTUPTO2]], float [[DOTI31]], i64 3
46+
; CHECK-NEXT: ret <4 x float> [[TMP0]]
47+
;
2548
entry:
26-
%a.addr = alloca half, align 2
27-
store half %a, ptr %a.addr, align 2
28-
%0 = load half, ptr %a.addr, align 2
29-
%dx.frac = call half @llvm.dx.frac.f16(half %0)
30-
ret half %dx.frac
49+
%2 = call <4 x float> @llvm.dx.frac.v4f32(<4 x float> %a)
50+
ret <4 x float> %2
3151
}
52+
53+
declare half @llvm.dx.frac.f16(half)
54+
declare float @llvm.dx.frac.f32(float)
55+
declare <4 x float> @llvm.dx.frac.v4f32(<4 x float>)

llvm/test/CodeGen/DirectX/llc-pipeline.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
; CHECK-LABEL: Pass Arguments:
77
; CHECK-NEXT: Target Library Information
8+
; CHECK-NEXT: Target Transform Information
89
; CHECK-NEXT: ModulePass Manager
910
; CHECK-NEXT: DXIL Intrinsic Expansion
1011
; CHECK-NEXT: FunctionPass Manager

llvm/test/CodeGen/DirectX/rsqrt.ll

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,56 @@
1-
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
23

34
; Make sure dxil operation function calls for rsqrt are generated for float and half.
45

56
; CHECK-LABEL: rsqrt_float
6-
; CHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
77
define noundef float @rsqrt_float(float noundef %a) {
8+
; CHECK-SAME: float noundef [[A:%.*]]) {
9+
; CHECK-NEXT: [[ENTRY:.*:]]
10+
; CHECK-NEXT: [[DX_RSQRT1:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A]])
11+
; CHECK-NEXT: ret float [[DX_RSQRT1]]
12+
;
813
entry:
9-
%a.addr = alloca float, align 4
10-
store float %a, ptr %a.addr, align 4
11-
%0 = load float, ptr %a.addr, align 4
12-
%dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
14+
%dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %a)
1315
ret float %dx.rsqrt
1416
}
1517

1618
; CHECK-LABEL: rsqrt_half
17-
; CHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
1819
define noundef half @rsqrt_half(half noundef %a) {
20+
; CHECK-SAME: half noundef [[A:%.*]]) {
21+
; CHECK-NEXT: [[ENTRY:.*:]]
22+
; CHECK-NEXT: [[DX_RSQRT1:%.*]] = call half @dx.op.unary.f16(i32 25, half [[A]])
23+
; CHECK-NEXT: ret half [[DX_RSQRT1]]
24+
;
1925
entry:
20-
%a.addr = alloca half, align 2
21-
store half %a, ptr %a.addr, align 2
22-
%0 = load half, ptr %a.addr, align 2
23-
%dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
26+
%dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %a)
2427
ret half %dx.rsqrt
2528
}
2629

30+
define noundef <4 x float> @rsqrt_float4(<4 x float> noundef %a) #0 {
31+
; CHECK-LABEL: define noundef <4 x float> @rsqrt_float4(
32+
; CHECK-SAME: <4 x float> noundef [[A:%.*]]) {
33+
; CHECK-NEXT: [[ENTRY:.*:]]
34+
; CHECK-NEXT: [[A_I0:%.*]] = extractelement <4 x float> [[A]], i64 0
35+
; CHECK-NEXT: [[DOTI04:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I0]])
36+
; CHECK-NEXT: [[A_I1:%.*]] = extractelement <4 x float> [[A]], i64 1
37+
; CHECK-NEXT: [[DOTI13:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I1]])
38+
; CHECK-NEXT: [[A_I2:%.*]] = extractelement <4 x float> [[A]], i64 2
39+
; CHECK-NEXT: [[DOTI22:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I2]])
40+
; CHECK-NEXT: [[A_I3:%.*]] = extractelement <4 x float> [[A]], i64 3
41+
; CHECK-NEXT: [[DOTI31:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I3]])
42+
; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x float> poison, float [[DOTI04]], i64 0
43+
; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x float> [[DOTUPTO0]], float [[DOTI13]], i64 1
44+
; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x float> [[DOTUPTO1]], float [[DOTI22]], i64 2
45+
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> [[DOTUPTO2]], float [[DOTI31]], i64 3
46+
; CHECK-NEXT: ret <4 x float> [[TMP0]]
47+
;
48+
entry:
49+
%2 = call <4 x float> @llvm.dx.rsqrt.v4f32(<4 x float> %a)
50+
ret <4 x float> %2
51+
}
52+
53+
2754
declare half @llvm.dx.rsqrt.f16(half)
2855
declare float @llvm.dx.rsqrt.f32(float)
56+
declare <4 x float> @llvm.dx.rsqrt.v4f32(<4 x float>)

0 commit comments

Comments
 (0)