Skip to content

[MLIR][NVGPU] Remove Memref Rank vs. Coordinates tma.async.load #69584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

grypp
Copy link
Member

@grypp grypp commented Oct 19, 2023

Previously, a verifier to check for mismatches between memref rank and number of coordinates was introduced. I noticed that it is very strict. Let's take following IR snippet where the verifier complains about mismatches (2 coordinates (%c1,%c2) != memref rank 3).

nvgpu.tma.async.load %0[%c1, %c2], %1 to %2 : ... -> memref<1x64x128xf16, ..., 3>

This PR relax the verifier.

The test #69913 needs this PR.

Previously, a verifier to check for mismatches between memref rank and number of coordinates was introduced. I noticed that it is very strict. Let's take following IR snippet where the verifier complains about mismatches (2 coordinates != memref rank 3).

```
nvgpu.tma.async.load %map[%coord1, %coord2], %2 to %1 : ... -> memref<1x64x128xf16, ..., 3>
```

This PR relax the verifier.
@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

Previously, a verifier to check for mismatches between memref rank and number of coordinates was introduced. I noticed that it is very strict. Let's take following IR snippet where the verifier complains about mismatches (2 coordinates (%c1,%c2) != memref rank 3).

nvgpu.tma.async.load %0[%c1, %c2], %1 to %2 : ... -&gt; memref&lt;1x64x128xf16, ..., 3&gt;

This PR relax the verifier.


Full diff: https://github.com/llvm/llvm-project/pull/69584.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (-6)
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..2a280df68371d17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -347,12 +347,6 @@ LogicalResult TmaAsyncLoadOp::verify() {
   if (getCoordinates().size() > 5) {
     return emitError() << "Maximum 5 coordinates are supported.";
   }
-  if (getCoordinates().size() != size_t(dstMemref.getRank())) {
-    return emitError() << "Destination memref rank is "
-                       << size_t(dstMemref.getRank()) << " but there are  "
-                       << getCoordinates().size()
-                       << " coordinates. They must match.";
-  }
   return success();
 }
 

@nicolasvasilache
Copy link
Contributor

The test #69913 needs this PR.

Can you fold this change into the PR that needs it ?

@qcolombet
Copy link
Collaborator

Instead of relaxing the check, I feel that we would need to use a collapse_shape on the input memref.

In particular, what would be the semantic of:

nvgpu.tma.async.load %0[%c1, %c2], %1 to %2 : ... -> memref<Outerx64x128xf16, ..., 3>

Is c1 applied to Outer or to 64 dim?

I think the motivating example only works because the leading dim of the input memref is 1.

@grypp
Copy link
Member Author

grypp commented Oct 24, 2023

Instead of relaxing the check, I feel that we would need to use a collapse_shape on the input memref.

In particular, what would be the semantic of:

nvgpu.tma.async.load %0[%c1, %c2], %1 to %2 : ... -> memref<Outerx64x128xf16, ..., 3>

Is c1 applied to Outer or to 64 dim?

I think the motivating example only works because the leading dim of the input memref is 1.

It will be Outer (base pointer of memref).

For example:

  1. `Load 64x128 into memref<128x128>' -> verifier error -> PR will relax this
  2. Load 64x128 to memref<64x64> -> verifier error -> PR will relax this

I want to allow option 1. I guess you are concerned about option-2. I can improve the verifier so it complains for option 2. Let me do that.

The test #69913 needs this PR.

Can you fold this change into the PR that needs it ?

Actually I could do this. But the test is large and requires a few more PRs. So I split them up for easy review :)

@qcolombet
Copy link
Collaborator

Instead of relaxing the check, I feel that we would need to use a collapse_shape on the input memref.
In particular, what would be the semantic of:

nvgpu.tma.async.load %0[%c1, %c2], %1 to %2 : ... -> memref<Outerx64x128xf16, ..., 3>

Is c1 applied to Outer or to 64 dim?
I think the motivating example only works because the leading dim of the input memref is 1.

It will be Outer (base pointer of memref).

For example:

  1. `Load 64x128 into memref<128x128>' -> verifier error -> PR will relax this
  2. Load 64x128 to memref<64x64> -> verifier error -> PR will relax this

I want to allow option 1. I guess you are concerned about option-2. I can improve the verifier so it complains for option 2. Let me do that.

Well, I actually don't understand what's the semantic of option 1 :).
The thing that I'd like to avoid is having an instruction that is too powerful and hence, difficult to work with.

The test #69913 needs this PR.

Can you fold this change into the PR that needs it ?

Actually I could do this. But the test is large and requires a few more PRs. So I split them up for easy review :)

@grypp
Copy link
Member Author

grypp commented Nov 1, 2023

I created #70923 that is a better cleanup and improvment.

@grypp grypp closed this Nov 1, 2023
grypp added a commit that referenced this pull request Nov 10, 2023
…#70028)

PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with
if-statement. This PR adds the same test using predicates in PTX.
Predicate support is enabled using _BasicPtxBuilderInterface_
`(nvgpu.opcode ..., predicate = %pred)`.

The predicate condition is computed in `Step 2. [GPU] Elect fastest
thread in CTA` inspired by cutlass. It is as follows:
```
lane_predicate = nvvm.elect.sync
warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0)
warp_idx_in_warp_group = warp_idx % 4
predicate = (lane_predicate & warp_idx_in_warp_group)
```

Depends on #70027 #69934 #69935 #69584
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…llvm#70028)

PR llvm#69913 added a GEMM test (128x128x128 F32 += F16 * F16) with
if-statement. This PR adds the same test using predicates in PTX.
Predicate support is enabled using _BasicPtxBuilderInterface_
`(nvgpu.opcode ..., predicate = %pred)`.

The predicate condition is computed in `Step 2. [GPU] Elect fastest
thread in CTA` inspired by cutlass. It is as follows:
```
lane_predicate = nvvm.elect.sync
warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0)
warp_idx_in_warp_group = warp_idx % 4
predicate = (lane_predicate & warp_idx_in_warp_group)
```

Depends on llvm#70027 llvm#69934 llvm#69935 llvm#69584
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants