11
11
#include " gc/Dialect/LLVMIR/XeVMDialect.h"
12
12
#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13
13
#include " mlir/Conversion/LLVMCommon/Pattern.h"
14
+ #include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14
15
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
15
16
#include " mlir/Pass/Pass.h"
16
17
#include " mlir/Support/LLVM.h"
18
+ #include " llvm/Support/FormatVariadic.h"
19
+
20
+ #include " mlir/IR/BuiltinTypes.h"
21
+ #include " mlir/IR/Types.h"
22
+
23
+ #include " llvm/ADT/STLExtras.h"
24
+ #include " llvm/ADT/TypeSwitch.h"
25
+ #include " llvm/Support/raw_ostream.h"
17
26
18
27
#define DEBUG_TYPE " xevm-to-llvm"
19
28
@@ -26,6 +35,231 @@ using namespace mlir;
26
35
using namespace xevm ;
27
36
28
37
namespace {
38
+ struct LLVMFuncAttributeOptions {
39
+ bool isConvergent = false ;
40
+ bool isNoUnwind = false ;
41
+ bool isWillReturn = false ;
42
+ LLVM::MemoryEffectsAttr memEffectsAttr{};
43
+ };
44
+ // static constexpr LLVMFuncAttributeOptions convergentAttrs = {
45
+ // true, false, false, {}};
46
+ // static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
47
+ // false, true, false, {}};
48
+ static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
49
+ false , true , true , {}};
50
+ // static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs =
51
+ // {
52
+ // true, true, true, {}};
53
+
54
+ std::string getTypeMangling (Type ty, bool isUnsigned = false ) {
55
+ return TypeSwitch<Type, std::string>(ty)
56
+ .Case ([isUnsigned](VectorType ty) -> std::string {
57
+ return " Dv" + std::to_string (ty.getNumElements ()) + " _" +
58
+ getTypeMangling (ty.getElementType (), isUnsigned);
59
+ })
60
+ .Case ([](Float16Type) -> std::string { return " Dh" ; })
61
+ .Case ([](Float32Type) -> std::string { return " f" ; })
62
+ .Case ([](Float64Type) -> std::string { return " d" ; })
63
+ .Case ([isUnsigned](IntegerType ty) -> std::string {
64
+ switch (ty.getWidth ()) {
65
+ case 8 :
66
+ return isUnsigned ? " h" : " c" ;
67
+ case 16 :
68
+ return isUnsigned ? " t" : " s" ;
69
+ case 32 :
70
+ return isUnsigned ? " j" : " i" ;
71
+ case 64 :
72
+ return isUnsigned ? " m" : " l" ;
73
+ default :
74
+ llvm_unreachable (" unhandled integer type" );
75
+ }
76
+ });
77
+ }
78
+
79
+ template <typename OpType>
80
+ static std::optional<ArrayAttr>
81
+ getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op,
82
+ const bool isLoad) {
83
+ if ((op.getL1CacheControlAttr () ==
84
+ xevm::L1StoreCacheControlAttr::get (
85
+ rewriter.getContext (), xevm::L1StoreCacheControl::DEFAULT) &&
86
+ op.getL3CacheControlAttr () ==
87
+ xevm::L3StoreCacheControlAttr::get (
88
+ rewriter.getContext (), xevm::L3StoreCacheControl::DEFAULT)) ||
89
+
90
+ (op.getL1CacheControlAttr () ==
91
+ xevm::L1LoadCacheControlAttr::get (
92
+ rewriter.getContext (), xevm::L1LoadCacheControl::DEFAULT) &&
93
+ op.getL3CacheControlAttr () ==
94
+ xevm::L3LoadCacheControlAttr::get (
95
+ rewriter.getContext (), xevm::L3LoadCacheControl::DEFAULT))) {
96
+ return {};
97
+ }
98
+ constexpr int32_t decorationCacheControlArity{4 };
99
+ constexpr int32_t loadCacheControlKey{6442 };
100
+ constexpr int32_t storeCacheControlKey{6443 };
101
+ constexpr int32_t l1Level{0 };
102
+ constexpr int32_t l3Level{1 };
103
+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
104
+ SmallVector<int32_t , decorationCacheControlArity> decorationsL1{
105
+ controlKey, l1Level, static_cast <int32_t >(op.getL1CacheControl ()), 0 };
106
+ SmallVector<int32_t , decorationCacheControlArity> decorationsL3{
107
+ controlKey, l3Level, static_cast <int32_t >(op.getL3CacheControl ()), 0 };
108
+ auto arrayAttrL1 = rewriter.getI32ArrayAttr (decorationsL1);
109
+ auto arrayAttrL3 = rewriter.getI32ArrayAttr (decorationsL3);
110
+
111
+ SmallVector<Attribute, 2 > combinedAttrs = {arrayAttrL1, arrayAttrL3};
112
+ return rewriter.getArrayAttr (combinedAttrs);
113
+ }
114
+
115
+ static LLVM::CallOp createDeviceFunctionCall (
116
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
117
+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
118
+ mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
119
+ LLVMFuncAttributeOptions funcAttributeOptions) {
120
+ auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
121
+ MLIRContext *ctx = rewriter.getContext ();
122
+ Location loc = UnknownLoc::get (ctx);
123
+
124
+ LLVM::LLVMFuncOp funcOp =
125
+ LLVM::lookupOrCreateFn (moduleOp, funcName, argTypes, retType);
126
+ funcOp.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
127
+ funcOp.setConvergent (funcAttributeOptions.isConvergent );
128
+ funcOp.setNoUnwind (funcAttributeOptions.isNoUnwind );
129
+ funcOp.setWillReturn (funcAttributeOptions.isWillReturn );
130
+
131
+ if (funcAttributeOptions.memEffectsAttr )
132
+ funcOp.setMemoryEffectsAttr (funcAttributeOptions.memEffectsAttr );
133
+
134
+ for (auto [idx, attrName] : paramAttrs)
135
+ funcOp.setArgAttr (idx, attrName, rewriter.getUnitAttr ());
136
+
137
+ // if (!passthroughAttrs.getFnAttributes().empty())
138
+ // funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx));
139
+
140
+ auto callOp = rewriter.create <LLVM::CallOp>(loc, funcOp, args);
141
+ callOp->setAttrs (funcOp->getAttrs ());
142
+
143
+ return callOp;
144
+ }
145
+
146
+ template <typename OpType>
147
+ class LoadStorePrefetchNdToOCLPattern : public OpConversionPattern <OpType> {
148
+ using OpConversionPattern<OpType>::OpConversionPattern;
149
+ LogicalResult
150
+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
151
+ ConversionPatternRewriter &rewriter) const override {
152
+ constexpr bool isLoad = std::is_same_v<OpType, xevm::BlockLoad2dOp>;
153
+ constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStore2dOp>;
154
+ constexpr bool isPrefetch = std::is_same_v<OpType, xevm::BlockPrefetch2dOp>;
155
+ auto loc = op.getLoc ();
156
+ VectorType vecType;
157
+ if constexpr (isLoad) {
158
+ vecType = op.getRes ().getType ();
159
+ } else if constexpr (isStore) {
160
+ vecType = op.getStoredVal ().getType ();
161
+ }
162
+
163
+ auto i32Type = rewriter.getI32Type ();
164
+ bool vnni = false ;
165
+ bool transpose = false ;
166
+ if constexpr (isLoad) {
167
+ vnni = op.getVnniTransform ();
168
+ transpose = op.getTranspose ();
169
+ }
170
+
171
+ Value byteCoord =
172
+ rewriter.create <LLVM::UndefOp>(loc, VectorType::get (2 , i32Type));
173
+ Value zero = rewriter.create <LLVM::ConstantOp>(
174
+ loc, i32Type, rewriter.getI32IntegerAttr (0 ));
175
+ Value one = rewriter.create <LLVM::ConstantOp>(
176
+ loc, i32Type, rewriter.getI32IntegerAttr (1 ));
177
+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
178
+ loc, VectorType::get (2 , i32Type), byteCoord, op.getX (), zero);
179
+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
180
+ loc, VectorType::get (2 , i32Type), byteCoord, op.getY (), one);
181
+ SmallVector<Value> args{op.getPtr (), op.getBaseWidth (), op.getBaseHeight (),
182
+ op.getBasePitch (), byteCoord};
183
+ SmallVector<Type> retTypes;
184
+ Value spvLoadDstPtr;
185
+ std::string funcName, bitWidthId;
186
+ SmallVector<std::pair<unsigned , mlir::StringRef>, 4 > paramAttrs;
187
+ if constexpr (isPrefetch) { // Prefetch
188
+ funcName = " intel_sub_group_2d_block_prefetch" ;
189
+ paramAttrs = {std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ())};
190
+ } else {
191
+ auto vecElemType = vecType.getElementType ();
192
+ auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth ();
193
+ Value numElems = rewriter.create <LLVM::ConstantOp>(
194
+ loc, i32Type, vecType.getNumElements ());
195
+ auto dstOrSrcPtr = rewriter.create <LLVM::AllocaOp>(
196
+ loc, LLVM::LLVMPointerType::get (rewriter.getContext ()), vecElemType,
197
+ numElems);
198
+ args.push_back (dstOrSrcPtr);
199
+ if constexpr (isLoad) { // Load
200
+ funcName = " intel_sub_group_2d_block_read" ;
201
+ bitWidthId = getTypeMangling (vecElemType, /* isUnsigned=*/ true );
202
+ if (vnni)
203
+ funcName += " _transform" ;
204
+ else if (transpose)
205
+ funcName += " _transpose" ;
206
+ spvLoadDstPtr = dstOrSrcPtr;
207
+ retTypes.push_back (vecType);
208
+ paramAttrs = {
209
+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
210
+ std::make_pair (0 , LLVM::LLVMDialect::getReadonlyAttrName ()),
211
+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
212
+ std::make_pair (5 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
213
+ };
214
+ } else { // Store
215
+ funcName = " intel_sub_group_2d_block_write" ;
216
+ bitWidthId = (vecElemBitWidth == 32 )
217
+ ? " j"
218
+ : ((vecElemBitWidth == 16 ) ? " t" : " h" );
219
+ rewriter.create <LLVM::StoreOp>(loc, op.getStoredVal (), dstOrSrcPtr);
220
+ paramAttrs = {
221
+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
222
+ std::make_pair (0 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
223
+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
224
+ std::make_pair (5 , LLVM::LLVMDialect::getReadonlyAttrName ()),
225
+ };
226
+ }
227
+ }
228
+
229
+ // !X = !{i32 %decoration_kind%, i32 %level%, i32 %control%, i32 %operand of
230
+ // the instruction to decorate%}
231
+ funcName =
232
+ llvm::formatv (" {0}_{1}b_{2}r{3}x{4}c" , funcName, op.getElemSizeInBits (),
233
+ op.getTileHeight (), op.getTileWidth (), op.getVBlocks ())
234
+ .str ();
235
+ funcName = llvm::formatv (" _Z{0}{1}PU3AS1viiiDv2_i{2}{3}" , funcName.size (),
236
+ funcName, isPrefetch ? " " : " P" , bitWidthId)
237
+ .str ();
238
+ SmallVector<Type> argTypes;
239
+ for (auto arg : args) {
240
+ argTypes.push_back (arg.getType ());
241
+ }
242
+ LLVM::CallOp call = createDeviceFunctionCall (
243
+ rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
244
+ argTypes, args, paramAttrs, noUnwindWillReturnAttrs);
245
+ if (std::optional<ArrayAttr> optCacheControls =
246
+ getCacheControlMetadata (rewriter, op, isLoad || isPrefetch)) {
247
+ call->setAttr (xevm::XeVMDialect::getCacheControlsAttrName (),
248
+ *optCacheControls);
249
+ }
250
+ if constexpr (isLoad)
251
+ rewriter.replaceOp (
252
+ op, rewriter.create <LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
253
+ else
254
+ rewriter.eraseOp (op);
255
+ return success ();
256
+ }
257
+ };
258
+
259
+ // ===----------------------------------------------------------------------===//
260
+ // Pass Definition
261
+ // ===----------------------------------------------------------------------===//
262
+
29
263
struct ConvertXeVMToLLVMPass
30
264
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
31
265
using Base::Base;
@@ -37,19 +271,51 @@ struct ConvertXeVMToLLVMPass
37
271
void runOnOperation () override {
38
272
ConversionTarget target (getContext ());
39
273
target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
40
- RewritePatternSet pattern (&getContext ());
41
- mlir::populateXeVMToLLVMConversionPatterns (pattern);
42
- if (failed (
43
- applyPartialConversion (getOperation (), target, std::move (pattern))))
274
+ target.addIllegalDialect <xevm::XeVMDialect>();
275
+ RewritePatternSet patterns (&getContext ());
276
+ mlir::populateXeVMToLLVMConversionPatterns (patterns);
277
+ if (failed (applyPartialConversion (getOperation (), target,
278
+ std::move (patterns))))
44
279
signalPassFailure ();
45
280
}
46
281
};
47
282
} // namespace
48
283
284
+ // ===----------------------------------------------------------------------===//
285
+ // Pattern Population
286
+ // ===----------------------------------------------------------------------===//
287
+
49
288
void mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
50
- /* TODO*/
289
+ patterns.add <LoadStorePrefetchNdToOCLPattern<xevm::BlockLoad2dOp>,
290
+ LoadStorePrefetchNdToOCLPattern<xevm::BlockStore2dOp>,
291
+ LoadStorePrefetchNdToOCLPattern<xevm::BlockPrefetch2dOp>>(
292
+ patterns.getContext ());
51
293
}
52
294
295
+ // ===----------------------------------------------------------------------===//
296
+ // ConvertToLLVMPatternInterface implementation
297
+ // ===----------------------------------------------------------------------===//
298
+
299
+ namespace {
300
+ // / Implement the interface to convert XeVM to LLVM.
301
+ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
302
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
303
+ void loadDependentDialects (MLIRContext *context) const final {
304
+ context->loadDialect <LLVM::LLVMDialect>();
305
+ }
306
+
307
+ // / Hook for derived dialect interface to provide conversion patterns
308
+ // / and mark dialect legal for the conversion target.
309
+ void populateConvertToLLVMConversionPatterns (
310
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
311
+ RewritePatternSet &patterns) const final {
312
+ populateXeVMToLLVMConversionPatterns (patterns);
313
+ }
314
+ };
315
+ } // namespace
316
+
53
317
void mlir::registerConvertXeVMToLLVMInterface (DialectRegistry ®istry) {
54
- /* TODO*/
318
+ registry.addExtension (+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
319
+ dialect->addInterfaces <XeVMToLLVMDialectInterface>();
320
+ });
55
321
}
0 commit comments