-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][nvgpu] Remove strict verifiers on warpgroup.generate.descriptor
#69935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Guray Ozen (grypp) ChangesThis PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR. Full diff: https://github.com/llvm/llvm-project/pull/69935.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..15eeba2839479d8 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -375,15 +375,9 @@ LogicalResult WarpgroupGenerateDescriptorOp::verify() {
MemRefType memrefType = getTensor().getType();
MemRefType tensorMapType = getTensorMap().getType().getTensor();
- if (memrefType != tensorMapType)
- return emitError() << "memref and tensor map type mismatch";
-
if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
return emitError() << "supports only static shapes";
- if (memrefType.getRank() != 2)
- return emitError() << "supports only 2d memref is supported for now";
-
if (getTensorMap().getType().getSwizzle() !=
TensorMapSwizzleKind::SWIZZLE_128B) {
return emitError() << "supports only "
|
The test #69913 needs this PR. Can you fold this change into the PR that needs it ? |
@@ -375,15 +375,9 @@ LogicalResult WarpgroupGenerateDescriptorOp::verify() { | |||
MemRefType memrefType = getTensor().getType(); | |||
MemRefType tensorMapType = getTensorMap().getType().getTensor(); | |||
|
|||
if (memrefType != tensorMapType) | |||
return emitError() << "memref and tensor map type mismatch"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the semantic of unmatched memref/tensor types?
Could you add a test case demonstrating this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is similar what we discuss here.
if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape()) | ||
return emitError() << "supports only static shapes"; | ||
|
||
if (memrefType.getRank() != 2) | ||
return emitError() << "supports only 2d memref is supported for now"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is supposed to be used to feed into wgmma operations, why do we need to support more than 2ds?
(Sorry for the dumb questions x).)
I created #70923 that is a better cleanup and improvement. |
…#70028) PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with if-statement. This PR adds the same test using predicates in PTX. Predicate support is enabled using _BasicPtxBuilderInterface_ `(nvgpu.opcode ..., predicate = %pred)`. The predicate condition is computed in `Step 2. [GPU] Elect fastest thread in CTA` inspired by cutlass. It is as follows: ``` lane_predicate = nvvm.elect.sync warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) warp_idx_in_warp_group = warp_idx % 4 predicate = (lane_predicate & warp_idx_in_warp_group) ``` Depends on #70027 #69934 #69935 #69584
…llvm#70028) PR llvm#69913 added a GEMM test (128x128x128 F32 += F16 * F16) with if-statement. This PR adds the same test using predicates in PTX. Predicate support is enabled using _BasicPtxBuilderInterface_ `(nvgpu.opcode ..., predicate = %pred)`. The predicate condition is computed in `Step 2. [GPU] Elect fastest thread in CTA` inspired by cutlass. It is as follows: ``` lane_predicate = nvvm.elect.sync warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) warp_idx_in_warp_group = warp_idx % 4 predicate = (lane_predicate & warp_idx_in_warp_group) ``` Depends on llvm#70027 llvm#69934 llvm#69935 llvm#69584
This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.
The test #69913 needs this PR.