Skip to content

Commit 3f39571

Browse files
authored
[DirectX][DXIL] Distinguish return type for overload type resolution. (#85646)
Return type of DXIL Ops may be different from valid overload type of the parameters, if any. Such DXIL Ops are correctly represented in DXIL.td. However, DXILEmitter assumes the return type to be the same as parameter overload type, if one exists. This results in generation in incorrect overload index value in DXILOperation.inc for the DXIL Op and incorrect DXIL operation function call in DXILOpLowering pass. This change distinguishes return types correctly from parameter overload types in DXILEmitter backend to handle such DXIL ops. Add specification for DXIL Op `isinf` and corresponding tests to verify the above change. Fixes issue #85125
1 parent 891172d commit 3f39571

File tree

7 files changed

+83
-30
lines changed

7 files changed

+83
-30
lines changed

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ class DXILOpMapping<int opCode, DXILOpClass opClass,
255255
}
256256

257257
// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
258+
def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
259+
"Determines if the specified value is infinite.",
260+
[llvm_i1_ty, llvm_halforfloat_ty]>;
258261
def Sin : DXILOpMapping<13, unary, int_sin,
259262
"Returns sine(theta) for theta in radians.",
260263
[llvm_halforfloat_ty, LLVMMatchType<0>]>;

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,13 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
229229
/// its specification in DXIL.td.
230230
/// \param OverloadTy Return type to be used to construct DXIL function type.
231231
static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
232-
Type *OverloadTy) {
232+
Type *ReturnTy, Type *OverloadTy) {
233233
SmallVector<Type *> ArgTys;
234234

235235
auto ParamKinds = getOpCodeParameterKind(*Prop);
236236

237-
// Add OverloadTy as return type of the function
238-
ArgTys.emplace_back(OverloadTy);
237+
// Add ReturnTy as return type of the function
238+
ArgTys.emplace_back(ReturnTy);
239239

240240
// Add DXIL Opcode value type viz., Int32 as first argument
241241
ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
@@ -249,34 +249,33 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
249249
ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
250250
}
251251

252-
static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
253-
Type *OverloadTy, Module &M) {
254-
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
252+
namespace llvm {
253+
namespace dxil {
254+
255+
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
256+
Type *OverloadTy,
257+
llvm::iterator_range<Use *> Args) {
258+
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
255259

256260
OverloadKind Kind = getOverloadKind(OverloadTy);
257261
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
258262
report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
259263
}
260264

261-
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
262-
// Dependent on name to dedup.
263-
if (auto *Fn = M.getFunction(FnName))
264-
return FunctionCallee(Fn);
265-
266-
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
267-
return M.getOrInsertFunction(FnName, DXILOpFT);
268-
}
269-
270-
namespace llvm {
271-
namespace dxil {
272-
273-
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
274-
llvm::iterator_range<Use *> Args) {
275-
auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
265+
std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
266+
FunctionCallee DXILFn;
267+
// Get the function with name DXILFnName, if one exists
268+
if (auto *Func = M.getFunction(DXILFnName)) {
269+
DXILFn = FunctionCallee(Func);
270+
} else {
271+
// Construct and add a function with name DXILFnName
272+
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
273+
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
274+
}
276275
SmallVector<Value *> FullArgs;
277276
FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
278277
FullArgs.append(Args.begin(), Args.end());
279-
return B.CreateCall(Fn, FullArgs);
278+
return B.CreateCall(DXILFn, FullArgs);
280279
}
281280

282281
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@ namespace dxil {
2929
class DXILOpBuilder {
3030
public:
3131
DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
32-
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
32+
/// Create an instruction that calls DXIL Op with return type, specified
33+
/// opcode, and call arguments. \param OpCode Opcode of the DXIL Op call
34+
/// constructed \param ReturnTy Return type of the DXIL Op call constructed
35+
/// \param OverloadTy Overload type of the DXIL Op call constructed
36+
/// \return DXIL Op call constructed
37+
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
38+
Type *OverloadTy,
3339
llvm::iterator_range<Use *> Args);
3440
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
3541
static const char *getOpCodeName(dxil::OpCode DXILOp);

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ using namespace llvm::dxil;
3232

3333
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
3434
IRBuilder<> B(M.getContext());
35-
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
3635
DXILOpBuilder DXILB(M, B);
3736
Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
3837
for (User *U : make_early_inc_range(F.users())) {
3938
CallInst *CI = dyn_cast<CallInst>(U);
4039
if (!CI)
4140
continue;
4241

43-
SmallVector<Value *> Args;
44-
Args.emplace_back(DXILOpArg);
45-
Args.append(CI->arg_begin(), CI->arg_end());
4642
B.SetInsertPoint(CI);
47-
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
43+
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
44+
OverloadTy, CI->args());
4845

4946
CI->replaceAllUsesWith(DXILCI);
5047
CI->eraseFromParent();

llvm/test/CodeGen/DirectX/isinf.ll

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
2+
3+
; Make sure dxil operation function calls for isinf are generated for float and half.
4+
; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
5+
; CHECK: call i1 @dx.op.isSpecialFloat.f16(i32 9, half %{{.*}})
6+
7+
; Function Attrs: noinline nounwind optnone
8+
define noundef i1 @isinf_float(float noundef %a) #0 {
9+
entry:
10+
%a.addr = alloca float, align 4
11+
store float %a, ptr %a.addr, align 4
12+
%0 = load float, ptr %a.addr, align 4
13+
%dx.isinf = call i1 @llvm.dx.isinf.f32(float %0)
14+
ret i1 %dx.isinf
15+
}
16+
17+
; Function Attrs: noinline nounwind optnone
18+
define noundef i1 @isinf_half(half noundef %p0) #0 {
19+
entry:
20+
%p0.addr = alloca half, align 2
21+
store half %p0, ptr %p0.addr, align 2
22+
%0 = load half, ptr %p0.addr, align 2
23+
%dx.isinf = call i1 @llvm.dx.isinf.f16(half %0)
24+
ret i1 %dx.isinf
25+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; DXIL operation isinf does not support double overload type
4+
; CHECK: LLVM ERROR: Invalid Overload Type
5+
6+
define noundef i1 @isinf_double(double noundef %a) #0 {
7+
entry:
8+
%a.addr = alloca double, align 8
9+
store double %a, ptr %a.addr, align 8
10+
%0 = load double, ptr %a.addr, align 8
11+
%dx.isinf = call i1 @llvm.dx.isinf.f64(double %0)
12+
ret i1 %dx.isinf
13+
}

llvm/utils/TableGen/DXILEmitter.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
119119
// Populate OpTypes with return type and parameter types
120120

121121
// Parameter indices of overloaded parameters.
122-
// This vector contains overload parameters in the order order used to
122+
// This vector contains overload parameters in the order used to
123123
// resolve an LLVMMatchType in accordance with convention outlined in
124124
// the comment before the definition of class LLVMMatchType in
125125
// llvm/IR/Intrinsics.td
@@ -398,10 +398,20 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
398398

399399
OS << " static const OpCodeProperty OpCodeProps[] = {\n";
400400
for (auto &Op : Ops) {
401+
// Consider Op.OverloadParamIndex as the overload parameter index, by
402+
// default
403+
auto OLParamIdx = Op.OverloadParamIndex;
404+
// If no overload parameter index is set, treat first parameter type as
405+
// overload type - unless the Op has no parameters, in which case treat the
406+
// return type - as overload parameter to emit the appropriate overload kind
407+
// enum.
408+
if (OLParamIdx < 0) {
409+
OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
410+
}
401411
OS << " { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
402412
<< ", OpCodeClass::" << Op.OpClass << ", "
403413
<< OpClassStrings.get(Op.OpClass.data()) << ", "
404-
<< getOverloadKindStr(Op.OpTypes[0]) << ", "
414+
<< getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
405415
<< emitDXILOperationAttr(Op.OpAttributes) << ", "
406416
<< Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
407417
<< Parameters.get(ParameterMap[Op.OpClass]) << " },\n";

0 commit comments

Comments
 (0)