Skip to content

Commit 0a2aaab

Browse files
authored
[SPIRV] Implement log10 for logical SPIR-V (#66921)
There is no log10 instruction in the GLSL Extended Instruction Set so to implement the HLSL log10 intrinsic when targeting Vulkan this change adds the logic to derive the result using the following formula: ``` log10(x) = log2(x) * (1 / log2(10)) = log2(x) * 0.30103 ```
1 parent 469b9cb commit 0a2aaab

File tree

4 files changed

+116
-14
lines changed

4 files changed

+116
-14
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

+14-11
Original file line numberDiff line numberDiff line change
@@ -244,24 +244,27 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
244244
MachineIRBuilder &MIRBuilder,
245245
SPIRVType *SpvType) {
246246
auto &MF = MIRBuilder.getMF();
247-
const Type *LLVMFPTy;
248-
if (SpvType) {
249-
LLVMFPTy = getTypeForSPIRVType(SpvType);
250-
assert(LLVMFPTy->isFloatingPointTy());
251-
} else {
252-
LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
247+
auto &Ctx = MF.getFunction().getContext();
248+
if (!SpvType) {
249+
const Type *LLVMFPTy = Type::getFloatTy(Ctx);
250+
SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
253251
}
254252
// Find a constant in DT or build a new one.
255-
const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
253+
const auto ConstFP = ConstantFP::get(Ctx, Val);
256254
Register Res = DT.find(ConstFP, &MF);
257255
if (!Res.isValid()) {
258-
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
259-
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
256+
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
260257
MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
261-
assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
258+
assignSPIRVTypeToVReg(SpvType, Res, MF);
262259
DT.add(ConstFP, &MF, Res);
263-
MIRBuilder.buildFConstant(Res, *ConstFP);
260+
261+
MachineInstrBuilder MIB;
262+
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
263+
.addDef(Res)
264+
.addUse(getSPIRVTypeID(SpvType));
265+
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
264266
}
267+
265268
return Res;
266269
}
267270

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

+55-1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
174174
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
175175
MachineInstr &I, const ExtInstList &ExtInsts) const;
176176

177+
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
178+
MachineInstr &I) const;
179+
177180
Register buildI32Constant(uint32_t Val, MachineInstr &I,
178181
const SPIRVType *ResType = nullptr) const;
179182

@@ -362,7 +365,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
362365
case TargetOpcode::G_FLOG2:
363366
return selectExtInst(ResVReg, ResType, I, CL::log2, GL::Log2);
364367
case TargetOpcode::G_FLOG10:
365-
return selectExtInst(ResVReg, ResType, I, CL::log10);
368+
return selectLog10(ResVReg, ResType, I);
366369

367370
case TargetOpcode::G_FABS:
368371
return selectExtInst(ResVReg, ResType, I, CL::fabs, GL::FAbs);
@@ -1562,6 +1565,57 @@ bool SPIRVInstructionSelector::selectGlobalValue(
15621565
return Reg.isValid();
15631566
}
15641567

1568+
bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
1569+
const SPIRVType *ResType,
1570+
MachineInstr &I) const {
1571+
if (STI.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
1572+
return selectExtInst(ResVReg, ResType, I, CL::log10);
1573+
}
1574+
1575+
// There is no log10 instruction in the GLSL Extended Instruction set, so it
1576+
// is implemented as:
1577+
// log10(x) = log2(x) * (1 / log2(10))
1578+
// = log2(x) * 0.30103
1579+
1580+
MachineIRBuilder MIRBuilder(I);
1581+
MachineBasicBlock &BB = *I.getParent();
1582+
1583+
// Build log2(x).
1584+
Register VarReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1585+
bool Result =
1586+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
1587+
.addDef(VarReg)
1588+
.addUse(GR.getSPIRVTypeID(ResType))
1589+
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
1590+
.addImm(GL::Log2)
1591+
.add(I.getOperand(1))
1592+
.constrainAllUses(TII, TRI, RBI);
1593+
1594+
// Build 0.30103.
1595+
assert(ResType->getOpcode() == SPIRV::OpTypeVector ||
1596+
ResType->getOpcode() == SPIRV::OpTypeFloat);
1597+
// TODO: Add matrix implementation once supported by the HLSL frontend.
1598+
const SPIRVType *SpirvScalarType =
1599+
ResType->getOpcode() == SPIRV::OpTypeVector
1600+
? GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg())
1601+
: ResType;
1602+
Register ScaleReg =
1603+
GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType);
1604+
1605+
// Multiply log2(x) by 0.30103 to get log10(x) result.
1606+
auto Opcode = ResType->getOpcode() == SPIRV::OpTypeVector
1607+
? SPIRV::OpVectorTimesScalar
1608+
: SPIRV::OpFMulS;
1609+
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
1610+
.addDef(ResVReg)
1611+
.addUse(GR.getSPIRVTypeID(ResType))
1612+
.addUse(VarReg)
1613+
.addUse(ScaleReg)
1614+
.constrainAllUses(TII, TRI, RBI);
1615+
1616+
return Result;
1617+
}
1618+
15651619
namespace llvm {
15661620
InstructionSelector *
15671621
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
229229
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230230
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231231

232+
// TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
233+
// tighten these requirements. Many of these math functions are only legal on
234+
// specific bitwidths, so they are not selectable for
235+
// allFloatScalarsAndVectors.
232236
getActionDefinitionsBuilder({G_FPOW,
233237
G_FEXP,
234238
G_FEXP2,
235239
G_FLOG,
236240
G_FLOG2,
241+
G_FLOG10,
237242
G_FABS,
238243
G_FMINNUM,
239244
G_FMAXNUM,
@@ -259,8 +264,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
259264
allFloatScalarsAndVectors, allIntScalarsAndVectors);
260265

261266
if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
262-
getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);
263-
264267
getActionDefinitionsBuilder(
265268
{G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
266269
.legalForCartesianProduct(allIntScalarsAndVectors,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s
2+
3+
; CHECK: %[[#extinst:]] = OpExtInstImport "GLSL.std.450"
4+
5+
; CHECK: %[[#float:]] = OpTypeFloat 32
6+
; CHECK: %[[#v4float:]] = OpTypeVector %[[#float]] 4
7+
; CHECK: %[[#float_0_30103001:]] = OpConstant %[[#float]] 0.30103000998497009
8+
; CHECK: %[[#_ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
9+
; CHECK: %[[#_ptr_Function_float:]] = OpTypePointer Function %[[#float]]
10+
11+
define void @main() {
12+
entry:
13+
; CHECK: %[[#f:]] = OpVariable %[[#_ptr_Function_float]] Function
14+
; CHECK: %[[#logf:]] = OpVariable %[[#_ptr_Function_float]] Function
15+
; CHECK: %[[#f4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
16+
; CHECK: %[[#logf4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
17+
%f = alloca float, align 4
18+
%logf = alloca float, align 4
19+
%f4 = alloca <4 x float>, align 16
20+
%logf4 = alloca <4 x float>, align 16
21+
22+
; CHECK: %[[#load:]] = OpLoad %[[#float]] %[[#f]] Aligned 4
23+
; CHECK: %[[#log2:]] = OpExtInst %[[#float]] %[[#extinst]] Log2 %[[#load]]
24+
; CHECK: %[[#res:]] = OpFMul %[[#float]] %[[#log2]] %[[#float_0_30103001]]
25+
; CHECK: OpStore %[[#logf]] %[[#res]] Aligned 4
26+
%0 = load float, ptr %f, align 4
27+
%elt.log10 = call float @llvm.log10.f32(float %0)
28+
store float %elt.log10, ptr %logf, align 4
29+
30+
; CHECK: %[[#load:]] = OpLoad %[[#v4float]] %[[#f4]] Aligned 16
31+
; CHECK: %[[#log2:]] = OpExtInst %[[#v4float]] %[[#extinst]] Log2 %[[#load]]
32+
; CHECK: %[[#res:]] = OpVectorTimesScalar %[[#v4float]] %[[#log2]] %[[#float_0_30103001]]
33+
; CHECK: OpStore %[[#logf4]] %[[#res]] Aligned 16
34+
%1 = load <4 x float>, ptr %f4, align 16
35+
%elt.log101 = call <4 x float> @llvm.log10.v4f32(<4 x float> %1)
36+
store <4 x float> %elt.log101, ptr %logf4, align 16
37+
38+
ret void
39+
}
40+
41+
declare float @llvm.log10.f32(float)
42+
declare <4 x float> @llvm.log10.v4f32(<4 x float>)

0 commit comments

Comments
 (0)