Skip to content

Commit 1be438f

Browse files
committed
Add xevm integration tests
1 parent 6c407ec commit 1be438f

File tree

16 files changed

+678
-42
lines changed

16 files changed

+678
-42
lines changed

include/gc/Dialect/LLVMIR/XeVMOps.td

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,76 @@ def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
221221
let hasVerifier = 1;
222222
}
223223

224+
def XeVM_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;
225+
226+
/// Enum attribute of the different precision types.
227+
def XeVM_PrecisionTypeAttr : I32EnumAttr<"PrecisionType",
228+
"XeVM precision type",
229+
[
230+
I32EnumAttrCase<"UNUSED", 0, "unused">,
231+
I32EnumAttrCase<"U8", 1, "u8">,
232+
I32EnumAttrCase<"U4", 2, "u4">,
233+
I32EnumAttrCase<"U2", 3, "u2">,
234+
I32EnumAttrCase<"S8", 4, "i8">,
235+
I32EnumAttrCase<"S4", 5, "i4">,
236+
I32EnumAttrCase<"S2", 6, "i2">,
237+
I32EnumAttrCase<"BF8", 7, "bf8">,
238+
I32EnumAttrCase<"TF32", 8, "tf32">,
239+
I32EnumAttrCase<"BF16", 9, "bf16">,
240+
I32EnumAttrCase<"FP16", 10, "f16">
241+
]> {
242+
let cppNamespace = "::mlir::xevm";
243+
}
244+
245+
def XeVM_DPASOp : XeVM_Op<"dpas">,
246+
Results<(outs FixedVectorOf<[XeVM_MatrixElemType]>:$d)>,
247+
Arguments<(ins
248+
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$c,
249+
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
250+
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
251+
XeVM_PrecisionTypeAttr:$pa,
252+
XeVM_PrecisionTypeAttr:$pb,
253+
I32Attr:$rc
254+
)> {
255+
256+
let summary = "Matrix multiply-add";
257+
258+
let description = [{
259+
The `xevm.dpas` operation is a matrix multiplication plus accumulation:
260+
261+
D = C + A x B
262+
263+
where the A, B, C input matrices and the result D have shapes:
264+
D : MxN
265+
C : MxN
266+
A : MxK
267+
B : KxN
268+
269+
Shape restrictions:
270+
M : must be 1, 2, 4, or 8
271+
N : fixed execution size, must be 16
272+
K : systolic_depth * OPS_PER_CHAN
273+
OPS_PER_CHAN
274+
1 : for TF32
275+
2 : for 16-bit precision(BF, HF)
276+
4 : for 8-bit precision (FP8, UB, B)
277+
8 : for less-then 8 bit precision (U4/S4, U2/S2).
278+
279+
If systolic_depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN).
280+
$a, $b, $c, $d - matrix A, B, C, D, respectively
281+
$pa, $pb - precision of matrix A and B resepectively
282+
$rc - repeat count
283+
284+
Further restrictions as well as more details can be found here:
285+
https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
286+
}];
287+
288+
let assemblyFormat = [{
289+
operands ` ` `{` `pa` `=` $pa `,` `pb` `=` $pb `,` `rc` `=` $rc `}` attr-dict `:` functional-type(operands, results)
290+
}];
291+
292+
// let hasVerifier = 1;
293+
}
224294

225295
def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
226296
let description = [{

include/gc/ExecutionEngine/Driver/Driver.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace mlir {
1818
class DialectRegistry;
1919
namespace gc {
2020

21-
const DialectRegistry &initCompilerAndGetDialects();
21+
DialectRegistry &initCompilerAndGetDialects();
2222

2323
// the pointers to XXXMemRefType
2424
using GeneralMemrefPtr = void *;

lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 132 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "gc/Dialect/LLVMIR/XeVMDialect.h"
1212
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313
#include "mlir/Conversion/LLVMCommon/Pattern.h"
14+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1415
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/Pass/Pass.h"
@@ -53,6 +54,8 @@ static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
5354
false, true, false, {}};
5455
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
5556
false, true, true, {}};
57+
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
58+
true, true, true, {}};
5659

5760
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
5861
return TypeSwitch<Type, std::string>(ty)
@@ -79,6 +82,31 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
7982
});
8083
}
8184

85+
std::string mangle(StringRef baseName, ArrayRef<Type> types,
86+
ArrayRef<bool> isUnsigned = {}) {
87+
assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
88+
"Signedness info doesn't match");
89+
std::string s;
90+
llvm::raw_string_ostream os(s);
91+
llvm::SmallDenseMap<Type, unsigned> substitutions;
92+
os << "_Z" << baseName.size() << baseName;
93+
for (auto [idx, type] : llvm::enumerate(types)) {
94+
auto it = substitutions.find(type);
95+
if (it != substitutions.end()) {
96+
os << "S";
97+
// First substitution is `S_`, second is `S0_`, and so on.
98+
if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
99+
os << firstIdx - 1;
100+
os << "_";
101+
} else {
102+
if (!type.isIntOrFloat())
103+
substitutions[type] = substitutions.size();
104+
os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
105+
}
106+
}
107+
return os.str();
108+
}
109+
82110
template <typename OpType>
83111
static std::optional<ArrayAttr>
84112
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
@@ -115,13 +143,15 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
115143
return rewriter.getArrayAttr(combinedAttrs);
116144
}
117145

118-
static LLVM::CallOp
119-
createDeviceFunctionCall(ConversionPatternRewriter &rewriter,
120-
StringRef funcName, Type retType,
121-
ArrayRef<Type> argTypes, ArrayRef<Value> args,
122-
ArrayRef<std::pair<unsigned, StringRef>> paramAttrs,
123-
LLVMFuncAttributeOptions funcAttributeOptions) {
124-
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
146+
static LLVM::CallOp createDeviceFunctionCall(
147+
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
148+
ArrayRef<Type> argTypes, ArrayRef<Value> args,
149+
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
150+
LLVMFuncAttributeOptions funcAttributeOptions) {
151+
auto moduleOp = rewriter.getBlock()
152+
->getParentOp()
153+
->getParentWithTrait<OpTrait::SymbolTable>();
154+
assert(moduleOp && "Expecting module");
125155
MLIRContext *ctx = rewriter.getContext();
126156
Location loc = UnknownLoc::get(ctx);
127157

@@ -144,6 +174,96 @@ createDeviceFunctionCall(ConversionPatternRewriter &rewriter,
144174
return callOp;
145175
}
146176

177+
class DPASToOCLPattern : public OpConversionPattern<xevm::DPASOp> {
178+
using OpConversionPattern::OpConversionPattern;
179+
LogicalResult
180+
matchAndRewrite(xevm::DPASOp op, xevm::DPASOp::Adaptor adaptor,
181+
ConversionPatternRewriter &rewriter) const override {
182+
constexpr uint32_t bitWidthPackedA{16};
183+
constexpr uint32_t bitWidthPackedB{32};
184+
auto loc = op.getLoc();
185+
186+
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
187+
VectorType origTy = cast<VectorType>(val.getType());
188+
const uint32_t vecBitSize =
189+
origTy.getNumElements() *
190+
origTy.getElementType().getIntOrFloatBitWidth();
191+
VectorType newTy = VectorType::get(
192+
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
193+
if (origTy != newTy)
194+
val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
195+
return val;
196+
};
197+
198+
Value a = op.getA();
199+
Type packedAType = (op.getPa() == xevm::PrecisionType::TF32)
200+
? cast<Type>(rewriter.getF32Type())
201+
: rewriter.getIntegerType(bitWidthPackedA);
202+
a = castIfNeeded(a, packedAType);
203+
204+
Value b = op.getB();
205+
Type packedBType = (op.getPb() == xevm::PrecisionType::TF32)
206+
? cast<Type>(rewriter.getF32Type())
207+
: rewriter.getIntegerType(bitWidthPackedB);
208+
b = castIfNeeded(b, packedBType);
209+
210+
Value c = op.getC();
211+
VectorType cOrigTy = cast<VectorType>(c.getType());
212+
assert(cOrigTy == op->getResultTypes()[0] &&
213+
"Accumulator and result type mismatch");
214+
// OCL builtins encode bfloat16 as int16
215+
VectorType cTy =
216+
cOrigTy.getElementType().isBF16()
217+
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
218+
: cOrigTy;
219+
if (cOrigTy != cTy)
220+
c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
221+
222+
constexpr int32_t systolicDepth{8};
223+
std::string fnName =
224+
llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
225+
stringifyPrecisionType(op.getPa()).str(),
226+
stringifyPrecisionType(op.getPb()).str(),
227+
systolicDepth * getNumOperandsPerDword(op.getPa()))
228+
.str();
229+
SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
230+
fnName = mangle(fnName, argTypes);
231+
SmallVector<Value> args{a, b, c};
232+
233+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
234+
/*other=*/LLVM::ModRefInfo::NoModRef,
235+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
236+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
237+
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
238+
funcAttrs.memEffectsAttr = memAttr;
239+
Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
240+
args, {}, funcAttrs)
241+
->getResult(0);
242+
243+
if (cOrigTy != cTy)
244+
result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
245+
246+
rewriter.replaceOp(op, result);
247+
return success();
248+
}
249+
250+
private:
251+
static unsigned getNumOperandsPerDword(xevm::PrecisionType pTy) {
252+
switch (pTy) {
253+
case xevm::PrecisionType::TF32:
254+
return 1;
255+
case xevm::PrecisionType::BF16:
256+
case xevm::PrecisionType::FP16:
257+
return 2;
258+
case xevm::PrecisionType::U8:
259+
case xevm::PrecisionType::S8:
260+
return 4;
261+
default:
262+
llvm_unreachable("unsupported xevm::PrecisionType");
263+
}
264+
}
265+
};
266+
147267
template <typename OpType>
148268
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
149269
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -291,10 +411,11 @@ struct ConvertXeVMToLLVMPass
291411
//===----------------------------------------------------------------------===//
292412

293413
void mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
294-
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
295-
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
296-
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>>(
297-
patterns.getContext());
414+
patterns
415+
.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
416+
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
417+
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, DPASToOCLPattern>(
418+
patterns.getContext());
298419
}
299420

300421
//===----------------------------------------------------------------------===//

lib/gc/ExecutionEngine/Driver/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ else()
2525
MLIRToLLVMIRTranslationRegistration
2626
)
2727
endif()
28+
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
2829

2930
set(GC_PASSES GcInterface GcPasses)
3031
if(GC_ENABLE_IMEX)
@@ -38,6 +39,8 @@ gc_add_mlir_library(GcJitWrapper
3839
${MLIR_LINK_COMPONENTS}
3940
${dialect_libs}
4041
${conversion_libs}
42+
${extension_libs}
4143
${GC_PASSES}
4244
GcAnalysis
45+
MLIRXeVMToLLVMIRTranslation
4346
)

lib/gc/ExecutionEngine/Driver/Driver.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
#ifdef GC_HAS_ONEDNN_DIALECT
1212
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
1313
#endif
14+
#include "gc/Conversion/Passes.h"
15+
#include "gc/Target/LLVM/XeVM/Target.h"
16+
#include "gc/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
1417
#include "gc/Transforms/Passes.h"
1518
#include "mlir/InitAllDialects.h"
19+
#include "mlir/InitAllExtensions.h"
1620
#include "mlir/InitAllPasses.h"
1721
#include "mlir/Pass/PassManager.h"
1822
#include "mlir/Target/LLVMIR/Dialect/All.h"
@@ -26,22 +30,29 @@ namespace gc {
2630
static DialectRegistry initDialects() {
2731
mlir::registerAllPasses();
2832
mlir::gc::registerGraphCompilerPasses();
33+
mlir::registerGCConversionPasses();
2934
mlir::cpuruntime::registerCPURuntimePasses();
3035
mlir::DialectRegistry registry;
3136
registry.insert<mlir::cpuruntime::CPURuntimeDialect>();
3237
mlir::registerAllDialects(registry);
3338
mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry);
39+
mlir::registerAllExtensions(registry);
40+
// Adds missing `LLVMTranslationDialectInterface` registration for dialect for
41+
// gpu.module op
42+
mlir::registerAllToLLVMIRTranslations(registry);
43+
mlir::registerConvertXeVMToLLVMInterface(registry);
44+
mlir::registerXeVMDialectTranslation(registry);
45+
mlir::xevm::registerXeVMTargetInterfaceExternalModels(registry);
3446
#ifdef GC_HAS_ONEDNN_DIALECT
3547
registry.insert<mlir::onednn_graph::OneDNNGraphDialect>();
3648
#endif
3749
llvm::InitializeNativeTarget();
3850
llvm::InitializeNativeTargetAsmPrinter();
3951
llvm::InitializeNativeTargetAsmParser();
40-
mlir::registerAllToLLVMIRTranslations(registry);
4152
return registry;
4253
}
4354

44-
const DialectRegistry &initCompilerAndGetDialects() {
55+
DialectRegistry &initCompilerAndGetDialects() {
4556
static DialectRegistry reg = initDialects();
4657
return reg;
4758
}

0 commit comments

Comments
 (0)