-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … #120566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … #120566
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Petr Kurapov (kurapov-peter) Changes…the distributed IR case. This patch allows Some background:
My personal opinion on this is that we should take the 1 approach and modify it in the way so that the creation of a tensor descriptor survives the distribution but instructions are using views into the descriptor so the IR describes "what a logical thread does" and there are no type mismatches (introduce some xegpu view into tensor descriptor OP, so that it would work for both ND and scattered case in the distribution logic). That said, I'm OK to put this patch in for experiments with the second approach as it doesn't break anything. @Jianhui-Li, @charithaintc, @chencha3, @adam-smnk, @rengolin Full diff: https://github.com/llvm/llvm-project/pull/120566.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f3ffbd0f5a027d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
let hasVerifier = 1;
}
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
- AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "stores a n-D block register region back to memory, currently only supports 2D";
let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..721cba70520758 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
+// Validations for nd instruction arguments is successful if any of these are
+// true:
+// - tensor descriptor and the output vector shapes exactly match.
+// - tensor descriptor has a sg_map attribute and the distributed vector shape
+// matches the tensor descriptor shape when scaled using sg_map factors on
+// each dimension.
+static bool isArgShapesValid(ArrayRef<int64_t> descShape,
+ ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
+ if (descShape == valShape)
+ return true;
+
+ if (!sgMap)
+ return false;
+
+ for (const auto &[factor, dim, expected] :
+ llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ if (factor * dim != expected)
+ return false;
+ }
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return success();
}
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
return emitOpError("Invalid result, it should be a VectorType.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto array_len = tdescTy.getArrayLength();
auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
auto it = tdescShape.begin();
tdescShape.insert(it, array_len);
}
+ auto sgMap = tdescTy.getSGMapAttr();
- if (tdescShape != valueShape)
+ if (!isArgShapesValid(tdescShape, valueShape, sgMap))
return emitOpError() << "Result shape doesn't match TensorDesc shape."
<< "The expected shape is " << makeString(tdescShape)
<< ". But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!valTy)
- return emitOpError("Exepcting a VectorType result.\n");
+ return emitOpError("Expecting a VectorType result.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+
+ auto tdescShape = getShapeOf(dstTy);
+ auto valueShape = getShapeOf(valTy);
+ auto sgMap = dstTy.getSGMapAttr();
+ if (!isArgShapesValid(tdescShape, valueShape, sgMap))
+ return emitOpError() << "Result shape doesn't match TensorDesc shape."
+ << "The expected shape is " << makeString(tdescShape)
+ << ". But the given shape is "
+ << makeString(valueShape) << ".\n";
return success();
}
@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return success();
}
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto tdescElemTy = tdescTy.getElementType();
auto valueElemTy = getElementType();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto maskTy = getMaskType();
auto valueTy = getValueType();
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..d7174a489888a4 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
gpu.return
}
+// load_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
gpu.return
}
+// store_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
+ %1 = arith.constant dense<1.0>: vector<24x2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f8a0d95bd70a27..155131ba9e6d50 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
// -----
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
return
}
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
// -----
func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
return
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
%1 = arith.constant dense<1.0>: vector<24x32xf16>
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
- // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+ // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
return
}
@@ -147,7 +147,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
return
}
@@ -168,7 +168,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
-> vector<4x2xf32>
@@ -193,7 +193,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
%1 = arith.constant dense<2.9>: vector<4x2xf32>
%2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+ // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
!xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
return
|
Just first impressions. Ideally, all Currently I'm mostly leaning toward the 4th idea. Subview of a descriptor could work or maybe the load operation itself could contain some optional offsets, e.g. mapping it to thread ID. |
Yup, I think so too.
Yes, that could be a good indication for the pass itself (e.g., as a return condition) to do the distribution.
Me too, that looks cleanest to me, so I'm suggesting we go there. For an offset I'm not sure what would it represent in a non-distributed case. You could do a vector of offsets that describes all the offsets for each individual lane but I think the definition of ops and sg_map wanted to avoid such an explicit representation. Anyway, there are clearly multiple ways of doing that.
I think that's not necessary, the transformation in a sense just slices through the dataflow in a way. You theoretically could do that multiple times (here's when this starts resembling tiling, which you could argue it kinda is). |
Based on the nature of tensor descriptor, I don't think we should distribute it. The tensor descriptor is a uniform value which all SIMD lanes share. There is only one copy of tensor descriptor being created for all SIMD lanes, and the creation doesn't involve lane id. The 2d block load is a collective operation that each lane takes the uniform tensor descriptor and load/store its own data fragment. The block shape size inside tensor descriptor can't be viewed exactly as the memref shape, which can naturally distribute and each individual thread computes its own address according to lane id. Instead, the computation of tensor descriptor involves no lane id, and all lanes should compute one same value, so that at the assembly level the tensor descriptor creation is done by only 1 thread. If we distribute it, we need to reverse it back by merging the shape back to original block size and geting rid of the lane id from the computation. Also the distribution is non-trivial which makes the reversal process complex: the data fragment may have strides along 2 dimensions, so each thread may generate multiple addresses. I don't see it is worth doing it since I don't see any optimization we want or missed on this kind of distributed form.
To me, SIMT doesn't mean each thread must know exactly how to compute address using its own lane id. This is actually what the HW ISA try to avoid, since pre-lane addresses uses more registers than one uniform tensor descriptor.
I don't quite understand the "memory side effects" and "violation of owership" to debate. Maybe an example can help here. In worst case, most of XeGPU level optimization is target dependent so we will have to take care this inconsistent shape issue that is target specific. I don't expect many target-independent optimization which we want XeGPU to be distributed in perfect SIMT flavor, but only see problems if we go that way as stated above. If you have example, you may point it out. |
Are you suggesting we apply the second approach, ignore that this is very confusing, and just hope that there would be no problems with type mismatching in the future? I also don't see any contradictions to the proposal (the 4th approach) by the way, which is preferable imo.
Yes, I know. The instruction does not live in a vacuum though. Since we are composing this with vector distribution, the resulting vector of a load would be consumed by some vector IR that describes what a logical thread does. It will have to consume some portion of the data that "belongs" to that thread. That portion is coming from a memref as described by the tensor descriptor. So there is already information about how to retrieve a specific data element(s) for a particular lane id, although now it is implicit. I don't see it as an argument against having a view into the descriptor since the view can exactly make that implicit assignment explicit. The information also wouldn't need to be restored down the pipeline since you retain the original
I don't think this is necessary, you just need to restore the parameters such as the size of the descriptor which can be done using the sg_map. The assignment of data chunks to lane ids is implicit. Anyway, this is neither what I'm proposing nor it is related to the patch.
I have no examples. All I'm stating is that I'm not aware of the consequences of messing with the types like that. |
I would not rush for option 4 also. If I understand correctly, the subview op needs to compute the offsets from id, and then create the subview with sizes, offsets, strides. I like the approach keeping the original tensor descriptor as a whole, but it requires adding a subview op which doesn't look like other XeGPU OPs (mapping to concrete hardware operations). I don't know what benefit it can bring at this point other than the IR appears less "confusing", which is a debatable point. To me, whether the IR is "confusing" actually depends on how the IR is lowered or optimized. My view is actually reverse. The type mismatch doesn't bother me that much. But if the IR doesn't model the hardware behavior, say it introduces per-lane offsets/sizes computation which we don't need during the lowering, it causes a different type of confusion that bothers me more. The passes on XeGPU is mostly target-specific, so people likes to match the IR with what hardware behavior - each lane takes the whole shape and read back its own data fragments implicitly, instead of computing its own offsets/sizes. If transformation/optimization needs to know the data fragment distribution, they can refer to sg_map that was designed to explicitly describe the data distribution. I suggest we first go with option 2. When it become clear that we really need a subview op, we can revisit it. |
Actually I find this code example as I believe this is a common issue for collective subgroup operations. Nvgpu ldmatrix chooses to live with the mismatch between memref shape and vector.
%mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} : memref<128x32xf16, 3> -> vector<4x2xf16> |
Good to know, that adds more context to the proposed solution as I initially thought this distribution split could be more arbitrary. After closer looks and going through the detailed explanations (thanks @Jianhui-Li), the main competing options 2 and 4 revolve around making the distribution implicit(-ish as it is still captured by
That's been my concern, however, so far I couldn't come up with a case that would break without explicit mapping. Primarily, as I imagine most complex transformations would occur still at SIMD/cooperative level before SIMT.
+1 on taking option 2 as the first step as it is simpler - no extra ops, changes that don't break any existing code. |
Agreed. +1 for option 2. XeGPU is primarily designed to faithfully represent HW details in block Load/Store. So in my view, approach 1 violates this philosophy for not clearly apparent benefits.
This argument is not clear to me. Are you saying that |
49ea091
to
103db33
Compare
Seems like there's consensus on going with the option 2 now. We can merge this then unless there're additional comments.
By the way, this validation relies on |
Seems to me that |
I am OK to keep |
// each dimension. | ||
static bool isArgShapesValid(ArrayRef<int64_t> descShape, | ||
ArrayRef<int64_t> valShape, SGMapAttr sgMap) { | ||
if (descShape == valShape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if descShape == valShape and sgMap is valid? does it mean the sgMap will be discarded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't get the question. If the shapes are equal that means the nd load/store is valid and not distributed (if there's an sg map, only sg map's validity is checked by SGMapAttr::verify
). If the shapes don't match then we either have a distributed or an invalid case. For distributed we check that ranks are the same, the sg map is present and the scaled values for each dimension match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVM, it seems a dumb question. It seems to me that, there are currently 3 stages: 1) pure SIMD code, sgMap == null
, and descShape == valShape
. 2). a valid sgMap is attached to guide the lowering, but the code is not rewritten yet, so we have descShape == valShape
and sgMap != null
, but sgMap is not effective yet. 3). code is rewritten, so descShape != valShape
, and sgMap != null
, and sgMap is now effective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, those are all valid states
@@ -210,13 +236,13 @@ LogicalResult PrefetchNdOp::verify() { | |||
return emitOpError("Expects a non-scattered TensorDesc.\n"); | |||
|
|||
if (!isReadHintOrNone(getL1HintAttr())) | |||
return emitOpError("invlid l1_hint: ") << getL1HintAttr(); | |||
return emitOpError("invalid l1_hint: ") << getL1HintAttr(); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing these typos.
It looks to me the current implementation is following option 2, am I right? |
Yes, these validation changes unblock option 2
Deriving from an attribute likely means the map would be essentially the same for all operations. I suspect this might be too restrictive, in presence of packing for example. Other than that, I don't see any technical problems with getting this information not from the type attribute but from somewhere else like a kernel attribute (which would probably be the same sg map but attached to a kernel?). |
It is not necessary to require all operations to have a same mapping rule, since they have different semantics. But it requires each op has exactly one mapping rule. |
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) { | |||
gpu.return | |||
} | |||
|
|||
// load_nd args may have different shapes, validated against sg_map | |||
// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) { | |||
gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chencha3 do you remember what does 'vc' suffix stand for? I feel like for SIMT-ish examples it is not relevant :)
…the distributed IR case.
This patch allows
nd_load
andnd_store
to preserve the tensor descriptor shape during distribution to SIMT. The validation now expects the distributed instruction to retain thesg_map
attribute and uses it to verify the consistency.Some background:
We've been discussing the appropriate way of distributing xegpu operations and there are multiple approaches possible. I'm listing the three main ones commenting on their properties.
sg_map
), but somewhat breaks the promise of XeGPU to represent HW abstractions 1:1 (arguably, imo). Allowing distributed tensor descriptors also means users would be able to create and use them inappropriately, so some UB cases are inevitably introduced.nd_load
andnd_store
are ops that should have memory side effects that are not implemented at the moment. It is unclear to me what implications this violation of "ownership" may have.).My personal opinion on this is that we should take the 1 approach and modify it in the way so that the creation of a tensor descriptor survives the distribution but instructions are using views into the descriptor so the IR describes "what a logical thread does" and there are no type mismatches (introduce some xegpu view into tensor descriptor OP, so that it would work for both ND and scattered case in the distribution logic). That said, I'm OK to put this patch in for experiments with the second approach as it doesn't break anything.
@Jianhui-Li, @charithaintc, @chencha3, @adam-smnk, @rengolin