Skip to content

Commit 6f9f446

Browse files
committed
add dpas test
1 parent b656a80 commit 6f9f446

File tree

5 files changed

+337
-22
lines changed

5 files changed

+337
-22
lines changed

include/gc/Dialect/LLVMIR/XeVMOps.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,73 @@ def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
221221
let hasVerifier = 1;
222222
}
223223

224+
def XeVM_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;
225+
226+
/// Enum attribute of the different precision types.
227+
def XeVM_PrecisionTypeAttr : I32EnumAttr<"PrecisionType",
228+
"XeVM precision type",
229+
[
230+
I32EnumAttrCase<"UNUSED", 0, "unused">,
231+
I32EnumAttrCase<"U8", 1, "u8">,
232+
I32EnumAttrCase<"U4", 2, "u4">,
233+
I32EnumAttrCase<"U2", 3, "u2">,
234+
I32EnumAttrCase<"S8", 4, "i8">,
235+
I32EnumAttrCase<"S4", 5, "i4">,
236+
I32EnumAttrCase<"S2", 6, "i2">,
237+
I32EnumAttrCase<"BF8", 7, "bf8">,
238+
I32EnumAttrCase<"TF32", 8, "tf32">,
239+
I32EnumAttrCase<"BF16", 9, "bf16">,
240+
I32EnumAttrCase<"FP16", 10, "f16">
241+
]> {
242+
let cppNamespace = "::mlir::xevm";
243+
}
244+
245+
def XeVM_DPASOp : XeVM_Op<"dpas">,
246+
Results<(outs FixedVectorOf<[XeVM_MatrixElemType]>:$d)>,
247+
Arguments<(ins
248+
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$c,
249+
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
250+
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
251+
XeVM_PrecisionTypeAttr:$pa,
252+
XeVM_PrecisionTypeAttr:$pb,
253+
I32Attr:$rc
254+
)> {
255+
256+
let summary = "Matrix multiply-add";
257+
258+
let description = [{
259+
The `xevm.dpas` operation is a matrix multiplication plus accumulation:
260+
261+
D = C + A x B
262+
263+
where the A, B, C input matrices and the result D have shapes:
264+
D : MxN
265+
C : MxN
266+
A : MxK
267+
B : KxN
268+
269+
M : repeat count, must be 1, 2, 4, or 8
270+
N : fixed execution size, must be 16
271+
K : depth * OPS_PER_CHAN
272+
OPS_PER_CHAN
273+
1 : for TF32
274+
2 : for 16-bit precision(BF, HF)
275+
4 : for 8-bit precision (FP8, UB, B)
276+
8 : for less-then 8 bit precision (U4/S4, U2/S2).
277+
278+
If depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN).
279+
280+
$a, $b, $c, $d - matrix A, B, C, D, respectively
281+
$pa, $pb - precision of matrix A and B resepectively
282+
$rc - repeat count
283+
}];
284+
285+
let assemblyFormat = [{
286+
operands ` ` `{` `pa` `=` $pa `,` `pb` `=` $pb `,` `rc` `=` $rc `}` attr-dict `:` functional-type(operands, results)
287+
}];
288+
289+
// let hasVerifier = 1;
290+
}
224291

225292
def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
226293
let description = [{

lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
5454
false, true, false, {}};
5555
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
5656
false, true, true, {}};
57+
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
58+
true, true, true, {}};
5759

5860
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
5961
return TypeSwitch<Type, std::string>(ty)
@@ -80,6 +82,31 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
8082
});
8183
}
8284

85+
std::string mangle(StringRef baseName, ArrayRef<Type> types,
86+
ArrayRef<bool> isUnsigned = {}) {
87+
assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
88+
"Signedness info doesn't match");
89+
std::string s;
90+
llvm::raw_string_ostream os(s);
91+
llvm::SmallDenseMap<Type, unsigned> substitutions;
92+
os << "_Z" << baseName.size() << baseName;
93+
for (auto [idx, type] : llvm::enumerate(types)) {
94+
auto it = substitutions.find(type);
95+
if (it != substitutions.end()) {
96+
os << "S";
97+
// First substitution is `S_`, second is `S0_`, and so on.
98+
if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
99+
os << firstIdx - 1;
100+
os << "_";
101+
} else {
102+
if (!type.isIntOrFloat())
103+
substitutions[type] = substitutions.size();
104+
os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
105+
}
106+
}
107+
return os.str();
108+
}
109+
83110
template <typename OpType>
84111
static std::optional<ArrayAttr>
85112
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
@@ -145,6 +172,96 @@ static LLVM::CallOp createDeviceFunctionCall(
145172
return callOp;
146173
}
147174

175+
class DPASToOCLPattern : public OpConversionPattern<xevm::DPASOp> {
176+
using OpConversionPattern::OpConversionPattern;
177+
LogicalResult
178+
matchAndRewrite(xevm::DPASOp op, xevm::DPASOp::Adaptor adaptor,
179+
ConversionPatternRewriter &rewriter) const override {
180+
constexpr uint32_t bitWidthPackedA{16};
181+
constexpr uint32_t bitWidthPackedB{32};
182+
auto loc = op.getLoc();
183+
184+
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
185+
VectorType origTy = cast<VectorType>(val.getType());
186+
const uint32_t vecBitSize =
187+
origTy.getNumElements() *
188+
origTy.getElementType().getIntOrFloatBitWidth();
189+
VectorType newTy = VectorType::get(
190+
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
191+
if (origTy != newTy)
192+
val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
193+
return val;
194+
};
195+
196+
Value a = op.getA();
197+
Type packedAType = (op.getPa() == xevm::PrecisionType::TF32)
198+
? cast<Type>(rewriter.getF32Type())
199+
: rewriter.getIntegerType(bitWidthPackedA);
200+
a = castIfNeeded(a, packedAType);
201+
202+
Value b = op.getB();
203+
Type packedBType = (op.getPb() == xevm::PrecisionType::TF32)
204+
? cast<Type>(rewriter.getF32Type())
205+
: rewriter.getIntegerType(bitWidthPackedB);
206+
b = castIfNeeded(b, packedBType);
207+
208+
Value c = op.getC();
209+
VectorType cOrigTy = cast<VectorType>(c.getType());
210+
assert(cOrigTy == op->getResultTypes()[0] &&
211+
"Accumulator and result type mismatch");
212+
// OCL builtins encode bfloat16 as int16
213+
VectorType cTy =
214+
cOrigTy.getElementType().isBF16()
215+
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
216+
: cOrigTy;
217+
if (cOrigTy != cTy)
218+
c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
219+
220+
constexpr int32_t systolicDepth{8};
221+
std::string fnName =
222+
llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
223+
stringifyPrecisionType(op.getPa()).str(),
224+
stringifyPrecisionType(op.getPb()).str(),
225+
systolicDepth * getNumOperandsPerDword(op.getPa()))
226+
.str();
227+
SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
228+
fnName = mangle(fnName, argTypes);
229+
SmallVector<Value> args{a, b, c};
230+
231+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
232+
/*other=*/LLVM::ModRefInfo::NoModRef,
233+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
234+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
235+
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
236+
funcAttrs.memEffectsAttr = memAttr;
237+
Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
238+
args, {}, funcAttrs)
239+
->getResult(0);
240+
241+
if (cOrigTy != cTy)
242+
result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
243+
244+
rewriter.replaceOp(op, result);
245+
return success();
246+
}
247+
248+
private:
249+
static unsigned getNumOperandsPerDword(xevm::PrecisionType pTy) {
250+
switch (pTy) {
251+
case xevm::PrecisionType::TF32:
252+
return 1;
253+
case xevm::PrecisionType::BF16:
254+
case xevm::PrecisionType::FP16:
255+
return 2;
256+
case xevm::PrecisionType::U8:
257+
case xevm::PrecisionType::S8:
258+
return 4;
259+
default:
260+
llvm_unreachable("unsupported xevm::PrecisionType");
261+
}
262+
}
263+
};
264+
148265
template <typename OpType>
149266
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
150267
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -292,10 +409,11 @@ struct ConvertXeVMToLLVMPass
292409
//===----------------------------------------------------------------------===//
293410

294411
void mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
295-
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
296-
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
297-
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>>(
298-
patterns.getContext());
412+
patterns
413+
.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
414+
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
415+
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, DPASToOCLPattern>(
416+
patterns.getContext());
299417
}
300418

301419
//===----------------------------------------------------------------------===//

lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,14 @@ static cl_program loadModule(GPUCLQUEUE *queue, const unsigned char *data,
358358
"-DPASTokenReduction -Xfinalizer -SWSBDepReduction -Xfinalizer "
359359
"'-printregusage -enableBCR' -cl-kernel-arg-info -x spir";
360360
}
361-
CL_SAFE_CALL(clBuildProgram(program, 0, NULL, build_flags, NULL, NULL));
361+
err = clBuildProgram(program, 1, &queue->device_, build_flags, NULL, NULL);
362+
if (err != CL_SUCCESS) {
363+
char log[10240];
364+
clGetProgramBuildInfo(program, queue->device_, CL_PROGRAM_BUILD_LOG,
365+
sizeof(log), log, nullptr);
366+
fprintf(stderr, "Build failed: %s\n", std::string(log).c_str());
367+
abort();
368+
}
362369
if (takeOwnership)
363370
queue->programs_.push_back(program);
364371
return program;
@@ -414,6 +421,12 @@ static void launchKernel(GPUCLQUEUE *queue, cl_kernel kernel, size_t gridX,
414421
}
415422
size_t globalSize[3] = {gridX * blockX, gridY * blockY, gridZ * blockZ};
416423
size_t localSize[3] = {blockX, blockY, blockZ};
424+
size_t sgSize;
425+
CL_SAFE_CALL(clGetKernelSubGroupInfo(
426+
kernel, queue->device_, CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
427+
sizeof(globalSize), &globalSize, sizeof(sgSize), &sgSize, nullptr));
428+
// printf("Kernel's sub-group size: %zu\n", sgSize);
429+
417430
CL_SAFE_CALL(clEnqueueNDRangeKernel(queue->queue_, kernel, 3, NULL,
418431
globalSize, localSize, 0, NULL, NULL));
419432
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: gc-opt %s --convert-xevm-to-llvm --xevm-attach-target --convert-scf-to-cf --convert-cf-to-llvm --convert-arith-to-llvm --convert-gpu-to-llvm-spv --gpu-to-llvm --reconcile-unrealized-casts --cse --gpu-module-to-binary | gc-cpu-runner -e main -entry-point-result=void --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s
2+
3+
module @gemm attributes {gpu.container_module} {
4+
gpu.module @kernel {
5+
// - Sets of `matrix_mad` intrinsics can differ based on device's *minimal* supported sub-group size.
6+
// The *minimum supported* sub-group size should be used to call `matrix_mad` intrinsics.
7+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
8+
9+
gpu.func @block_dpas(%a: !llvm.ptr<1>, %b: !llvm.ptr<1>, %c: !llvm.ptr<1>) kernel attributes {intel_reqd_sub_group_size = 16 : i32} {
10+
%base_width_a = arith.constant 32 : i32
11+
%base_height_a = arith.constant 8 : i32
12+
%base_pitch_a = arith.constant 32 : i32
13+
%x = arith.constant 0 : i32
14+
%y = arith.constant 0 : i32
15+
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
16+
17+
%base_width_b = arith.constant 32 : i32
18+
%base_height_b = arith.constant 16 : i32
19+
%base_pitch_b = arith.constant 32 : i32
20+
%loaded_b1 = xevm.blockload2d %b, %base_width_b, %base_height_b, %base_pitch_b, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
21+
%loaded_b_casted = vector.bitcast %loaded_b1 : vector<16xi16> to vector<8xi32>
22+
23+
%base_width_c = arith.constant 64 : i32
24+
%base_height_c = arith.constant 8 : i32
25+
%base_pitch_c = arith.constant 64 : i32
26+
%loaded_c = xevm.blockload2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
27+
28+
%loaded_c_casted = vector.bitcast %loaded_c : vector<8xi32> to vector<8xf32>
29+
%c_result = xevm.dpas %loaded_c_casted, %loaded_a, %loaded_b_casted {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32>
30+
%c_result_casted = vector.bitcast %c_result : vector<8xf32> to vector<8xi32>
31+
32+
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
33+
gpu.return
34+
}
35+
}
36+
37+
func.func @test(%a : memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
38+
%c1 = arith.constant 1 : index
39+
%c16 = arith.constant 16 : index
40+
41+
%memref_a = gpu.alloc host_shared () : memref<8x16xf16>
42+
memref.copy %a, %memref_a : memref<8x16xf16> to memref<8x16xf16>
43+
%a_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_a : memref<8x16xf16> -> index
44+
%a_ptr_as_i64 = arith.index_cast %a_ptr_as_idx : index to i64
45+
%a_ptr = llvm.inttoptr %a_ptr_as_i64 : i64 to !llvm.ptr
46+
%a_ptr_casted = llvm.addrspacecast %a_ptr : !llvm.ptr to !llvm.ptr<1>
47+
48+
%memref_b = gpu.alloc host_shared () : memref<16x16xf16>
49+
memref.copy %b, %memref_b : memref<16x16xf16> to memref<16x16xf16>
50+
%b_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_b : memref<16x16xf16> -> index
51+
%b_ptr_as_i64 = arith.index_cast %b_ptr_as_idx : index to i64
52+
%b_ptr = llvm.inttoptr %b_ptr_as_i64 : i64 to !llvm.ptr
53+
%b_ptr_casted = llvm.addrspacecast %b_ptr : !llvm.ptr to !llvm.ptr<1>
54+
55+
%memref_c = gpu.alloc host_shared () : memref<8x16xf32>
56+
memref.copy %c, %memref_c : memref<8x16xf32> to memref<8x16xf32>
57+
%c_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_c : memref<8x16xf32> -> index
58+
%c_ptr_as_i64 = arith.index_cast %c_ptr_as_idx : index to i64
59+
%c_ptr = llvm.inttoptr %c_ptr_as_i64 : i64 to !llvm.ptr
60+
%c_ptr_casted = llvm.addrspacecast %c_ptr : !llvm.ptr to !llvm.ptr<1>
61+
62+
gpu.launch_func @kernel::@block_dpas blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%a_ptr_casted : !llvm.ptr<1>, %b_ptr_casted : !llvm.ptr<1>, %c_ptr_casted : !llvm.ptr<1>)
63+
return %memref_c : memref<8x16xf32>
64+
}
65+
66+
func.func @main() attributes {llvm.emit_c_interface} {
67+
%A = memref.alloc() : memref<8x16xf16>
68+
%c0 = arith.constant 0 : index
69+
%c1 = arith.constant 1 : index
70+
%c8 = arith.constant 8 : index
71+
%c16 = arith.constant 16 : index
72+
73+
scf.for %i = %c0 to %c8 step %c1 {
74+
scf.for %j = %c0 to %c16 step %c1 {
75+
%row_idx = arith.index_cast %i : index to i32
76+
%row = arith.sitofp %row_idx : i32 to f16
77+
memref.store %row, %A[%i, %j] : memref<8x16xf16>
78+
}
79+
}
80+
%B = memref.alloc() : memref<16x16xf16>
81+
scf.for %i = %c0 to %c16 step %c1 {
82+
scf.for %j = %c0 to %c16 step %c1 {
83+
%col_idx = arith.index_cast %j : index to i32
84+
%col = arith.sitofp %col_idx : i32 to f16
85+
memref.store %col, %B[%i, %j] : memref<16x16xf16>
86+
}
87+
}
88+
89+
%C = memref.alloc() : memref<8x16xf32>
90+
%c0_f16 = arith.constant 0.0 : f32
91+
scf.for %i = %c0 to %c8 step %c1 {
92+
scf.for %j = %c0 to %c16 step %c1 {
93+
memref.store %c0_f16, %C[%i, %j] : memref<8x16xf32>
94+
}
95+
}
96+
97+
%C_res = call @test(%A, %B, %C) : (memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) -> memref<8x16xf32>
98+
%C_cast = memref.cast %C_res : memref<8x16xf32> to memref<*xf32>
99+
%A_cast = memref.cast %A : memref<8x16xf16> to memref<*xf16>
100+
call @printMemrefF32(%C_cast) : (memref<*xf32>) -> ()
101+
102+
// CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
103+
// CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
104+
// CHECK-NEXT: [0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]
105+
// CHECK-NEXT: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480]
106+
// CHECK-NEXT: [0, 48, 96, 144, 192, 240, 288, 336, 384, 432, 480, 528, 576, 624, 672, 720]
107+
// CHECK-NEXT: [0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960]
108+
// CHECK-NEXT: [0, 80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120, 1200]
109+
// CHECK-NEXT: [0, 96, 192, 288, 384, 480, 576, 672, 768, 864, 960, 1056, 1152, 1248, 1344, 1440]
110+
// CHECK-NEXT: [0, 112, 224, 336, 448, 560, 672, 784, 896, 1008, 1120, 1232, 1344, 1456, 1568, 1680]
111+
112+
return
113+
}
114+
func.func private @printMemrefF16(%ptr : memref<*xf16>) attributes { llvm.emit_c_interface }
115+
func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface }
116+
117+
}

0 commit comments

Comments
 (0)