forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit 1513d53
[AMD] Use 1:N conversion to rewrite canonicalize pointers pass (triton-lang#5329)
### TL;DR (too long, didn't review)
This PR re-enables the `tritonamdgpu-canonicalize-pointers` pass[^1].
The PR is effectively a complete rewrite of the original pass, which
walked the AST and mutated IR in-place, using the new [`1:N` dialect
conversion framework](llvm/llvm-project#116470).
Recall a "fat pointer" is a tuple-like `(%baseptr, %offsetptr)` - the
current (original) pass keeps this tuple in a global data structure
while the new/rewritten pass emits this tuple into the IR as an
`unrealized_cast(%baseptr, %offsetptr)`[^2].
Note, this PR also rewrites the existing lit test (see [this comment
below](triton-lang#5329 (comment))).
### Pass outline
The pass structure/action is roughly:
1. Perform an approximate sparse dataflow analysis to find all
transitive uses for `tt.func` args that are `tt.ptr`s; legalize only
these ops;
2. Rewrite all operations' `use`s and `result`s to be `(%baseptr,
%offsetptr)` using `ConversionPattern`s that takes the new
`OneToNOpAdaptor`, which automatically forwards both `%baseptr` and
`%offsetptr` through `adaptor.getOperands()`[^3];
3. Clean up remaining `unrealized_casts` (currently only handling one
category of such remaining casts but can be extended to handle all; see
bullet 1 in TODOs).
### Some pre-emptive call outs
Right up front I'll say this took a long time to figure out because
**a)** the conversion framework is hugely complex **b)** it's being
currently rewritten to be more robust/stable. As a consequence, the
implementation is complex but I've tried hard to **a)** simplify as much
as possible **b)** comment/note subtleties **c)** put in ample `assert`s
and checks to clarify intent and gracefully fail. So some things to call
out:
1. I called the dataflow analysis approximate because it does not
actually use
[DataFlow/SparseAnalysis](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp)
and instead computes a forward slice using the heuristic "transfer
function" that users of an op with a `tt.ptr` operand should be
rewritten. This heuristic works because the forward slice starts from
`tt.ptr` args on a `tt.func` and ends at `tt.store`, which has no
results. Note, there's no reason why this component of the pass can't be
a true `SparseAnalysis` implementation, it's just that this rewrite has
already taken way longer than I expected (so I leave that for a possible
follow-up).
2. The pass uses no global `TypeConverter` but uses local
`TypeConverter`s, in `BranchInterface`/`RegionInterface` patterns. This
is because **a)** we are not actually converting operand/result types
(we are converting number of operands/results) **b)** the conversion
framework expects/handles this lack of a `TypeConverter` exactly [the
way we
want](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1179-L1185).
The local type converters are used for ops that couple to basic blocks
(`^bb`s) that need to have their signatures rewritten (i.e., the ops for
which we need to do `rewriter.applySignatureConversion(block,
*conversion, &localTypeConverter)`). That's `scf.for`, `scf.while`,
`cf.br` and `cf.cond_br` (not needed for `scf.if` which has no `bb`
args).
3. `tt.func` is handled differently from all of the other ops - it is
not rewritten at all. Instead, for every `%arg: tt.ptr` arg, we insert
into the new body `%c0 = arith.constant 0 : i32` and `%newarg =
unrealized_cast(%arg, %c0) : tt.ptr` (manually, not done by the
conversion framework) and replace all uses of `%arg` by `%newarg`. These
are then unpacked to `(%arg, %c0)` using `replaceOpWithMultiple` so that
they "magically" appear in `adaptor.getOperands()`. Then at the end,
currently, these are the only unreconciled casts (because they are the
only ones **not** inserted by the conversion framework) and we
materialize them by just replacing uses of `%newarg` with `%arg`.
4. `scf.if` needs to be handled specially; since it has no operands but
can `yield` results, we need to rewrite it only after its `yield`s have
been rewritten. This is not straightforward because the dialect
conversion [does a preorder
walk](https://github.com/llvm/llvm-project/blob/6ab8401f53deddbd79f930ba2bec2f824c9567e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2705).
To work around this we define legality for `scf.if` to be dependent on
whether its `yield`s have been rewritten (using two `UnitAttr`s on those
`yield`s). Thus, `scf.if` is "legal" and not rewritten until after the
results of the `yield`s are known.
[^1]: I haven't actually moved it out of the flag but it's now usable
with the `AMDGCN_USE_BUFFER_OPS` flag whereas it wasn't prior.
[^2]: In reality it's the conversion framework that materializes this
tuple as `unrealized_cast(%baseptr, %offsetptr)` and then
reconciles/DCEs all the casts automatically.
[^3]: The `unrealized_cast`s are completely "transparent" to the
patterns, see
[`ConversionPatternRewriterImpl::remapValues`](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1161).1 parent f8b1049 commit 1513d53Copy full SHA for 1513d53
File tree
4 files changed
+2292
-1292
lines changedFilter options
- test/TritonGPU/amd
- third_party/amd
- include/TritonAMDGPUTransforms
- lib/TritonAMDGPUTransforms
- python
4 files changed
+2292
-1292
lines changed
0 commit comments