Skip to content

Commit f8ba021

Browse files
drprajapkuhar
andauthored
[mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (#78510)
This change contains following: - adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass. - Fixes Constant decoration parsing for spirv GlobalVariable. - minor modification to spirv.CL.printf op assembly format. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 4dfed69 commit f8ba021

File tree

6 files changed

+207
-6
lines changed

6 files changed

+207
-6
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
875875
#### Example:
876876

877877
```mlir
878-
%0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
878+
%0 = spirv.CL.printf %fmt %1, %2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
879879
```
880880
}];
881881

@@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
889889
);
890890

891891
let assemblyFormat = [{
892-
$format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
892+
$format ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
893893
}];
894894

895895
let hasVerifier = 0;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
121121
ConversionPatternRewriter &rewriter) const override;
122122
};
123123

124+
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
125+
public:
126+
using OpConversionPattern::OpConversionPattern;
127+
128+
LogicalResult
129+
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
130+
ConversionPatternRewriter &rewriter) const override;
131+
};
132+
124133
} // namespace
125134

126135
//===----------------------------------------------------------------------===//
@@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final
597606
}
598607
};
599608

609+
// Formulate a unique variable/constant name after
610+
// searching in the module for existing variable/constant names.
611+
// This is to avoid name collision with existing variables.
612+
// Example: printfMsg0, printfMsg1, printfMsg2, ...
613+
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
614+
std::string name;
615+
unsigned number = 0;
616+
617+
do {
618+
name.clear();
619+
name = (prefix + llvm::Twine(number++)).str();
620+
} while (moduleOp.lookupSymbol(name));
621+
622+
return name;
623+
}
624+
625+
/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
626+
627+
LogicalResult GPUPrintfConversion::matchAndRewrite(
628+
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
629+
ConversionPatternRewriter &rewriter) const {
630+
631+
Location loc = gpuPrintfOp.getLoc();
632+
633+
auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
634+
if (!moduleOp)
635+
return failure();
636+
637+
// SPIR-V global variable is used to initialize printf
638+
// format string value, if there are multiple printf messages,
639+
// each global var needs to be created with a unique name.
640+
std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
641+
spirv::GlobalVariableOp globalVar;
642+
643+
IntegerType i8Type = rewriter.getI8Type();
644+
IntegerType i32Type = rewriter.getI32Type();
645+
646+
// Each character of printf format string is
647+
// stored as a spec constant. We need to create
648+
// unique name for this spec constant like
649+
// @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
650+
// for existing spec constant names.
651+
auto createSpecConstant = [&](unsigned value) {
652+
auto attr = rewriter.getI8IntegerAttr(value);
653+
std::string specCstName =
654+
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
655+
656+
return rewriter.create<spirv::SpecConstantOp>(
657+
loc, rewriter.getStringAttr(specCstName), attr);
658+
};
659+
{
660+
Operation *parent =
661+
SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
662+
663+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
664+
665+
Block &entryBlock = *parent->getRegion(0).begin();
666+
rewriter.setInsertionPointToStart(
667+
&entryBlock); // insertion point at module level
668+
669+
// Create Constituents with SpecConstant by scanning format string
670+
// Each character of format string is stored as a spec constant
671+
// and then these spec constants are used to create a
672+
// SpecConstantCompositeOp.
673+
llvm::SmallString<20> formatString(adaptor.getFormat());
674+
formatString.push_back('\0'); // Null terminate for C.
675+
SmallVector<Attribute, 4> constituents;
676+
for (char c : formatString) {
677+
spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
678+
constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
679+
}
680+
681+
// Create SpecConstantCompositeOp to initialize the global variable
682+
size_t contentSize = constituents.size();
683+
auto globalType = spirv::ArrayType::get(i8Type, contentSize);
684+
spirv::SpecConstantCompositeOp specCstComposite;
685+
// There will be one SpecConstantCompositeOp per printf message/global var,
686+
// so no need do lookup for existing ones.
687+
std::string specCstCompositeName =
688+
(llvm::Twine(globalVarName) + "_scc").str();
689+
690+
specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
691+
loc, TypeAttr::get(globalType),
692+
rewriter.getStringAttr(specCstCompositeName),
693+
rewriter.getArrayAttr(constituents));
694+
695+
auto ptrType = spirv::PointerType::get(
696+
globalType, spirv::StorageClass::UniformConstant);
697+
698+
// Define a GlobalVarOp initialized using specialized constants
699+
// that is used to specify the printf format string
700+
// to be passed to the SPIRV CLPrintfOp.
701+
globalVar = rewriter.create<spirv::GlobalVariableOp>(
702+
loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
703+
704+
globalVar->setAttr("Constant", rewriter.getUnitAttr());
705+
}
706+
// Get SSA value of Global variable and create pointer to i8 to point to
707+
// the format string.
708+
Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
709+
Value fmtStr = rewriter.create<spirv::BitcastOp>(
710+
loc,
711+
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
712+
globalPtr);
713+
714+
// Get printf arguments.
715+
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
716+
717+
rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
718+
719+
// Need to erase the gpu.printf op as gpu.printf does not use result vs
720+
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
721+
// printf op.
722+
rewriter.eraseOp(gpuPrintfOp);
723+
724+
return success();
725+
}
726+
600727
//===----------------------------------------------------------------------===//
601728
// GPU To SPIRV Patterns.
602729
//===----------------------------------------------------------------------===//
@@ -620,5 +747,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
620747
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
621748
spirv::BuiltIn::SubgroupSize>,
622749
WorkGroupSizeConversion, GPUAllReduceConversion,
623-
GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
750+
GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
751+
patterns.getContext());
624752
}

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
319319
case spirv::Decoration::Restrict:
320320
case spirv::Decoration::RestrictPointer:
321321
case spirv::Decoration::NoContraction:
322+
case spirv::Decoration::Constant:
322323
if (words.size() != 2) {
323324
return emitError(unknownLoc, "OpDecoration with ")
324325
<< decorationName << "needs a single target <id>";

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
286286
case spirv::Decoration::Restrict:
287287
case spirv::Decoration::RestrictPointer:
288288
case spirv::Decoration::NoContraction:
289+
case spirv::Decoration::Constant:
289290
// For unit attributes and decoration attributes, the args list
290291
// has no values so we do nothing.
291292
if (isa<UnitAttr, DecorationAttr>(attr))
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s | FileCheck %s
2+
3+
module attributes {
4+
gpu.container_module,
5+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
6+
} {
7+
func.func @main() {
8+
%c1 = arith.constant 1 : index
9+
10+
gpu.launch_func @kernels::@printf
11+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
12+
args()
13+
return
14+
}
15+
16+
gpu.module @kernels {
17+
// CHECK: spirv.module @{{.*}} Physical32 OpenCL
18+
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
19+
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
20+
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
21+
gpu.func @printf() kernel
22+
attributes
23+
{spirv.entry_point_abi = #spirv.entry_point_abi<>} {
24+
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
25+
// CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
26+
// CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : !spirv.ptr<i8, UniformConstant> -> i32
27+
gpu.printf "\nHello\n"
28+
// CHECK: spirv.Return
29+
gpu.return
30+
}
31+
}
32+
}
33+
34+
// -----
35+
36+
module attributes {
37+
gpu.container_module,
38+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
39+
} {
40+
func.func @main() {
41+
%c1 = arith.constant 1 : index
42+
%c100 = arith.constant 100: i32
43+
%cst_f32 = arith.constant 314.4: f32
44+
45+
gpu.launch_func @kernels1::@printf_args
46+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
47+
args(%c100: i32, %cst_f32: f32)
48+
return
49+
}
50+
51+
gpu.module @kernels1 {
52+
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
53+
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
54+
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
55+
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
56+
gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
57+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
58+
%0 = gpu.block_id x
59+
%1 = gpu.block_id y
60+
%2 = gpu.thread_id x
61+
62+
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
63+
// CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
64+
// CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, f32, i32 -> i32
65+
gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
66+
67+
// CHECK: spirv.Return
68+
gpu.return
69+
}
70+
}
71+
}

mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
274274
// spirv.CL.printf
275275
//===----------------------------------------------------------------------===//
276276
// CHECK-LABEL: func.func @printf(
277-
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
278-
// CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
279-
%0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
277+
func.func @printf(%fmt : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
278+
// CHECK: spirv.CL.printf {{%.*}} {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
279+
%0 = spirv.CL.printf %fmt %arg1, %arg2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
280280
return %0 : i32
281281
}
282282

0 commit comments

Comments
 (0)