Skip to content

Commit 991cb14

Browse files
[mlir][memref][transform] Add new alloca_to_global op. (#66511)
This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.
1 parent 59896c1 commit 991cb14

File tree

5 files changed

+236
-0
lines changed

5 files changed

+236
-0
lines changed

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,69 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
144144
}
145145

146146
def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
147+
def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">;
148+
149+
def MemRefAllocaToGlobalOp :
150+
Op<Transform_Dialect, "memref.alloca_to_global",
151+
[TransformOpInterface,
152+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
153+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
154+
let description = [{
155+
Inserts a new `memref.global` for each provided `memref.alloca` into the
156+
nearest symbol table (e.g., a `builtin.module`) and replaces it with a
157+
`memref.get_global`. This is useful, for example, for allocations that
158+
should reside in the shared memory of a GPU, which have to be declared as
159+
globals.
160+
161+
#### Example
162+
163+
Consider the following transform op:
164+
165+
```mlir
166+
%get_global, %global =
167+
transform.memref.alloca_to_global %alloca
168+
: (!transform.op<"memref.alloca">)
169+
-> (!transform.any_op, !transform.any_op)
170+
```
171+
172+
and the following input payload:
173+
174+
```mlir
175+
module {
176+
func.func @func() {
177+
%alloca = memref.alloca() : memref<2x32xf32>
178+
// usages of %alloca...
179+
}
180+
}
181+
```
182+
183+
then applying the transform op to the payload would result in the following
184+
output IR:
185+
186+
```mlir
187+
module {
188+
memref.global "private" @alloc : memref<2x32xf32>
189+
func.func @func() {
190+
%alloca = memref.get_global @alloc : memref<2x32xf32>
191+
// usages of %alloca...
192+
}
193+
}
194+
```
195+
196+
#### Return modes
197+
198+
Succeeds always. The returned handles refer to the `memref.get_global` and
199+
`memref.global` ops that were inserted by the transformation.
200+
}];
201+
202+
let arguments = (ins Transform_MemRefAllocaOp:$alloca);
203+
let results = (outs TransformHandleTypeInterface:$getGlobal,
204+
TransformHandleTypeInterface:$global);
205+
206+
let assemblyFormat = [{
207+
$alloca attr-dict `:` functional-type(operands, results)
208+
}];
209+
}
147210

148211
def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
149212
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,67 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
126126
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
127127
}
128128

129+
//===----------------------------------------------------------------------===//
130+
// AllocaToGlobalOp
131+
//===----------------------------------------------------------------------===//
132+
133+
DiagnosedSilenceableFailure
134+
transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
135+
transform::TransformResults &results,
136+
transform::TransformState &state) {
137+
auto allocaOps = state.getPayloadOps(getAlloca());
138+
139+
SmallVector<memref::GlobalOp> globalOps;
140+
SmallVector<memref::GetGlobalOp> getGlobalOps;
141+
142+
// Transform `memref.alloca`s.
143+
for (auto *op : allocaOps) {
144+
auto alloca = cast<memref::AllocaOp>(op);
145+
MLIRContext *ctx = rewriter.getContext();
146+
Location loc = alloca->getLoc();
147+
148+
memref::GlobalOp globalOp;
149+
{
150+
// Find nearest symbol table.
151+
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152+
assert(symbolTableOp && "expected alloca payload to be in symbol table");
153+
SymbolTable symbolTable(symbolTableOp);
154+
155+
// Insert a `memref.global` into the symbol table.
156+
Type resultType = alloca.getResult().getType();
157+
OpBuilder builder(rewriter.getContext());
158+
// TODO: Add a better builder for this.
159+
globalOp = builder.create<memref::GlobalOp>(
160+
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
161+
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
162+
symbolTable.insert(globalOp);
163+
}
164+
165+
// Replace the `memref.alloca` with a `memref.get_global` accessing the
166+
// global symbol inserted above.
167+
rewriter.setInsertionPoint(alloca);
168+
auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169+
alloca, globalOp.getType(), globalOp.getName());
170+
171+
globalOps.push_back(globalOp);
172+
getGlobalOps.push_back(getGlobalOp);
173+
}
174+
175+
// Assemble results.
176+
results.set(getGlobal().cast<OpResult>(), globalOps);
177+
results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
178+
179+
return DiagnosedSilenceableFailure::success();
180+
}
181+
182+
void transform::MemRefAllocaToGlobalOp::getEffects(
183+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184+
producesHandle(getGlobal(), effects);
185+
producesHandle(getGetGlobal(), effects);
186+
consumesHandle(getAlloca(), effects);
187+
modifiesPayload(effects);
188+
}
189+
129190
//===----------------------------------------------------------------------===//
130191
// MemRefMultiBufferOp
131192
//===----------------------------------------------------------------------===//

mlir/python/mlir/dialects/_memref_transform_ops_ext.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,52 @@
1111
from typing import Optional, overload, Union
1212

1313

14+
class MemRefAllocaToGlobalOp:
15+
"""Specialization for MemRefAllocaToGlobalOp class."""
16+
17+
@overload
18+
def __init__(
19+
self,
20+
get_global_type: Type,
21+
global_type: Type,
22+
alloca: Union[Operation, OpView, Value],
23+
*,
24+
loc=None,
25+
ip=None
26+
):
27+
...
28+
29+
@overload
30+
def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
31+
...
32+
33+
def __init__(
34+
self,
35+
get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
36+
global_type_or_none: Optional[Type] = None,
37+
alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
38+
*,
39+
loc=None,
40+
ip=None
41+
):
42+
if isinstance(get_global_type_or_alloca, Type):
43+
get_global_type = get_global_type_or_alloca
44+
global_type = global_type_or_none
45+
alloca = alloca_or_none
46+
else:
47+
get_global_type = transform.AnyOpType.get()
48+
global_type = transform.AnyOpType.get()
49+
alloca = get_global_type_or_alloca
50+
51+
super().__init__(
52+
get_global_type,
53+
global_type,
54+
alloca,
55+
loc=loc,
56+
ip=ip,
57+
)
58+
59+
1460
class MemRefMultiBufferOp:
1561
"""Specialization for MemRefMultiBufferOp class."""
1662

mlir/test/Dialect/MemRef/transform-ops.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
11
// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
22

3+
// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
4+
// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
5+
6+
// CHECK-DAG: func.func @func(%[[LB:.*]]: index, %[[UB:.*]]: index)
7+
func.func @func(%lb: index, %ub: index) {
8+
// CHECK-DAG: scf.forall (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[LB]], %[[UB]])
9+
scf.forall (%arg0, %arg1) in (%lb, %ub) {
10+
// CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
11+
// CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
12+
// CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
13+
// CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
14+
%cst = arith.constant 0.0 : f32
15+
%mr0 = memref.alloca() : memref<2x32xf32>
16+
%mr1 = memref.alloca() : memref<2x32xf32>
17+
memref.store %cst, %mr0[%arg0, %arg1] : memref<2x32xf32>
18+
memref.store %cst, %mr1[%arg0, %arg1] : memref<2x32xf32>
19+
}
20+
return
21+
}
22+
23+
transform.sequence failures(propagate) {
24+
^bb1(%arg0: !transform.any_op):
25+
%alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
26+
: (!transform.any_op) -> !transform.op<"memref.alloca">
27+
%get_global, %global = transform.memref.alloca_to_global %alloca
28+
: (!transform.op<"memref.alloca">)
29+
-> (!transform.any_op, !transform.any_op)
30+
}
31+
32+
// -----
33+
334
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
435
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
536

mlir/test/python/dialects/transform_memref_ext.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,41 @@ def run(f):
1616
return f
1717

1818

19+
@run
20+
def testMemRefAllocaToAllocOpCompact():
21+
sequence = transform.SequenceOp(
22+
transform.FailurePropagationMode.Propagate,
23+
[],
24+
transform.OperationType.get("memref.alloca"),
25+
)
26+
with InsertionPoint(sequence.body):
27+
memref.MemRefAllocaToGlobalOp(sequence.bodyTarget)
28+
transform.YieldOp()
29+
# CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
30+
# CHECK: = transform.memref.alloca_to_global
31+
# CHECK-SAME: (!transform.op<"memref.alloca">)
32+
# CHECK-SAME: -> (!transform.any_op, !transform.any_op)
33+
34+
35+
@run
36+
def testMemRefAllocaToAllocOpTyped():
37+
sequence = transform.SequenceOp(
38+
transform.FailurePropagationMode.Propagate,
39+
[],
40+
transform.OperationType.get("memref.alloca"),
41+
)
42+
with InsertionPoint(sequence.body):
43+
memref.MemRefAllocaToGlobalOp(
44+
transform.OperationType.get("memref.get_global"),
45+
transform.OperationType.get("memref.global"),
46+
sequence.bodyTarget,
47+
)
48+
transform.YieldOp()
49+
# CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
50+
# CHECK: = transform.memref.alloca_to_global
51+
# CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
52+
53+
1954
@run
2055
def testMemRefMultiBufferOpCompact():
2156
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)