Skip to content

Commit 1513d53

Browse files
maksleventalAlexAUT
authored andcommitted
[AMD] Use 1:N conversion to rewrite canonicalize pointers pass (triton-lang#5329)
### TL;DR (too long, didn't review) This PR re-enables the `tritonamdgpu-canonicalize-pointers` pass[^1]. The PR is effectively a complete rewrite of the original pass, which walked the AST and mutated IR in-place, using the new [`1:N` dialect conversion framework](llvm/llvm-project#116470). Recall a "fat pointer" is a tuple-like `(%baseptr, %offsetptr)` - the current (original) pass keeps this tuple in a global data structure while the new/rewritten pass emits this tuple into the IR as an `unrealized_cast(%baseptr, %offsetptr)`[^2]. Note, this PR also rewrites the existing lit test (see [this comment below](triton-lang#5329 (comment))). ### Pass outline The pass structure/action is roughly: 1. Perform an approximate sparse dataflow analysis to find all transitive uses for `tt.func` args that are `tt.ptr`s; legalize only these ops; 2. Rewrite all operations' `use`s and `result`s to be `(%baseptr, %offsetptr)` using `ConversionPattern`s that takes the new `OneToNOpAdaptor`, which automatically forwards both `%baseptr` and `%offsetptr` through `adaptor.getOperands()`[^3]; 3. Clean up remaining `unrealized_casts` (currently only handling one category of such remaining casts but can be extended to handle all; see bullet 1 in TODOs). ### Some pre-emptive call outs Right up front I'll say this took a long time to figure out because **a)** the conversion framework is hugely complex **b)** it's being currently rewritten to be more robust/stable. As a consequence, the implementation is complex but I've tried hard to **a)** simplify as much as possible **b)** comment/note subtleties **c)** put in ample `assert`s and checks to clarify intent and gracefully fail. So some things to call out: 1. I called the dataflow analysis approximate because it does not actually use [DataFlow/SparseAnalysis](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp) and instead computes a forward slice using the heuristic "transfer function" that users of an op with a `tt.ptr` operand should be rewritten. This heuristic works because the forward slice starts from `tt.ptr` args on a `tt.func` and ends at `tt.store`, which has no results. Note, there's no reason why this component of the pass can't be a true `SparseAnalysis` implementation, it's just that this rewrite has already taken way longer than I expected (so I leave that for a possible follow-up). 2. The pass uses no global `TypeConverter` but uses local `TypeConverter`s, in `BranchInterface`/`RegionInterface` patterns. This is because **a)** we are not actually converting operand/result types (we are converting number of operands/results) **b)** the conversion framework expects/handles this lack of a `TypeConverter` exactly [the way we want](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1179-L1185). The local type converters are used for ops that couple to basic blocks (`^bb`s) that need to have their signatures rewritten (i.e., the ops for which we need to do `rewriter.applySignatureConversion(block, *conversion, &localTypeConverter)`). That's `scf.for`, `scf.while`, `cf.br` and `cf.cond_br` (not needed for `scf.if` which has no `bb` args). 3. `tt.func` is handled differently from all of the other ops - it is not rewritten at all. Instead, for every `%arg: tt.ptr` arg, we insert into the new body `%c0 = arith.constant 0 : i32` and `%newarg = unrealized_cast(%arg, %c0) : tt.ptr` (manually, not done by the conversion framework) and replace all uses of `%arg` by `%newarg`. These are then unpacked to `(%arg, %c0)` using `replaceOpWithMultiple` so that they "magically" appear in `adaptor.getOperands()`. Then at the end, currently, these are the only unreconciled casts (because they are the only ones **not** inserted by the conversion framework) and we materialize them by just replacing uses of `%newarg` with `%arg`. 4. `scf.if` needs to be handled specially; since it has no operands but can `yield` results, we need to rewrite it only after its `yield`s have been rewritten. This is not straightforward because the dialect conversion [does a preorder walk](https://github.com/llvm/llvm-project/blob/6ab8401f53deddbd79f930ba2bec2f824c9567e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2705). To work around this we define legality for `scf.if` to be dependent on whether its `yield`s have been rewritten (using two `UnitAttr`s on those `yield`s). Thus, `scf.if` is "legal" and not rewritten until after the results of the `yield`s are known. [^1]: I haven't actually moved it out of the flag but it's now usable with the `AMDGCN_USE_BUFFER_OPS` flag whereas it wasn't prior. [^2]: In reality it's the conversion framework that materializes this tuple as `unrealized_cast(%baseptr, %offsetptr)` and then reconciles/DCEs all the casts automatically. [^3]: The `unrealized_cast`s are completely "transparent" to the patterns, see [`ConversionPatternRewriterImpl::remapValues`](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1161).
1 parent f8b1049 commit 1513d53

File tree

4 files changed

+2292
-1292
lines changed

4 files changed

+2292
-1292
lines changed

0 commit comments

Comments
 (0)