Skip to content

Commit 0169946

Browse files
committed
[DirectX] Add support for scalarization of Target intrinsics
1 parent 8965795 commit 0169946

File tree

11 files changed

+123
-43
lines changed

11 files changed

+123
-43
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,8 @@ class TargetTransformInfo {
882882
/// should use coldcc calling convention.
883883
bool useColdCCForColdCall(Function &F) const;
884884

885+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
886+
885887
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
886888
/// are set if the demanded result elements need to be inserted and/or
887889
/// extracted from vectors.
@@ -1928,6 +1930,7 @@ class TargetTransformInfo::Concept {
19281930
virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
19291931
virtual bool shouldBuildRelLookupTables() = 0;
19301932
virtual bool useColdCCForColdCall(Function &F) = 0;
1933+
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
19311934
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
19321935
const APInt &DemandedElts,
19331936
bool Insert, bool Extract,
@@ -2467,7 +2470,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
24672470
bool useColdCCForColdCall(Function &F) override {
24682471
return Impl.useColdCCForColdCall(F);
24692472
}
2470-
2473+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) override {
2474+
return Impl.isTargetIntrinsicTriviallyScalarizable(ID);
2475+
}
24712476
InstructionCost getScalarizationOverhead(VectorType *Ty,
24722477
const APInt &DemandedElts,
24732478
bool Insert, bool Extract,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ class TargetTransformInfoImplBase {
373373

374374
bool useColdCCForColdCall(Function &F) const { return false; }
375375

376+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const {
377+
return false;
378+
}
379+
376380
InstructionCost getScalarizationOverhead(VectorType *Ty,
377381
const APInt &DemandedElts,
378382
bool Insert, bool Extract,

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
789789
return Cost;
790790
}
791791

792+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const {
793+
return false;
794+
}
795+
792796
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
793797
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
794798
bool Extract,

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def int_dx_udot :
7272
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
7373
[IntrNoMem, Commutative] >;
7474

75-
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
75+
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
7676

7777
def int_dx_isinf :
7878
DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
@@ -85,7 +85,7 @@ def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyf
8585
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
8686
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
8787
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
88-
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
88+
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
8989
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9090
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
9191
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>]>;

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,11 @@ bool TargetTransformInfo::useColdCCForColdCall(Function &F) const {
587587
return TTIImpl->useColdCCForColdCall(F);
588588
}
589589

590+
bool TargetTransformInfo::isTargetIntrinsicTriviallyScalarizable(
591+
Intrinsic::ID ID) const {
592+
return TTIImpl->isTargetIntrinsicTriviallyScalarizable(ID);
593+
}
594+
590595
InstructionCost TargetTransformInfo::getScalarizationOverhead(
591596
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
592597
TTI::TargetCostKind CostKind) const {

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_llvm_target(DirectXCodeGen
1818
DirectXRegisterInfo.cpp
1919
DirectXSubtarget.cpp
2020
DirectXTargetMachine.cpp
21+
DirectXTargetTransformInfo.cpp
2122
DXContainerGlobals.cpp
2223
DXILFinalizeLinkage.cpp
2324
DXILIntrinsicExpansion.cpp
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "DirectXTargetTransformInfo.h"
13+
#include "llvm/IR/Intrinsics.h"
14+
#include "llvm/IR/IntrinsicsDirectX.h"
15+
16+
bool llvm::DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
17+
Intrinsic::ID ID) const {
18+
switch (ID) {
19+
case Intrinsic::dx_frac:
20+
case Intrinsic::dx_rsqrt:
21+
return true;
22+
default:
23+
return false;
24+
}
25+
}

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
3434
: BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)),
3535
TLI(ST->getTargetLowering()) {}
3636
unsigned getMinVectorRegisterBitWidth() const { return 32; }
37+
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
3738
};
3839
} // namespace llvm
3940

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/PostOrderIterator.h"
1919
#include "llvm/ADT/SmallVector.h"
2020
#include "llvm/ADT/Twine.h"
21+
#include "llvm/Analysis/TargetTransformInfo.h"
2122
#include "llvm/Analysis/VectorUtils.h"
2223
#include "llvm/IR/Argument.h"
2324
#include "llvm/IR/BasicBlock.h"
@@ -32,6 +33,7 @@
3233
#include "llvm/IR/Instruction.h"
3334
#include "llvm/IR/Instructions.h"
3435
#include "llvm/IR/Intrinsics.h"
36+
#include "llvm/IR/IntrinsicsDirectX.h"
3537
#include "llvm/IR/LLVMContext.h"
3638
#include "llvm/IR/Module.h"
3739
#include "llvm/IR/Type.h"
@@ -281,17 +283,20 @@ T getWithDefaultOverride(const cl::opt<T> &ClOption,
281283

282284
class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
283285
public:
284-
ScalarizerVisitor(DominatorTree *DT, ScalarizerPassOptions Options)
285-
: DT(DT), ScalarizeVariableInsertExtract(getWithDefaultOverride(
286-
ClScalarizeVariableInsertExtract,
287-
Options.ScalarizeVariableInsertExtract)),
286+
ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI,
287+
ScalarizerPassOptions Options)
288+
: DT(DT), TTI(TTI), ScalarizeVariableInsertExtract(getWithDefaultOverride(
289+
ClScalarizeVariableInsertExtract,
290+
Options.ScalarizeVariableInsertExtract)),
288291
ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
289292
Options.ScalarizeLoadStore)),
290293
ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits,
291294
Options.ScalarizeMinBits)) {}
292295

293296
bool visit(Function &F);
294297

298+
bool isTriviallyScalarizable(Intrinsic::ID ID);
299+
295300
// InstVisitor methods. They return true if the instruction was scalarized,
296301
// false if nothing changed.
297302
bool visitInstruction(Instruction &I) { return false; }
@@ -335,6 +340,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
335340
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
336341

337342
DominatorTree *DT;
343+
const TargetTransformInfo *TTI;
338344

339345
const bool ScalarizeVariableInsertExtract;
340346
const bool ScalarizeLoadStore;
@@ -358,6 +364,7 @@ ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options)
358364

359365
void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
360366
AU.addRequired<DominatorTreeWrapperPass>();
367+
AU.addRequired<TargetTransformInfoWrapperPass>();
361368
AU.addPreserved<DominatorTreeWrapperPass>();
362369
}
363370

@@ -445,7 +452,9 @@ bool ScalarizerLegacyPass::runOnFunction(Function &F) {
445452
return false;
446453

447454
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
448-
ScalarizerVisitor Impl(DT, Options);
455+
const TargetTransformInfo *TTI =
456+
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
457+
ScalarizerVisitor Impl(DT, TTI, Options);
449458
return Impl.visit(F);
450459
}
451460

@@ -689,8 +698,10 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
689698
return true;
690699
}
691700

692-
static bool isTriviallyScalariable(Intrinsic::ID ID) {
693-
return isTriviallyVectorizable(ID);
701+
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
702+
703+
return TTI->isTargetIntrinsicTriviallyScalarizable(ID) ||
704+
isTriviallyVectorizable(ID);
694705
}
695706

696707
/// If a call to a vector typed intrinsic function, split into a scalar call per
@@ -705,7 +716,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
705716
return false;
706717

707718
Intrinsic::ID ID = F->getIntrinsicID();
708-
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
719+
720+
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID))
709721
return false;
710722

711723
// unsigned NumElems = VT->getNumElements();
@@ -1249,7 +1261,8 @@ bool ScalarizerVisitor::finish() {
12491261

12501262
PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
12511263
DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1252-
ScalarizerVisitor Impl(DT, Options);
1264+
const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
1265+
ScalarizerVisitor Impl(DT, TTI, Options);
12531266
bool Changed = Impl.visit(F);
12541267
PreservedAnalyses PA;
12551268
PA.preserve<DominatorTreeAnalysis>();

llvm/test/CodeGen/DirectX/frac.ll

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
1-
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
1+
; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
22

33
; Make sure dxil operation function calls for frac are generated for float and half.
4-
; CHECK:call float @dx.op.unary.f32(i32 22, float %{{.*}})
5-
; CHECK:call half @dx.op.unary.f16(i32 22, half %{{.*}})
64

7-
target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
8-
target triple = "dxil-pc-shadermodel6.7-library"
5+
define noundef half @frac_half(half noundef %a) {
6+
entry:
7+
; CHECK:call half @dx.op.unary.f16(i32 22, half %{{.*}})
8+
%dx.frac = call half @llvm.dx.frac.f16(half %a)
9+
ret half %dx.frac
10+
}
911

10-
; Function Attrs: noinline nounwind optnone
1112
define noundef float @frac_float(float noundef %a) #0 {
1213
entry:
13-
%a.addr = alloca float, align 4
14-
store float %a, ptr %a.addr, align 4
15-
%0 = load float, ptr %a.addr, align 4
16-
%dx.frac = call float @llvm.dx.frac.f32(float %0)
14+
; CHECK:call float @dx.op.unary.f32(i32 22, float %{{.*}})
15+
%dx.frac = call float @llvm.dx.frac.f32(float %a)
1716
ret float %dx.frac
1817
}
1918

20-
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
21-
declare float @llvm.dx.frac.f32(float) #1
22-
23-
; Function Attrs: noinline nounwind optnone
24-
define noundef half @frac_half(half noundef %a) #0 {
19+
define noundef <4 x float> @frac_float4(<4 x float> noundef %a) #0 {
2520
entry:
26-
%a.addr = alloca half, align 2
27-
store half %a, ptr %a.addr, align 2
28-
%0 = load half, ptr %a.addr, align 2
29-
%dx.frac = call half @llvm.dx.frac.f16(half %0)
30-
ret half %dx.frac
21+
; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i64 0
22+
; CHECK: [[ie0:%.*]] = call float @dx.op.unary.f32(i32 22, float [[ee0]])
23+
; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i64 1
24+
; CHECK: [[ie1:%.*]] = call float @dx.op.unary.f32(i32 22, float [[ee1]])
25+
; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i64 2
26+
; CHECK: [[ie2:%.*]] = call float @dx.op.unary.f32(i32 22, float [[ee2]])
27+
; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i64 3
28+
; CHECK: [[ie3:%.*]] = call float @dx.op.unary.f32(i32 22, float [[ee3]])
29+
; CHECK: insertelement <4 x float> poison, float [[ie0]], i64 0
30+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ie1]], i64 1
31+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ie2]], i64 2
32+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ie3]], i64 3
33+
%2 = call <4 x float> @llvm.dx.frac.v4f32(<4 x float> %a)
34+
ret <4 x float> %2
3135
}
36+
37+
declare half @llvm.dx.frac.f16(half)
38+
declare float @llvm.dx.frac.f32(float)
39+
declare <4 x float> @llvm.dx.frac.v4f32(<4 x float>)

llvm/test/CodeGen/DirectX/rsqrt.ll

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,42 @@
1-
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
1+
; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
22

33
; Make sure dxil operation function calls for rsqrt are generated for float and half.
44

55
; CHECK-LABEL: rsqrt_float
6-
; CHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
76
define noundef float @rsqrt_float(float noundef %a) {
87
entry:
9-
%a.addr = alloca float, align 4
10-
store float %a, ptr %a.addr, align 4
11-
%0 = load float, ptr %a.addr, align 4
12-
%dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
8+
; CHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
9+
%dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %a)
1310
ret float %dx.rsqrt
1411
}
1512

1613
; CHECK-LABEL: rsqrt_half
17-
; CHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
1814
define noundef half @rsqrt_half(half noundef %a) {
1915
entry:
20-
%a.addr = alloca half, align 2
21-
store half %a, ptr %a.addr, align 2
22-
%0 = load half, ptr %a.addr, align 2
23-
%dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
16+
; CHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
17+
%dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %a)
2418
ret half %dx.rsqrt
2519
}
2620

21+
define noundef <4 x float> @rsqrt_float4(<4 x float> noundef %a) #0 {
22+
entry:
23+
; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i64 0
24+
; CHECK: [[ie0:%.*]] = call float @dx.op.unary.f32(i32 25, float [[ee0]])
25+
; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i64 1
26+
; CHECK: [[ie1:%.*]] = call float @dx.op.unary.f32(i32 25, float [[ee1]])
27+
; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i64 2
28+
; CHECK: [[ie2:%.*]] = call float @dx.op.unary.f32(i32 25, float [[ee2]])
29+
; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i64 3
30+
; CHECK: [[ie3:%.*]] = call float @dx.op.unary.f32(i32 25, float [[ee3]])
31+
; CHECK: insertelement <4 x float> poison, float [[ie0]], i64 0
32+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ie1]], i64 1
33+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ie2]], i64 2
34+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ie3]], i64 3
35+
%2 = call <4 x float> @llvm.dx.rsqrt.v4f32(<4 x float> %a)
36+
ret <4 x float> %2
37+
}
38+
39+
2740
declare half @llvm.dx.rsqrt.f16(half)
2841
declare float @llvm.dx.rsqrt.f32(float)
42+
declare <4 x float> @llvm.dx.rsqrt.v4f32(<4 x float>)

0 commit comments

Comments
 (0)