Skip to content

[mlir][nvgpu] Remove strict verifiers on warpgroup.generate.descriptor #69935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

grypp
Copy link
Member

@grypp grypp commented Oct 23, 2023

This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.

The test #69913 needs this PR.

This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.
@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.


Full diff: https://github.com/llvm/llvm-project/pull/69935.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (-6)
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..15eeba2839479d8 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -375,15 +375,9 @@ LogicalResult WarpgroupGenerateDescriptorOp::verify() {
   MemRefType memrefType = getTensor().getType();
   MemRefType tensorMapType = getTensorMap().getType().getTensor();
 
-  if (memrefType != tensorMapType)
-    return emitError() << "memref and tensor map type mismatch";
-
   if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
     return emitError() << "supports only static shapes";
 
-  if (memrefType.getRank() != 2)
-    return emitError() << "supports only 2d memref is supported for now";
-
   if (getTensorMap().getType().getSwizzle() !=
       TensorMapSwizzleKind::SWIZZLE_128B) {
     return emitError() << "supports only "

@nicolasvasilache
Copy link
Contributor

The test #69913 needs this PR.

Can you fold this change into the PR that needs it ?

@@ -375,15 +375,9 @@ LogicalResult WarpgroupGenerateDescriptorOp::verify() {
MemRefType memrefType = getTensor().getType();
MemRefType tensorMapType = getTensorMap().getType().getTensor();

if (memrefType != tensorMapType)
return emitError() << "memref and tensor map type mismatch";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the semantic of unmatched memref/tensor types?

Could you add a test case demonstrating this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is similar what we discuss here.

if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
return emitError() << "supports only static shapes";

if (memrefType.getRank() != 2)
return emitError() << "supports only 2d memref is supported for now";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is supposed to be used to feed into wgmma operations, why do we need to support more than 2ds?

(Sorry for the dumb questions x).)

@grypp
Copy link
Member Author

grypp commented Nov 1, 2023

I created #70923 that is a better cleanup and improvement.

@grypp grypp closed this Nov 1, 2023
grypp added a commit that referenced this pull request Nov 10, 2023
…#70028)

PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with
if-statement. This PR adds the same test using predicates in PTX.
Predicate support is enabled using _BasicPtxBuilderInterface_
`(nvgpu.opcode ..., predicate = %pred)`.

The predicate condition is computed in `Step 2. [GPU] Elect fastest
thread in CTA` inspired by cutlass. It is as follows:
```
lane_predicate = nvvm.elect.sync
warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0)
warp_idx_in_warp_group = warp_idx % 4
predicate = (lane_predicate & warp_idx_in_warp_group)
```

Depends on #70027 #69934 #69935 #69584
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…llvm#70028)

PR llvm#69913 added a GEMM test (128x128x128 F32 += F16 * F16) with
if-statement. This PR adds the same test using predicates in PTX.
Predicate support is enabled using _BasicPtxBuilderInterface_
`(nvgpu.opcode ..., predicate = %pred)`.

The predicate condition is computed in `Step 2. [GPU] Elect fastest
thread in CTA` inspired by cutlass. It is as follows:
```
lane_predicate = nvvm.elect.sync
warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0)
warp_idx_in_warp_group = warp_idx % 4
predicate = (lane_predicate & warp_idx_in_warp_group)
```

Depends on llvm#70027 llvm#69934 llvm#69935 llvm#69584
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants