Skip to content

Commit 5d8813d

Browse files
committed
[mlir] allow dense dataflow to customize call and region operations
Initial implementations of dense dataflow analyses feature special cases for operations that have region- or call-based control flow by leveraging the corresponding interfaces. This is not necessarily sufficient as these operations may influence the dataflow state by themselves as well we through the control flow. For example, `linalg.generic` and similar operations have region-based control flow and their proper memory effects, so any memory-related analyses such as last-writer require processing `linalg.generic` directly instead of, or in addition to, the region-based flow. Provide hooks to customize the processing of operations with region- cand call-based contol flow in forward and backward dense dataflow analysis. These hooks are trigerred when control flow is transferred between the "main" operation, i.e. the call or the region owner, and another region. Such an apporach allows the analyses to update the lattice before and/or after the regions. In the `linalg.generic` example, the reads from memory are interpreted as happening before the body region and the writes to memory are interpreted as happening after the body region. Using these hooks in generic analysis may require introducing additional interfaces, but for now assume that the specific analysis have spceial cases for the (rare) operaitons with call- and region-based control flow that need additional processing. Reviewed By: Mogball, phisiart Differential Revision: https://reviews.llvm.org/D155757
1 parent f6bdfb0 commit 5d8813d

File tree

9 files changed

+942
-113
lines changed

9 files changed

+942
-113
lines changed

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

Lines changed: 227 additions & 14 deletions
Large diffs are not rendered by default.

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 139 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,38 @@ LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
4242
return success();
4343
}
4444

45+
void AbstractDenseDataFlowAnalysis::visitCallOperation(
46+
CallOpInterface call, AbstractDenseLattice *after) {
47+
48+
const auto *predecessors =
49+
getOrCreateFor<PredecessorState>(call.getOperation(), call);
50+
// If not all return sites are known, then conservatively assume we can't
51+
// reason about the data-flow.
52+
if (!predecessors->allPredecessorsKnown())
53+
return setToEntryState(after);
54+
55+
for (Operation *predecessor : predecessors->getKnownPredecessors()) {
56+
// Get the lattices at callee return:
57+
//
58+
// func.func @callee() {
59+
// ...
60+
// return // predecessor
61+
// // latticeAtCalleeReturn
62+
// }
63+
// func.func @caller() {
64+
// ...
65+
// call @callee
66+
// // latticeAfterCall
67+
// ...
68+
// }
69+
AbstractDenseLattice *latticeAfterCall = after;
70+
const AbstractDenseLattice *latticeAtCalleeReturn =
71+
getLatticeFor(call.getOperation(), predecessor);
72+
visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee,
73+
*latticeAtCalleeReturn, latticeAfterCall);
74+
}
75+
}
76+
4577
void AbstractDenseDataFlowAnalysis::processOperation(Operation *op) {
4678
// If the containing block is not executable, bail out.
4779
if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
@@ -50,30 +82,22 @@ void AbstractDenseDataFlowAnalysis::processOperation(Operation *op) {
5082
// Get the dense lattice to update.
5183
AbstractDenseLattice *after = getLattice(op);
5284

85+
// Get the dense state before the execution of the op.
86+
const AbstractDenseLattice *before;
87+
if (Operation *prev = op->getPrevNode())
88+
before = getLatticeFor(op, prev);
89+
else
90+
before = getLatticeFor(op, op->getBlock());
91+
5392
// If this op implements region control-flow, then control-flow dictates its
5493
// transfer function.
5594
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
5695
return visitRegionBranchOperation(op, branch, after);
5796

5897
// If this is a call operation, then join its lattices across known return
5998
// sites.
60-
if (auto call = dyn_cast<CallOpInterface>(op)) {
61-
const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
62-
// If not all return sites are known, then conservatively assume we can't
63-
// reason about the data-flow.
64-
if (!predecessors->allPredecessorsKnown())
65-
return setToEntryState(after);
66-
for (Operation *predecessor : predecessors->getKnownPredecessors())
67-
join(after, *getLatticeFor(op, predecessor));
68-
return;
69-
}
70-
71-
// Get the dense state before the execution of the op.
72-
const AbstractDenseLattice *before;
73-
if (Operation *prev = op->getPrevNode())
74-
before = getLatticeFor(op, prev);
75-
else
76-
before = getLatticeFor(op, op->getBlock());
99+
if (auto call = dyn_cast<CallOpInterface>(op))
100+
return visitCallOperation(call, after);
77101

78102
// Invoke the operation transfer function.
79103
visitOperationImpl(op, *before, after);
@@ -100,10 +124,15 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
100124
return setToEntryState(after);
101125
for (Operation *callsite : callsites->getKnownPredecessors()) {
102126
// Get the dense lattice before the callsite.
127+
const AbstractDenseLattice *before;
103128
if (Operation *prev = callsite->getPrevNode())
104-
join(after, *getLatticeFor(block, prev));
129+
before = getLatticeFor(block, prev);
105130
else
106-
join(after, *getLatticeFor(block, callsite->getBlock()));
131+
before = getLatticeFor(block, callsite->getBlock());
132+
133+
visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
134+
CallControlFlowAction::EnterCallee,
135+
*before, after);
107136
}
108137
return;
109138
}
@@ -152,7 +181,41 @@ void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation(
152181
} else {
153182
before = getLatticeFor(point, op);
154183
}
155-
join(after, *before);
184+
185+
// This function is called in two cases:
186+
// 1. when visiting the block (point = block);
187+
// 2. when visiting the parent operation (point = parent op).
188+
// In both cases, we are looking for predecessor operations of the point,
189+
// 1. predecessor may be the terminator of another block from another
190+
// region (assuming that the block does belong to another region via an
191+
// assertion) or the parent (when parent can transfer control to this
192+
// region);
193+
// 2. predecessor may be the terminator of a block that exits the
194+
// region (when region transfers control to the parent) or the operation
195+
// before the parent.
196+
// In the latter case, just perform the join as it isn't the control flow
197+
// affected by the region.
198+
std::optional<unsigned> regionFrom =
199+
op == branch ? std::optional<unsigned>()
200+
: op->getBlock()->getParent()->getRegionNumber();
201+
if (auto *toBlock = point.dyn_cast<Block *>()) {
202+
assert(op == branch ||
203+
toBlock->getParent() != op->getBlock()->getParent());
204+
unsigned regionTo = toBlock->getParent()->getRegionNumber();
205+
visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
206+
*before, after);
207+
} else {
208+
assert(point.get<Operation *>() == branch &&
209+
"expected to be visiting the branch itself");
210+
// Only need to call the arc transfer when the predecessor is the region
211+
// or the op itself, not the previous op.
212+
if (op->getParentOp() == branch || op == branch) {
213+
visitRegionBranchControlFlowTransfer(
214+
branch, regionFrom, /*regionTo=*/std::nullopt, *before, after);
215+
} else {
216+
join(after, *before);
217+
}
218+
}
156219
}
157220
}
158221

@@ -194,6 +257,44 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
194257
return success();
195258
}
196259

260+
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
261+
CallOpInterface call, AbstractDenseLattice *before) {
262+
// Find the callee.
263+
Operation *callee = call.resolveCallable(&symbolTable);
264+
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
265+
if (!callable)
266+
return setToExitState(before);
267+
268+
// No region means the callee is only declared in this module and we shouldn't
269+
// assume anything about it.
270+
Region *region = callable.getCallableRegion();
271+
if (!region || region->empty())
272+
return setToExitState(before);
273+
274+
// Call-level control flow specifies the data flow here.
275+
//
276+
// func.func @callee() {
277+
// ^calleeEntryBlock:
278+
// // latticeAtCalleeEntry
279+
// ...
280+
// }
281+
// func.func @caller() {
282+
// ...
283+
// // latticeBeforeCall
284+
// call @callee
285+
// ...
286+
// }
287+
Block *calleeEntryBlock = &region->front();
288+
ProgramPoint calleeEntry = calleeEntryBlock->empty()
289+
? ProgramPoint(calleeEntryBlock)
290+
: &calleeEntryBlock->front();
291+
const AbstractDenseLattice &latticeAtCalleeEntry =
292+
*getLatticeFor(call.getOperation(), calleeEntry);
293+
AbstractDenseLattice *latticeBeforeCall = before;
294+
visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee,
295+
latticeAtCalleeEntry, latticeBeforeCall);
296+
}
297+
197298
void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
198299
// If the containing block is not executable, bail out.
199300
if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
@@ -202,46 +303,19 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
202303
// Get the dense lattice to update.
203304
AbstractDenseLattice *before = getLattice(op);
204305

205-
// If the op implements region control flow, then the interface specifies the
206-
// control function.
207-
// TODO: this is not always true, e.g. linalg.generic, but is implement this
208-
// way for consistency with the dense forward analysis.
209-
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
210-
return visitRegionBranchOperation(op, branch, std::nullopt, before);
211-
212-
// If the op is a call-like, do inter-procedural data flow as follows:
213-
//
214-
// - find the callable (resolve via the symbol table),
215-
// - get the entry block of the callable region,
216-
// - take the state before the first operation if present or at block end
217-
// otherwise,
218-
// - meet that state with the state before the call-like op.
219-
if (auto call = dyn_cast<CallOpInterface>(op)) {
220-
Operation *callee = call.resolveCallable(&symbolTable);
221-
if (auto callable = dyn_cast<CallableOpInterface>(callee)) {
222-
Region *region = callable.getCallableRegion();
223-
if (region && !region->empty()) {
224-
Block *entryBlock = &region->front();
225-
if (entryBlock->empty())
226-
meet(before, *getLatticeFor(op, entryBlock));
227-
else
228-
meet(before, *getLatticeFor(op, &entryBlock->front()));
229-
} else {
230-
setToExitState(before);
231-
}
232-
} else {
233-
setToExitState(before);
234-
}
235-
return;
236-
}
237-
238306
// Get the dense state after execution of this op.
239307
const AbstractDenseLattice *after;
240308
if (Operation *next = op->getNextNode())
241309
after = getLatticeFor(op, next);
242310
else
243311
after = getLatticeFor(op, op->getBlock());
244312

313+
// Special cases where control flow may dictate data flow.
314+
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
315+
return visitRegionBranchOperation(op, branch, std::nullopt, before);
316+
if (auto call = dyn_cast<CallOpInterface>(op))
317+
return visitCallOperation(call, before);
318+
245319
// Invoke the operation transfer function.
246320
visitOperationImpl(op, *after, before);
247321
}
@@ -280,16 +354,20 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
280354
return setToExitState(before);
281355

282356
for (Operation *callsite : callsites->getKnownPredecessors()) {
357+
const AbstractDenseLattice *after;
283358
if (Operation *next = callsite->getNextNode())
284-
meet(before, *getLatticeFor(block, next));
359+
after = getLatticeFor(block, next);
285360
else
286-
meet(before, *getLatticeFor(block, callsite->getBlock()));
361+
after = getLatticeFor(block, callsite->getBlock());
362+
visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
363+
CallControlFlowAction::ExitCallee, *after,
364+
before);
287365
}
288366
return;
289367
}
290368

291369
// If this block is exiting from an operation with region-based control
292-
// flow, follow that flow.
370+
// flow, propagate the lattice back along the control flow edge.
293371
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
294372
visitRegionBranchOperation(block, branch,
295373
block->getParent()->getRegionNumber(), before);
@@ -346,7 +424,11 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
346424
else
347425
after = getLatticeFor(point, &successorBlock->front());
348426
}
349-
meet(before, *after);
427+
std::optional<unsigned> successorNo =
428+
successor.isParent() ? std::optional<unsigned>()
429+
: successor.getSuccessor()->getRegionNumber();
430+
visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
431+
before);
350432
}
351433
}
352434

mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s
1+
// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 | FileCheck %s
22

33
// CHECK-LABEL: test_tag: test_callsite
44
// CHECK: operand #0
@@ -64,4 +64,84 @@ func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref<i32>
6464
func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
6565
%0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref<i32>) -> memref<i32>
6666
return {tag = "test_multiple_return_sites"} %0 : memref<i32>
67-
}
67+
}
68+
69+
// -----
70+
71+
72+
func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
73+
%2 = arith.constant 2.0 : f32
74+
memref.load %arg0[] {tag = "call_and_store_before::enter_callee"} : memref<f32>
75+
memref.store %2, %arg0[] {tag_name = "callee"} : memref<f32>
76+
memref.load %arg0[] {tag = "exit_callee"} : memref<f32>
77+
return %arg0 : memref<f32>
78+
}
79+
// In this test, the "call" operation also stores to %arg0 itself before
80+
// transferring control flow to the callee. Therefore, the order of accesses is
81+
// "pre" -> "call" -> "callee" -> "post"
82+
83+
// CHECK-LABEL: test_tag: call_and_store_before::enter_callee:
84+
// CHECK: operand #0
85+
// CHECK: - call
86+
// CHECK: test_tag: exit_callee:
87+
// CHECK: operand #0
88+
// CHECK: - callee
89+
// CHECK: test_tag: before_call:
90+
// CHECK: operand #0
91+
// CHECK: - pre
92+
// CHECK: test_tag: after_call:
93+
// CHECK: operand #0
94+
// CHECK: - callee
95+
// CHECK: test_tag: return:
96+
// CHECK: operand #0
97+
// CHECK: - post
98+
func.func @call_and_store_before(%arg0: memref<f32>) -> memref<f32> {
99+
%0 = arith.constant 0.0 : f32
100+
%1 = arith.constant 1.0 : f32
101+
memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
102+
memref.load %arg0[] {tag = "before_call"} : memref<f32>
103+
test.call_and_store @callee(%arg0), %arg0 {tag_name = "call", store_before_call = true} : (memref<f32>, memref<f32>) -> ()
104+
memref.load %arg0[] {tag = "after_call"} : memref<f32>
105+
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
106+
return {tag = "return"} %arg0 : memref<f32>
107+
}
108+
109+
// -----
110+
111+
func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
112+
%2 = arith.constant 2.0 : f32
113+
memref.load %arg0[] {tag = "call_and_store_after::enter_callee"} : memref<f32>
114+
memref.store %2, %arg0[] {tag_name = "callee"} : memref<f32>
115+
memref.load %arg0[] {tag = "exit_callee"} : memref<f32>
116+
return %arg0 : memref<f32>
117+
}
118+
119+
// In this test, the "call" operation also stores to %arg0 itself after getting
120+
// control flow back from the callee. Therefore, the order of accesses is
121+
// "pre" -> "callee" -> "call" -> "post"
122+
123+
// CHECK-LABEL: test_tag: call_and_store_after::enter_callee:
124+
// CHECK: operand #0
125+
// CHECK: - pre
126+
// CHECK: test_tag: exit_callee:
127+
// CHECK: operand #0
128+
// CHECK: - callee
129+
// CHECK: test_tag: before_call:
130+
// CHECK: operand #0
131+
// CHECK: - pre
132+
// CHECK: test_tag: after_call:
133+
// CHECK: operand #0
134+
// CHECK: - call
135+
// CHECK: test_tag: return:
136+
// CHECK: operand #0
137+
// CHECK: - post
138+
func.func @call_and_store_after(%arg0: memref<f32>) -> memref<f32> {
139+
%0 = arith.constant 0.0 : f32
140+
%1 = arith.constant 1.0 : f32
141+
memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
142+
memref.load %arg0[] {tag = "before_call"} : memref<f32>
143+
test.call_and_store @callee(%arg0), %arg0 {tag_name = "call", store_before_call = false} : (memref<f32>, memref<f32>) -> ()
144+
memref.load %arg0[] {tag = "after_call"} : memref<f32>
145+
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
146+
return {tag = "return"} %arg0 : memref<f32>
147+
}

0 commit comments

Comments
 (0)