Skip to content

[WIP] XeVM and XeGPU SIMT dev #429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/llvm-version-imex.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
add6b2f35f2bcf1f59a2ab2d5b3dab124fe0895a
7842374103b26933d71a8fe354cd4d8715d55b1c
2 changes: 1 addition & 1 deletion cmake/llvm-version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
add6b2f35f2bcf1f59a2ab2d5b3dab124fe0895a
3ae0f3047b5a0de8ef98c167610f6018f615b7ea
3 changes: 3 additions & 0 deletions include/gc/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

namespace mlir {

#define GEN_PASS_DECL
#include "gc/Conversion/Passes.h.inc"

/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
#include "gc/Conversion/Passes.h.inc"
Expand Down
16 changes: 16 additions & 0 deletions include/gc/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,20 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// XeGPUToXeVM
//===----------------------------------------------------------------------===//

def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
let summary = "Convert XeGPU to XeVM dialect";
let dependentDialects = [
"xegpu::XeGPUDialect",
"xevm::XeVMDialect",
"vector::VectorDialect",
"memref::MemRefDialect",
"arith::ArithDialect",
];
}


#endif // GC_CONVERSION_PASSES
28 changes: 28 additions & 0 deletions include/gc/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===-- XeGPUToXeVM.h - Convert XeVM to LLVM dialect -------------*- C++
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_

#include <memory>

namespace mlir {
class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;

#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
#include "gc/Conversion/Passes.h.inc"

void populateXeGPUToXeVMConversionPatterns(RewritePatternSet &patterns,
LLVMTypeConverter &typeConverter);

} // namespace mlir

#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
13 changes: 13 additions & 0 deletions include/gc/Dialect/LLVMIR/XeVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,17 @@

#include "gc/Dialect/LLVMIR/XeVMOpsDialect.h.inc"

namespace mlir::xevm {
/// XeVM memory space identifiers following SPIRV storage class convention
/// https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/main/docs/SPIRVRepresentationInLLVM.rst#address-spaces
///
enum class XeVMMemorySpace : uint32_t {
kFunction = 0, // OpenCL workitem address space
kCrossWorkgroup = 1, // OpenCL Global memory
kUniformConstant = 2, // OpenCL Constant memory
kWorkgroup = 3, // OpenCL Local memory
kGeneric = 4 // OpenCL Generic memory
};

} // namespace mlir::xevm
#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */
136 changes: 112 additions & 24 deletions include/gc/Dialect/LLVMIR/XeVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,44 @@ class XeVM_Op<string mnemonic, list<Trait> traits = []> :

def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;

class XeVM_LoadCacheControl<string cacheMnemonic> : I32EnumAttr<!strconcat(cacheMnemonic, "LoadCacheControl"), "XeVM load ops cache control",
def XeVM_LoadCacheControl : I32EnumAttr<"LoadCacheControl", "XeVM load ops cache control",
[
I32EnumAttrCase<"DEFAULT", 0, "Default">,
I32EnumAttrCase<"UC", 1, !strconcat(cacheMnemonic, "UC")>, // uncached
I32EnumAttrCase<"C", 2, !strconcat(cacheMnemonic, "C")>, // cached
I32EnumAttrCase<"S", 3, !strconcat(cacheMnemonic, "S")>, // streaming
I32EnumAttrCase<"IAR", 4, !strconcat(cacheMnemonic, "IAR")>, // invalidate-after-read
I32EnumAttrCase<"UC", 1, "UC">, // uncached
I32EnumAttrCase<"C", 2, "C">, // cached
I32EnumAttrCase<"S", 3, "S">, // streaming
I32EnumAttrCase<"IAR", 4, "IAR">, // invalidate-after-read
]> {
let cppNamespace = "::mlir::xevm";
let genSpecializedAttr = 0;
}

def XeVM_L1LoadCacheControl : XeVM_LoadCacheControl<"L1">;
def XeVM_L3LoadCacheControl : XeVM_LoadCacheControl<"L3">;
def XeVM_LoadCacheControlAttr:
EnumAttr<XeVM_Dialect, XeVM_LoadCacheControl, "load_cache_control"> {
let summary = [{ }];
let assemblyFormat = "$value";
}

class XeVM_StoreCacheControl<string cacheMnemonic> : I32EnumAttr<!strconcat(cacheMnemonic, "StoreCacheControl"), "XeVM store ops cache control",
def XeVM_StoreCacheControl : I32EnumAttr<"StoreCacheControl", "XeVM store ops cache control",
[
I32EnumAttrCase<"DEFAULT", 0, "Default">,
I32EnumAttrCase<"UC", 1, !strconcat(cacheMnemonic, "UC")>, // uncached
I32EnumAttrCase<"WT", 2, !strconcat(cacheMnemonic, "WT")>, // write-through
I32EnumAttrCase<"S", 3, !strconcat(cacheMnemonic, "S")>, // streaming
I32EnumAttrCase<"WB", 4, !strconcat(cacheMnemonic, "WB")>, // write back
I32EnumAttrCase<"UC", 1, "UC">, // uncached
I32EnumAttrCase<"WT", 2, "WT">, // write-through
I32EnumAttrCase<"S", 3, "S">, // streaming
I32EnumAttrCase<"WB", 4, "WB">, // write back
]> {
let cppNamespace = "::mlir::xevm";
let genSpecializedAttr = 0;
}

def XeVM_L1StoreCacheControl : XeVM_StoreCacheControl<"L1">;
def XeVM_L3StoreCacheControl : XeVM_StoreCacheControl<"L3">;
def XeVM_StoreCacheControlAttr:
EnumAttr<XeVM_Dialect, XeVM_StoreCacheControl, "store_cache_control"> {
let summary = [{ }];
let assemblyFormat = "$value";
}

def XeVM_BlockLoad2dOp : XeVM_Op<"blockload2d">,
Results<(outs FixedVectorOf<[XeVM_ElemType]>:$res)>,
Results<(outs FixedVectorOfRankAndType<[1,2,3], [XeVM_ElemType]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
I32:$base_width,
Expand All @@ -84,8 +92,8 @@ def XeVM_BlockLoad2dOp : XeVM_Op<"blockload2d">,
I32Attr:$v_blocks,
I1Attr:$transpose,
I1Attr:$vnni_transform,
DefaultValuedAttr<XeVM_L1LoadCacheControl, "::mlir::xevm::L1LoadCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_L3LoadCacheControl, "::mlir::xevm::L3LoadCacheControl::DEFAULT">:$l3_cache_control
DefaultValuedAttr<XeVM_LoadCacheControlAttr, "::mlir::xevm::LoadCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_LoadCacheControlAttr, "::mlir::xevm::LoadCacheControl::DEFAULT">:$l3_cache_control
)> {

let summary = "2D block load";
Expand Down Expand Up @@ -137,9 +145,9 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
I32Attr:$tile_width,
I32Attr:$tile_height,
I32Attr:$v_blocks,
FixedVectorOf<[XeVM_ElemType]>:$stored_val,
DefaultValuedAttr<XeVM_L1StoreCacheControl, "::mlir::xevm::L1StoreCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_L3StoreCacheControl, "::mlir::xevm::L3StoreCacheControl::DEFAULT">:$l3_cache_control
FixedVectorOfRankAndType<[1, 2, 3], [XeVM_ElemType]>:$stored_val,
DefaultValuedAttr<XeVM_StoreCacheControlAttr, "::mlir::xevm::StoreCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_StoreCacheControlAttr, "::mlir::xevm::StoreCacheControl::DEFAULT">:$l3_cache_control
)> {

let summary = "2D block store";
Expand Down Expand Up @@ -174,6 +182,86 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
let hasVerifier = 1;
}

def XeVM_MemoryScope : I32EnumAttr<"MemoryScope", "Memory scope for memory operations",
[
I32EnumAttrCase<"WORKGROUP", 0, "workgroup">,
I32EnumAttrCase<"CLUSTER", 1, "cluster">,
I32EnumAttrCase<"GPU", 2, "gpu">,
I32EnumAttrCase<"SYSTEM", 3, "system">
]>{
let cppNamespace = "mlir::xevm";
let genSpecializedAttr = 0;
}

def XeVM_MemoryScopeAttr:
EnumAttr<XeVM_Dialect, XeVM_MemoryScope, "mem_scope"> {
let summary = [{Describes the memory visibility scope:
"workgroup" - All work-items in the same work-group.
"cluster" - All work-items in the same cluster (a group of workgroups sharing SLM).
"gpu" - All work-items in the global NDrange.
"system" - All work-items in the global NDrange and the host program. }];
let assemblyFormat = "$value";
}

def XeVM_AddrSpace : I32EnumAttr<"AddrSpace", "Address spaces",
[
I32EnumAttrCase<"SHARED", 0, "shared">,
I32EnumAttrCase<"GLOBAL", 1, "global">,
I32EnumAttrCase<"GENERIC", 2, "generic">
]>{
let cppNamespace = "mlir::xevm";
let genSpecializedAttr = 0;
}

def XeVM_AddrSpaceAttr:
EnumAttr<XeVM_Dialect, XeVM_AddrSpace, "fence_addrspace"> {
let summary = [{Specifies the address space for memory operations affected by a fence:
"shared" - workgroup (SLM).
"global" - GPU.
"generic" - both "shared" and "global".}];
let assemblyFormat = "$value";
}

def XeVM_MemfenceOp : XeVM_Op<"memfence">,
Arguments<(ins
XeVM_MemoryScopeAttr:$scope,
DefaultValuedAttr<XeVM_AddrSpaceAttr, "mlir::xevm::AddrSpace::GENERIC"> :$addrspace
)> {
let summary = "Work-item's memory fence.";
let description = [{
This operation ensures that all prior memory accesses of this
work-item to `addrspace` are visible to all other work-items in `scope`.
Parameters description:
$scope - specify the memory scope at which all other work-items should observe
memory operations prior to the fence.
$addrspace - specify the address space of work-item's memory accesses
to be affected by the fence.
}];
let assemblyFormat = [{`addrspace` `=` `` $addrspace `,` `scope` `=` `` $scope attr-dict}];
}

def XeVM_PrefetchOp : XeVM_Op<"prefetch">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
XeVM_AddrSpaceAttr:$addrspace,
DefaultValuedAttr<XeVM_LoadCacheControlAttr, "::mlir::xevm::LoadCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_LoadCacheControlAttr, "::mlir::xevm::LoadCacheControl::DEFAULT">:$l3_cache_control
)> {
let summary = "Prefetch data into a cache subsystem.";
let description = [{
Work-item issues a prefetch from global memory to L1/L3 cache:
$ptr - memory pointer.
$addrspace - address space of a pointer, must be generic or global.
$cache_control - specify caching options (e.g., L1c, L3uc).
}];
let assemblyFormat = [{
operands ` ` `{` `addrspace` `=` $addrspace `,` `l1_cc` `=` $l1_cache_control `,` `l3_cc` `=` $l3_cache_control `}`
attr-dict `:` `(` type(operands) `)`
}];

// let hasVerifier = 1;
}

def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
Expand All @@ -186,8 +274,8 @@ def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
I32Attr:$tile_width,
I32Attr:$tile_height,
I32Attr:$v_blocks,
DefaultValuedAttr<XeVM_L1LoadCacheControl, "::mlir::xevm::L1LoadCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_L3LoadCacheControl, "::mlir::xevm::L3LoadCacheControl::DEFAULT">:$l3_cache_control
DefaultValuedAttr<XeVM_LoadCacheControlAttr, "::mlir::xevm::LoadCacheControl::DEFAULT">:$l1_cache_control,
DefaultValuedAttr<XeVM_LoadCacheControlAttr, "::mlir::xevm::LoadCacheControl::DEFAULT">:$l3_cache_control
)> {

let summary = "2D block prefetch";
Expand Down Expand Up @@ -242,8 +330,8 @@ def XeVM_PrecisionTypeAttr : I32EnumAttr<"PrecisionType",
let cppNamespace = "::mlir::xevm";
}

def XeVM_DPASOp : XeVM_Op<"dpas">,
Results<(outs FixedVectorOf<[XeVM_MatrixElemType]>:$d)>,
def XeVM_DpasOp : XeVM_Op<"dpas">,
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
Arguments<(ins
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$c,
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
Expand Down
8 changes: 4 additions & 4 deletions include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter,
auto context = rewriter.getContext();
auto tattr = dyn_cast_or_null<TypeAttr>(attr);
assert(tattr);
if (tattr == TypeAttr::get(FloatType::getF32(context))) {
if (tattr == TypeAttr::get(Float32Type::get(context))) {
return static_cast<int64_t>(dnnl_f32);
} else if (tattr == TypeAttr::get(FloatType::getF64(context))) {
} else if (tattr == TypeAttr::get(Float64Type::get(context))) {
return static_cast<int64_t>(dnnl_f64);
} else if (tattr == TypeAttr::get(FloatType::getBF16(context))) {
} else if (tattr == TypeAttr::get(BFloat16Type::get(context))) {
return static_cast<int64_t>(dnnl_bf16);
} else if (tattr == TypeAttr::get(FloatType::getF16(context))) {
} else if (tattr == TypeAttr::get(Float16Type::get(context))) {
return static_cast<int64_t>(dnnl_f16);
} else if (tattr == TypeAttr::get(
IntegerType::get(context, 32, IntegerType::Signed))) {
Expand Down
5 changes: 3 additions & 2 deletions include/gc/Transforms/Utils/StructuredOpMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct HasStaticStrides {
SmallVector<int64_t> strides;
if (auto memRefType = dyn_cast_or_null<MemRefType>(operandType)) {
int64_t offset;
if (failed(getStridesAndOffset(memRefType, strides, offset)))
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return false;
if (llvm::any_of(strides, [](int64_t stride) {
return stride == ShapedType::kDynamic;
Expand Down Expand Up @@ -244,7 +244,8 @@ struct NumDpsInits {
// Callable object to validate number of input operands for `op`.
struct NumDpsInputs {
NumDpsInputs() = delete;
explicit NumDpsInputs(std::function<bool(size_t)> fun) : fun(std::move(fun)){};
explicit NumDpsInputs(std::function<bool(size_t)> fun)
: fun(std::move(fun)){};

bool operator()(Operation *op) {
if (auto linalgOp = dyn_cast_or_null<linalg::LinalgOp>(op))
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(XeVMToLLVM)
add_subdirectory(XeGPUToXeVM)
24 changes: 24 additions & 0 deletions lib/gc/Conversion/XeGPUToXeVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
gc_add_mlir_conversion_library(MLIRXeGPUToXeVM
XeGPUToXeVM.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/gc/Conversion/XeGPUToXeVM

DEPENDS
GCConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRFuncDialect
MLIRGPUDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRXeVMDialect
MLIRVectorDialect
MLIRArithDialect
MLIRXeGPUDialect
MLIRPass
MLIRTransforms
)
Loading
Loading