11
11
#include " gc/Dialect/LLVMIR/XeVMDialect.h"
12
12
#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13
13
#include " mlir/Conversion/LLVMCommon/Pattern.h"
14
+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
14
15
#include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15
16
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
16
17
#include " mlir/Pass/Pass.h"
@@ -53,6 +54,8 @@ static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
53
54
false , true , false , {}};
54
55
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
55
56
false , true , true , {}};
57
+ static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
58
+ true , true , true , {}};
56
59
57
60
std::string getTypeMangling (Type ty, bool isUnsigned = false ) {
58
61
return TypeSwitch<Type, std::string>(ty)
@@ -79,6 +82,31 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
79
82
});
80
83
}
81
84
85
+ std::string mangle (StringRef baseName, ArrayRef<Type> types,
86
+ ArrayRef<bool > isUnsigned = {}) {
87
+ assert ((isUnsigned.empty () || isUnsigned.size () == types.size ()) &&
88
+ " Signedness info doesn't match" );
89
+ std::string s;
90
+ llvm::raw_string_ostream os (s);
91
+ llvm::SmallDenseMap<Type, unsigned > substitutions;
92
+ os << " _Z" << baseName.size () << baseName;
93
+ for (auto [idx, type] : llvm::enumerate (types)) {
94
+ auto it = substitutions.find (type);
95
+ if (it != substitutions.end ()) {
96
+ os << " S" ;
97
+ // First substitution is `S_`, second is `S0_`, and so on.
98
+ if (unsigned firstIdx = it->getSecond (); firstIdx > 0 )
99
+ os << firstIdx - 1 ;
100
+ os << " _" ;
101
+ } else {
102
+ if (!type.isIntOrFloat ())
103
+ substitutions[type] = substitutions.size ();
104
+ os << getTypeMangling (type, isUnsigned.empty () ? false : isUnsigned[idx]);
105
+ }
106
+ }
107
+ return os.str ();
108
+ }
109
+
82
110
template <typename OpType>
83
111
static std::optional<ArrayAttr>
84
112
getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op,
@@ -115,13 +143,15 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
115
143
return rewriter.getArrayAttr (combinedAttrs);
116
144
}
117
145
118
- static LLVM::CallOp
119
- createDeviceFunctionCall (ConversionPatternRewriter &rewriter,
120
- StringRef funcName, Type retType,
121
- ArrayRef<Type> argTypes, ArrayRef<Value> args,
122
- ArrayRef<std::pair<unsigned , StringRef>> paramAttrs,
123
- LLVMFuncAttributeOptions funcAttributeOptions) {
124
- auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
146
+ static LLVM::CallOp createDeviceFunctionCall (
147
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
148
+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
149
+ mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
150
+ LLVMFuncAttributeOptions funcAttributeOptions) {
151
+ auto moduleOp = rewriter.getBlock ()
152
+ ->getParentOp ()
153
+ ->getParentWithTrait <OpTrait::SymbolTable>();
154
+ assert (moduleOp && " Expecting module" );
125
155
MLIRContext *ctx = rewriter.getContext ();
126
156
Location loc = UnknownLoc::get (ctx);
127
157
@@ -144,6 +174,96 @@ createDeviceFunctionCall(ConversionPatternRewriter &rewriter,
144
174
return callOp;
145
175
}
146
176
177
+ class DPASToOCLPattern : public OpConversionPattern <xevm::DPASOp> {
178
+ using OpConversionPattern::OpConversionPattern;
179
+ LogicalResult
180
+ matchAndRewrite (xevm::DPASOp op, xevm::DPASOp::Adaptor adaptor,
181
+ ConversionPatternRewriter &rewriter) const override {
182
+ constexpr uint32_t bitWidthPackedA{16 };
183
+ constexpr uint32_t bitWidthPackedB{32 };
184
+ auto loc = op.getLoc ();
185
+
186
+ auto castIfNeeded = [&](Value val, Type packedType) -> Value {
187
+ VectorType origTy = cast<VectorType>(val.getType ());
188
+ const uint32_t vecBitSize =
189
+ origTy.getNumElements () *
190
+ origTy.getElementType ().getIntOrFloatBitWidth ();
191
+ VectorType newTy = VectorType::get (
192
+ vecBitSize / packedType.getIntOrFloatBitWidth (), packedType);
193
+ if (origTy != newTy)
194
+ val = rewriter.create <LLVM::BitcastOp>(loc, newTy, val);
195
+ return val;
196
+ };
197
+
198
+ Value a = op.getA ();
199
+ Type packedAType = (op.getPa () == xevm::PrecisionType::TF32)
200
+ ? cast<Type>(rewriter.getF32Type ())
201
+ : rewriter.getIntegerType (bitWidthPackedA);
202
+ a = castIfNeeded (a, packedAType);
203
+
204
+ Value b = op.getB ();
205
+ Type packedBType = (op.getPb () == xevm::PrecisionType::TF32)
206
+ ? cast<Type>(rewriter.getF32Type ())
207
+ : rewriter.getIntegerType (bitWidthPackedB);
208
+ b = castIfNeeded (b, packedBType);
209
+
210
+ Value c = op.getC ();
211
+ VectorType cOrigTy = cast<VectorType>(c.getType ());
212
+ assert (cOrigTy == op->getResultTypes ()[0 ] &&
213
+ " Accumulator and result type mismatch" );
214
+ // OCL builtins encode bfloat16 as int16
215
+ VectorType cTy =
216
+ cOrigTy.getElementType ().isBF16 ()
217
+ ? VectorType::get (cOrigTy.getShape (), rewriter.getIntegerType (16 ))
218
+ : cOrigTy;
219
+ if (cOrigTy != cTy)
220
+ c = rewriter.create <LLVM::BitcastOp>(loc, cTy, c);
221
+
222
+ constexpr int32_t systolicDepth{8 };
223
+ std::string fnName =
224
+ llvm::formatv (" intel_sub_group_{0}_{1}_matrix_mad_k{2}" ,
225
+ stringifyPrecisionType (op.getPa ()).str (),
226
+ stringifyPrecisionType (op.getPb ()).str (),
227
+ systolicDepth * getNumOperandsPerDword (op.getPa ()))
228
+ .str ();
229
+ SmallVector<Type> argTypes{a.getType (), b.getType (), cTy};
230
+ fnName = mangle (fnName, argTypes);
231
+ SmallVector<Value> args{a, b, c};
232
+
233
+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
234
+ /* other=*/ LLVM::ModRefInfo::NoModRef,
235
+ /* argMem=*/ LLVM::ModRefInfo::NoModRef,
236
+ /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
237
+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
238
+ funcAttrs.memEffectsAttr = memAttr;
239
+ Value result = createDeviceFunctionCall (rewriter, fnName, cTy, argTypes,
240
+ args, {}, funcAttrs)
241
+ ->getResult (0 );
242
+
243
+ if (cOrigTy != cTy)
244
+ result = rewriter.create <LLVM::BitcastOp>(loc, cOrigTy, result);
245
+
246
+ rewriter.replaceOp (op, result);
247
+ return success ();
248
+ }
249
+
250
+ private:
251
+ static unsigned getNumOperandsPerDword (xevm::PrecisionType pTy) {
252
+ switch (pTy) {
253
+ case xevm::PrecisionType::TF32:
254
+ return 1 ;
255
+ case xevm::PrecisionType::BF16:
256
+ case xevm::PrecisionType::FP16:
257
+ return 2 ;
258
+ case xevm::PrecisionType::U8:
259
+ case xevm::PrecisionType::S8:
260
+ return 4 ;
261
+ default :
262
+ llvm_unreachable (" unsupported xevm::PrecisionType" );
263
+ }
264
+ }
265
+ };
266
+
147
267
template <typename OpType>
148
268
class LoadStorePrefetchToOCLPattern : public OpConversionPattern <OpType> {
149
269
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -291,10 +411,11 @@ struct ConvertXeVMToLLVMPass
291
411
// ===----------------------------------------------------------------------===//
292
412
293
413
void mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
294
- patterns.add <LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
295
- LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
296
- LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>>(
297
- patterns.getContext ());
414
+ patterns
415
+ .add <LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
416
+ LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
417
+ LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, DPASToOCLPattern>(
418
+ patterns.getContext ());
298
419
}
299
420
300
421
// ===----------------------------------------------------------------------===//
0 commit comments