@@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
121
121
ConversionPatternRewriter &rewriter) const override ;
122
122
};
123
123
124
+ class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
125
+ public:
126
+ using OpConversionPattern::OpConversionPattern;
127
+
128
+ LogicalResult
129
+ matchAndRewrite (gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
130
+ ConversionPatternRewriter &rewriter) const override ;
131
+ };
132
+
124
133
} // namespace
125
134
126
135
// ===----------------------------------------------------------------------===//
@@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final
597
606
}
598
607
};
599
608
609
+ // Formulate a unique variable/constant name after
610
+ // searching in the module for existing variable/constant names.
611
+ // This is to avoid name collision with existing variables.
612
+ // Example: printfMsg0, printfMsg1, printfMsg2, ...
613
+ static std::string makeVarName (spirv::ModuleOp moduleOp, llvm::Twine prefix) {
614
+ std::string name;
615
+ unsigned number = 0 ;
616
+
617
+ do {
618
+ name.clear ();
619
+ name = (prefix + llvm::Twine (number++)).str ();
620
+ } while (moduleOp.lookupSymbol (name));
621
+
622
+ return name;
623
+ }
624
+
625
+ // / Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
626
+
627
+ LogicalResult GPUPrintfConversion::matchAndRewrite (
628
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
629
+ ConversionPatternRewriter &rewriter) const {
630
+
631
+ Location loc = gpuPrintfOp.getLoc ();
632
+
633
+ auto moduleOp = gpuPrintfOp->getParentOfType <spirv::ModuleOp>();
634
+ if (!moduleOp)
635
+ return failure ();
636
+
637
+ // SPIR-V global variable is used to initialize printf
638
+ // format string value, if there are multiple printf messages,
639
+ // each global var needs to be created with a unique name.
640
+ std::string globalVarName = makeVarName (moduleOp, llvm::Twine (" printfMsg" ));
641
+ spirv::GlobalVariableOp globalVar;
642
+
643
+ IntegerType i8Type = rewriter.getI8Type ();
644
+ IntegerType i32Type = rewriter.getI32Type ();
645
+
646
+ // Each character of printf format string is
647
+ // stored as a spec constant. We need to create
648
+ // unique name for this spec constant like
649
+ // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
650
+ // for existing spec constant names.
651
+ auto createSpecConstant = [&](unsigned value) {
652
+ auto attr = rewriter.getI8IntegerAttr (value);
653
+ std::string specCstName =
654
+ makeVarName (moduleOp, llvm::Twine (globalVarName) + " _sc" );
655
+
656
+ return rewriter.create <spirv::SpecConstantOp>(
657
+ loc, rewriter.getStringAttr (specCstName), attr);
658
+ };
659
+ {
660
+ Operation *parent =
661
+ SymbolTable::getNearestSymbolTable (gpuPrintfOp->getParentOp ());
662
+
663
+ ConversionPatternRewriter::InsertionGuard guard (rewriter);
664
+
665
+ Block &entryBlock = *parent->getRegion (0 ).begin ();
666
+ rewriter.setInsertionPointToStart (
667
+ &entryBlock); // insertion point at module level
668
+
669
+ // Create Constituents with SpecConstant by scanning format string
670
+ // Each character of format string is stored as a spec constant
671
+ // and then these spec constants are used to create a
672
+ // SpecConstantCompositeOp.
673
+ llvm::SmallString<20 > formatString (adaptor.getFormat ());
674
+ formatString.push_back (' \0 ' ); // Null terminate for C.
675
+ SmallVector<Attribute, 4 > constituents;
676
+ for (char c : formatString) {
677
+ spirv::SpecConstantOp cSpecConstantOp = createSpecConstant (c);
678
+ constituents.push_back (SymbolRefAttr::get (cSpecConstantOp));
679
+ }
680
+
681
+ // Create SpecConstantCompositeOp to initialize the global variable
682
+ size_t contentSize = constituents.size ();
683
+ auto globalType = spirv::ArrayType::get (i8Type, contentSize);
684
+ spirv::SpecConstantCompositeOp specCstComposite;
685
+ // There will be one SpecConstantCompositeOp per printf message/global var,
686
+ // so no need do lookup for existing ones.
687
+ std::string specCstCompositeName =
688
+ (llvm::Twine (globalVarName) + " _scc" ).str ();
689
+
690
+ specCstComposite = rewriter.create <spirv::SpecConstantCompositeOp>(
691
+ loc, TypeAttr::get (globalType),
692
+ rewriter.getStringAttr (specCstCompositeName),
693
+ rewriter.getArrayAttr (constituents));
694
+
695
+ auto ptrType = spirv::PointerType::get (
696
+ globalType, spirv::StorageClass::UniformConstant);
697
+
698
+ // Define a GlobalVarOp initialized using specialized constants
699
+ // that is used to specify the printf format string
700
+ // to be passed to the SPIRV CLPrintfOp.
701
+ globalVar = rewriter.create <spirv::GlobalVariableOp>(
702
+ loc, ptrType, globalVarName, FlatSymbolRefAttr::get (specCstComposite));
703
+
704
+ globalVar->setAttr (" Constant" , rewriter.getUnitAttr ());
705
+ }
706
+ // Get SSA value of Global variable and create pointer to i8 to point to
707
+ // the format string.
708
+ Value globalPtr = rewriter.create <spirv::AddressOfOp>(loc, globalVar);
709
+ Value fmtStr = rewriter.create <spirv::BitcastOp>(
710
+ loc,
711
+ spirv::PointerType::get (i8Type, spirv::StorageClass::UniformConstant),
712
+ globalPtr);
713
+
714
+ // Get printf arguments.
715
+ auto printfArgs = llvm::to_vector_of<Value, 4 >(adaptor.getArgs ());
716
+
717
+ rewriter.create <spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
718
+
719
+ // Need to erase the gpu.printf op as gpu.printf does not use result vs
720
+ // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
721
+ // printf op.
722
+ rewriter.eraseOp (gpuPrintfOp);
723
+
724
+ return success ();
725
+ }
726
+
600
727
// ===----------------------------------------------------------------------===//
601
728
// GPU To SPIRV Patterns.
602
729
// ===----------------------------------------------------------------------===//
@@ -620,5 +747,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
620
747
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
621
748
spirv::BuiltIn::SubgroupSize>,
622
749
WorkGroupSizeConversion, GPUAllReduceConversion,
623
- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext ());
750
+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
751
+ patterns.getContext ());
624
752
}
0 commit comments