Skip to content

Commit 3b9094d

Browse files
committed
add initial conversions for xevm block ops
1 parent 238706b commit 3b9094d

File tree

9 files changed

+1033
-8
lines changed

9 files changed

+1033
-8
lines changed

include/gc/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef GC_CONVERSION_PASSES_H
1010
#define GC_CONVERSION_PASSES_H
1111

12-
#include "gc/Conversion/XeVMToLLVM.h"
12+
#include "gc/Conversion/XeVMToLLVM/XeVMToLLVM.h"
1313

1414
namespace mlir {
1515

include/gc/Dialect/LLVMIR/XeVMOps.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def XeVM_Dialect : Dialect {
1919
let name = "xevm";
2020
let cppNamespace = "::mlir::xevm";
2121
let dependentDialects = ["LLVM::LLVMDialect"];
22+
23+
let extraClassDeclaration = [{
24+
/// Get the name for the attribute used to specify cache control
25+
/// decorations.
26+
static constexpr ::llvm::StringRef getCacheControlsAttrName() {
27+
return ::llvm::StringLiteral("xevm.DecorationCacheControlINTEL");
28+
}
29+
}];
30+
2231
let useDefaultAttributePrinterParser = 1;
2332
}
2433

@@ -161,6 +170,52 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
161170
let hasVerifier = 1;
162171
}
163172

173+
def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
174+
Arguments<(ins
175+
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
176+
I32:$base_width,
177+
I32:$base_height,
178+
I32:$base_pitch,
179+
I32:$x,
180+
I32:$y,
181+
I32Attr:$elem_size_in_bits,
182+
I32Attr:$tile_width,
183+
I32Attr:$tile_height,
184+
I32Attr:$v_blocks,
185+
DefaultValuedAttr<XeVM_L1LoadCacheControl, "::mlir::xevm::L1LoadCacheControl::DEFAULT">:$l1_cache_control,
186+
DefaultValuedAttr<XeVM_L3LoadCacheControl, "::mlir::xevm::L3LoadCacheControl::DEFAULT">:$l3_cache_control
187+
)> {
188+
189+
let summary = "2D block prefetch";
190+
191+
let description = [{
192+
The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
193+
from a larger matrix residing in memory. The parameters are:
194+
$ptr - the base address of the matrix containing the tile to prefetch
195+
$base_width, $base_height, $base_pitch - the shape of the matrix
196+
$x, $y, $tile_width, $tile_height - the starting offsets and shape of tile to prefetch
197+
$elem_size_in_bits - the size in bits of the matrix element
198+
- 32 for f32, bf32
199+
- 16 for f16, int16, bf16
200+
- 8 for int8, int4, int2
201+
$v_blocks - number of tiles to prefetch
202+
$cache_control - an enumerator that sets the L1 and L3 cache behaviour
203+
204+
Notes:
205+
- coordinate is provided in elements, while width and pitch are provided in bytes.
206+
}];
207+
208+
let assemblyFormat = [{
209+
operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,`
210+
`tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,`
211+
`l3_cache_control` `=` $l3_cache_control `}`
212+
attr-dict `:` `(` type(operands) `)`
213+
}];
214+
215+
let hasVerifier = 1;
216+
}
217+
218+
164219
def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
165220
let description = [{
166221
GPU target attribute for controlling compilation of targets. All

lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 272 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,18 @@
1111
#include "gc/Dialect/LLVMIR/XeVMDialect.h"
1212
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313
#include "mlir/Conversion/LLVMCommon/Pattern.h"
14+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Pass/Pass.h"
1617
#include "mlir/Support/LLVM.h"
18+
#include "llvm/Support/FormatVariadic.h"
19+
20+
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/Types.h"
22+
23+
#include "llvm/ADT/STLExtras.h"
24+
#include "llvm/ADT/TypeSwitch.h"
25+
#include "llvm/Support/raw_ostream.h"
1726

1827
#define DEBUG_TYPE "xevm-to-llvm"
1928

@@ -26,6 +35,231 @@ using namespace mlir;
2635
using namespace xevm;
2736

2837
namespace {
38+
struct LLVMFuncAttributeOptions {
39+
bool isConvergent = false;
40+
bool isNoUnwind = false;
41+
bool isWillReturn = false;
42+
LLVM::MemoryEffectsAttr memEffectsAttr{};
43+
};
44+
// static constexpr LLVMFuncAttributeOptions convergentAttrs = {
45+
// true, false, false, {}};
46+
// static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
47+
// false, true, false, {}};
48+
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
49+
false, true, true, {}};
50+
// static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs =
51+
// {
52+
// true, true, true, {}};
53+
54+
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
55+
return TypeSwitch<Type, std::string>(ty)
56+
.Case([isUnsigned](VectorType ty) -> std::string {
57+
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
58+
getTypeMangling(ty.getElementType(), isUnsigned);
59+
})
60+
.Case([](Float16Type) -> std::string { return "Dh"; })
61+
.Case([](Float32Type) -> std::string { return "f"; })
62+
.Case([](Float64Type) -> std::string { return "d"; })
63+
.Case([isUnsigned](IntegerType ty) -> std::string {
64+
switch (ty.getWidth()) {
65+
case 8:
66+
return isUnsigned ? "h" : "c";
67+
case 16:
68+
return isUnsigned ? "t" : "s";
69+
case 32:
70+
return isUnsigned ? "j" : "i";
71+
case 64:
72+
return isUnsigned ? "m" : "l";
73+
default:
74+
llvm_unreachable("unhandled integer type");
75+
}
76+
});
77+
}
78+
79+
template <typename OpType>
80+
static std::optional<ArrayAttr>
81+
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
82+
const bool isLoad) {
83+
if ((op.getL1CacheControlAttr() ==
84+
xevm::L1StoreCacheControlAttr::get(
85+
rewriter.getContext(), xevm::L1StoreCacheControl::DEFAULT) &&
86+
op.getL3CacheControlAttr() ==
87+
xevm::L3StoreCacheControlAttr::get(
88+
rewriter.getContext(), xevm::L3StoreCacheControl::DEFAULT)) ||
89+
90+
(op.getL1CacheControlAttr() ==
91+
xevm::L1LoadCacheControlAttr::get(
92+
rewriter.getContext(), xevm::L1LoadCacheControl::DEFAULT) &&
93+
op.getL3CacheControlAttr() ==
94+
xevm::L3LoadCacheControlAttr::get(
95+
rewriter.getContext(), xevm::L3LoadCacheControl::DEFAULT))) {
96+
return {};
97+
}
98+
constexpr int32_t decorationCacheControlArity{4};
99+
constexpr int32_t loadCacheControlKey{6442};
100+
constexpr int32_t storeCacheControlKey{6443};
101+
constexpr int32_t l1Level{0};
102+
constexpr int32_t l3Level{1};
103+
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
104+
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
105+
controlKey, l1Level, static_cast<int32_t>(op.getL1CacheControl()), 0};
106+
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
107+
controlKey, l3Level, static_cast<int32_t>(op.getL3CacheControl()), 0};
108+
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
109+
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
110+
111+
SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
112+
return rewriter.getArrayAttr(combinedAttrs);
113+
}
114+
115+
static LLVM::CallOp createDeviceFunctionCall(
116+
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
117+
ArrayRef<Type> argTypes, ArrayRef<Value> args,
118+
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
119+
LLVMFuncAttributeOptions funcAttributeOptions) {
120+
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
121+
MLIRContext *ctx = rewriter.getContext();
122+
Location loc = UnknownLoc::get(ctx);
123+
124+
LLVM::LLVMFuncOp funcOp =
125+
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType);
126+
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
127+
funcOp.setConvergent(funcAttributeOptions.isConvergent);
128+
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
129+
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
130+
131+
if (funcAttributeOptions.memEffectsAttr)
132+
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
133+
134+
for (auto [idx, attrName] : paramAttrs)
135+
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
136+
137+
// if (!passthroughAttrs.getFnAttributes().empty())
138+
// funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx));
139+
140+
auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
141+
callOp->setAttrs(funcOp->getAttrs());
142+
143+
return callOp;
144+
}
145+
146+
template <typename OpType>
147+
class LoadStorePrefetchNdToOCLPattern : public OpConversionPattern<OpType> {
148+
using OpConversionPattern<OpType>::OpConversionPattern;
149+
LogicalResult
150+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
151+
ConversionPatternRewriter &rewriter) const override {
152+
constexpr bool isLoad = std::is_same_v<OpType, xevm::BlockLoad2dOp>;
153+
constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStore2dOp>;
154+
constexpr bool isPrefetch = std::is_same_v<OpType, xevm::BlockPrefetch2dOp>;
155+
auto loc = op.getLoc();
156+
VectorType vecType;
157+
if constexpr (isLoad) {
158+
vecType = op.getRes().getType();
159+
} else if constexpr (isStore) {
160+
vecType = op.getStoredVal().getType();
161+
}
162+
163+
auto i32Type = rewriter.getI32Type();
164+
bool vnni = false;
165+
bool transpose = false;
166+
if constexpr (isLoad) {
167+
vnni = op.getVnniTransform();
168+
transpose = op.getTranspose();
169+
}
170+
171+
Value byteCoord =
172+
rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
173+
Value zero = rewriter.create<LLVM::ConstantOp>(
174+
loc, i32Type, rewriter.getI32IntegerAttr(0));
175+
Value one = rewriter.create<LLVM::ConstantOp>(
176+
loc, i32Type, rewriter.getI32IntegerAttr(1));
177+
byteCoord = rewriter.create<LLVM::InsertElementOp>(
178+
loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
179+
byteCoord = rewriter.create<LLVM::InsertElementOp>(
180+
loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
181+
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
182+
op.getBasePitch(), byteCoord};
183+
SmallVector<Type> retTypes;
184+
Value spvLoadDstPtr;
185+
std::string funcName, bitWidthId;
186+
SmallVector<std::pair<unsigned, mlir::StringRef>, 4> paramAttrs;
187+
if constexpr (isPrefetch) { // Prefetch
188+
funcName = "intel_sub_group_2d_block_prefetch";
189+
paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
190+
} else {
191+
auto vecElemType = vecType.getElementType();
192+
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
193+
Value numElems = rewriter.create<LLVM::ConstantOp>(
194+
loc, i32Type, vecType.getNumElements());
195+
auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
196+
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
197+
numElems);
198+
args.push_back(dstOrSrcPtr);
199+
if constexpr (isLoad) { // Load
200+
funcName = "intel_sub_group_2d_block_read";
201+
bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
202+
if (vnni)
203+
funcName += "_transform";
204+
else if (transpose)
205+
funcName += "_transpose";
206+
spvLoadDstPtr = dstOrSrcPtr;
207+
retTypes.push_back(vecType);
208+
paramAttrs = {
209+
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
210+
std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
211+
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
212+
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
213+
};
214+
} else { // Store
215+
funcName = "intel_sub_group_2d_block_write";
216+
bitWidthId = (vecElemBitWidth == 32)
217+
? "j"
218+
: ((vecElemBitWidth == 16) ? "t" : "h");
219+
rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
220+
paramAttrs = {
221+
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
222+
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
223+
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
224+
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
225+
};
226+
}
227+
}
228+
229+
// !X = !{i32 %decoration_kind%, i32 %level%, i32 %control%, i32 %operand of
230+
// the instruction to decorate%}
231+
funcName =
232+
llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
233+
op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
234+
.str();
235+
funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
236+
funcName, isPrefetch ? "" : "P", bitWidthId)
237+
.str();
238+
SmallVector<Type> argTypes;
239+
for (auto arg : args) {
240+
argTypes.push_back(arg.getType());
241+
}
242+
LLVM::CallOp call = createDeviceFunctionCall(
243+
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
244+
argTypes, args, paramAttrs, noUnwindWillReturnAttrs);
245+
if (std::optional<ArrayAttr> optCacheControls =
246+
getCacheControlMetadata(rewriter, op, isLoad || isPrefetch)) {
247+
call->setAttr(xevm::XeVMDialect::getCacheControlsAttrName(),
248+
*optCacheControls);
249+
}
250+
if constexpr (isLoad)
251+
rewriter.replaceOp(
252+
op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
253+
else
254+
rewriter.eraseOp(op);
255+
return success();
256+
}
257+
};
258+
259+
//===----------------------------------------------------------------------===//
260+
// Pass Definition
261+
//===----------------------------------------------------------------------===//
262+
29263
struct ConvertXeVMToLLVMPass
30264
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
31265
using Base::Base;
@@ -37,19 +271,51 @@ struct ConvertXeVMToLLVMPass
37271
void runOnOperation() override {
38272
ConversionTarget target(getContext());
39273
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
40-
RewritePatternSet pattern(&getContext());
41-
mlir::populateXeVMToLLVMConversionPatterns(pattern);
42-
if (failed(
43-
applyPartialConversion(getOperation(), target, std::move(pattern))))
274+
target.addIllegalDialect<xevm::XeVMDialect>();
275+
RewritePatternSet patterns(&getContext());
276+
mlir::populateXeVMToLLVMConversionPatterns(patterns);
277+
if (failed(applyPartialConversion(getOperation(), target,
278+
std::move(patterns))))
44279
signalPassFailure();
45280
}
46281
};
47282
} // namespace
48283

284+
//===----------------------------------------------------------------------===//
285+
// Pattern Population
286+
//===----------------------------------------------------------------------===//
287+
49288
void mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
50-
/*TODO*/
289+
patterns.add<LoadStorePrefetchNdToOCLPattern<xevm::BlockLoad2dOp>,
290+
LoadStorePrefetchNdToOCLPattern<xevm::BlockStore2dOp>,
291+
LoadStorePrefetchNdToOCLPattern<xevm::BlockPrefetch2dOp>>(
292+
patterns.getContext());
51293
}
52294

295+
//===----------------------------------------------------------------------===//
296+
// ConvertToLLVMPatternInterface implementation
297+
//===----------------------------------------------------------------------===//
298+
299+
namespace {
300+
/// Implement the interface to convert XeVM to LLVM.
301+
struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
302+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
303+
void loadDependentDialects(MLIRContext *context) const final {
304+
context->loadDialect<LLVM::LLVMDialect>();
305+
}
306+
307+
/// Hook for derived dialect interface to provide conversion patterns
308+
/// and mark dialect legal for the conversion target.
309+
void populateConvertToLLVMConversionPatterns(
310+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
311+
RewritePatternSet &patterns) const final {
312+
populateXeVMToLLVMConversionPatterns(patterns);
313+
}
314+
};
315+
} // namespace
316+
53317
void mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
54-
/*TODO*/
318+
registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
319+
dialect->addInterfaces<XeVMToLLVMDialectInterface>();
320+
});
55321
}

0 commit comments

Comments
 (0)