-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][NVGPU] Remove Memref Rank vs. Coordinates tma.async.load
#69584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously, a verifier to check for mismatches between memref rank and number of coordinates was introduced. I noticed that it is very strict. Let's take following IR snippet where the verifier complains about mismatches (2 coordinates != memref rank 3). ``` nvgpu.tma.async.load %map[%coord1, %coord2], %2 to %1 : ... -> memref<1x64x128xf16, ..., 3> ``` This PR relax the verifier.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Guray Ozen (grypp) ChangesPreviously, a verifier to check for mismatches between memref rank and number of coordinates was introduced. I noticed that it is very strict. Let's take following IR snippet where the verifier complains about mismatches (2 coordinates (%c1,%c2) != memref rank 3).
This PR relax the verifier. Full diff: https://github.com/llvm/llvm-project/pull/69584.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..2a280df68371d17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -347,12 +347,6 @@ LogicalResult TmaAsyncLoadOp::verify() {
if (getCoordinates().size() > 5) {
return emitError() << "Maximum 5 coordinates are supported.";
}
- if (getCoordinates().size() != size_t(dstMemref.getRank())) {
- return emitError() << "Destination memref rank is "
- << size_t(dstMemref.getRank()) << " but there are "
- << getCoordinates().size()
- << " coordinates. They must match.";
- }
return success();
}
|
Can you fold this change into the PR that needs it ? |
Instead of relaxing the check, I feel that we would need to use a In particular, what would be the semantic of:
Is I think the motivating example only works because the leading dim of the input memref is 1. |
It will be For example:
I want to allow option 1. I guess you are concerned about option-2. I can improve the verifier so it complains for option 2. Let me do that.
Actually I could do this. But the test is large and requires a few more PRs. So I split them up for easy review :) |
Well, I actually don't understand what's the semantic of option 1 :).
|
I created #70923 that is a better cleanup and improvment. |
…#70028) PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with if-statement. This PR adds the same test using predicates in PTX. Predicate support is enabled using _BasicPtxBuilderInterface_ `(nvgpu.opcode ..., predicate = %pred)`. The predicate condition is computed in `Step 2. [GPU] Elect fastest thread in CTA` inspired by cutlass. It is as follows: ``` lane_predicate = nvvm.elect.sync warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) warp_idx_in_warp_group = warp_idx % 4 predicate = (lane_predicate & warp_idx_in_warp_group) ``` Depends on #70027 #69934 #69935 #69584
…llvm#70028) PR llvm#69913 added a GEMM test (128x128x128 F32 += F16 * F16) with if-statement. This PR adds the same test using predicates in PTX. Predicate support is enabled using _BasicPtxBuilderInterface_ `(nvgpu.opcode ..., predicate = %pred)`. The predicate condition is computed in `Step 2. [GPU] Elect fastest thread in CTA` inspired by cutlass. It is as follows: ``` lane_predicate = nvvm.elect.sync warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) warp_idx_in_warp_group = warp_idx % 4 predicate = (lane_predicate & warp_idx_in_warp_group) ``` Depends on llvm#70027 llvm#69934 llvm#69935 llvm#69584
Previously, a verifier to check for mismatches between memref rank and number of coordinates was introduced. I noticed that it is very strict. Let's take following IR snippet where the verifier complains about mismatches (2 coordinates (%c1,%c2) != memref rank 3).
This PR relax the verifier.
The test #69913 needs this PR.