Skip to content

Commit 6981f7e

Browse files
authored
[mlir] account for explicit affine.parallel in parallelization (#130812)
Affine parallelization should take explicitly parallel loops into account when computing loop depth for dependency analysis purposes. This was previously not the case, potentially leading to loops incorrectly being marked as parallel due to depth mismatch.
1 parent 554347b commit 6981f7e

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,8 @@ unsigned mlir::affine::getNestingDepth(Operation *op) {
19881988
while ((currOp = currOp->getParentOp())) {
19891989
if (isa<AffineForOp>(currOp))
19901990
depth++;
1991+
if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
1992+
depth += parOp.getNumDims();
19911993
}
19921994
return depth;
19931995
}

mlir/test/Dialect/Affine/parallelize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,23 @@ func.func @test_add_inv_or_terminal_symbol(%arg0: memref<9x9xi32>, %arg1: i1) {
341341
}
342342
return
343343
}
344+
345+
// Ensure that outer parallel loops are taken into account when computing the
346+
// loop depth in dependency analysis during parallelization. With correct
347+
// depth, the analysis should see the inner loop as sequential due to reads and
348+
// writes to the same address indexed by the outer (parallel) loop.
349+
//
350+
// CHECK-LABEL: @explicit_parallel
351+
func.func @explicit_parallel(%arg0: memref<1x123x194xf64>, %arg5: memref<34x99x194xf64>) {
352+
// CHECK: affine.parallel
353+
affine.parallel (%arg7, %arg8) = (0, 0) to (85, 180) {
354+
// CHECK: affine.for
355+
affine.for %arg9 = 0 to 18 {
356+
%0 = affine.load %arg0[0, %arg7 + 19, %arg8 + 7] : memref<1x123x194xf64>
357+
%1 = affine.load %arg5[%arg9 + 8, %arg7 + 7, %arg8 + 7] : memref<34x99x194xf64>
358+
%2 = arith.addf %0, %1 {fastmathFlags = #llvm.fastmath<none>} : f64
359+
affine.store %1, %arg0[0, %arg7 + 19, %arg8 + 7] : memref<1x123x194xf64>
360+
}
361+
}
362+
return
363+
}

0 commit comments

Comments
 (0)