Skip to content

Commit 47ec870

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] Revisit 0-D abstraction
This revision takes advantage of the empty AffineMap to specify the 0-D edge case. This allows removing a bunch of annoying corner cases that ended up impacting users of Linalg. Differential Revision: https://reviews.llvm.org/D75831
1 parent 4a0267e commit 47ec870

File tree

11 files changed

+56
-77
lines changed

11 files changed

+56
-77
lines changed

mlir/docs/Dialects/Affine.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ affine-expr ::= `(` affine-expr `)`
9191
| bare-id
9292
| `-`? integer-literal
9393
94-
multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
94+
multi-dim-affine-expr ::= `(` `)`
95+
| `(` affine-expr (`,` affine-expr)* `)`
9596
```
9697

9798
`ceildiv` is the ceiling function which maps the result of the division of its

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
184184
MLIRContext *context = getContext();
185185
auto r_i = getAffineDimExpr(0, context);
186186
return SmallVector<AffineMap, 8>{
187-
AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()};
187+
AffineMap::get(1, 0, {r_i}),
188+
AffineMap::get(1, 0, {r_i}),
189+
AffineMap::get(1, 0, context)};
188190
}
189191
}];
190192

mlir/include/mlir/IR/AffineMap.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ class AffineMap {
4444
/// Returns a zero result affine map with no dimensions or symbols: () -> ().
4545
static AffineMap get(MLIRContext *context);
4646

47+
/// Returns a zero result affine map with `dimCount` dimensions and
48+
/// `symbolCount` symbols, e.g.: `(...) -> ()`.
49+
static AffineMap get(unsigned dimCount, unsigned symbolCount,
50+
MLIRContext *context);
51+
4752
static AffineMap get(unsigned dimCount, unsigned symbolCount,
4853
ArrayRef<AffineExpr> results);
4954

@@ -275,8 +280,7 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
275280
namespace llvm {
276281

277282
// AffineExpr hash just like pointers
278-
template <>
279-
struct DenseMapInfo<mlir::AffineMap> {
283+
template <> struct DenseMapInfo<mlir::AffineMap> {
280284
static mlir::AffineMap getEmptyKey() {
281285
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
282286
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,15 +356,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
356356
<< idx << " to have " << nLoops
357357
<< " dim(s) to match the number of loops";
358358

359-
if (m.getNumResults() == 1 && view.getRank() == 0) {
360-
auto cst = m.getResult(0).template dyn_cast<AffineConstantExpr>();
361-
if (!cst || cst.getValue() != 0)
362-
return op.emitOpError("expected indexing_map #")
363-
<< idx << " to be 0 to match 0-D view: " << view;
364-
} else if (m.getNumResults() != view.getRank()) {
359+
if (m.getNumResults() != view.getRank())
365360
return op.emitOpError("expected indexing_map #")
366361
<< idx << " results to match view rank: " << view;
367-
}
368362
}
369363

370364
auto concatMap = concatAffineMaps(indexingMaps);
@@ -886,7 +880,7 @@ AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
886880
if (maybeMap)
887881
return maybeMap.getValue();
888882
if (rank == 0)
889-
return AffineMap();
883+
return AffineMap::get(context);
890884
return AffineMap::getMultiDimIdentityMap(rank, context);
891885
}
892886

mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ using edsc::op::operator==;
3737
static SmallVector<ValueHandle, 8>
3838
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
3939
ArrayRef<Value> vals) {
40+
if (map.isEmpty())
41+
return {};
4042
assert(map.getNumSymbols() == 0);
4143
assert(map.getNumInputs() == vals.size());
4244
SmallVector<ValueHandle, 8> res;
@@ -241,26 +243,17 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
241243

242244
// 1.a. Emit std_load from input views.
243245
for (unsigned i = 0; i < nInputs; ++i) {
244-
Value input = genericOp.getInput(i);
245-
if (input.getType().cast<ShapedType>().getRank()) {
246-
ValueHandleArray indexing(makeCanonicalAffineApplies(
247-
b, loc, genericOp.getInputIndexingMap(i), allIvs));
248-
indexedValues[i] = std_load(input, indexing);
249-
} else {
250-
indexedValues[i] = std_load(input);
251-
}
246+
ValueHandleArray indexing(makeCanonicalAffineApplies(
247+
b, loc, genericOp.getInputIndexingMap(i), allIvs));
248+
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
252249
}
253250

254251
// 1.b. Emit std_load from output views.
255252
for (unsigned i = 0; i < nOutputs; ++i) {
256253
Value output = genericOp.getOutputBuffer(i);
257-
if (output.getType().cast<ShapedType>().getRank()) {
258-
ValueHandleArray indexing(makeCanonicalAffineApplies(
259-
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
260-
indexedValues[nInputs + i] = std_load(output, indexing);
261-
} else {
262-
indexedValues[nInputs + i] = std_load(output);
263-
}
254+
ValueHandleArray indexing(makeCanonicalAffineApplies(
255+
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
256+
indexedValues[nInputs + i] = std_load(output, indexing);
264257
}
265258

266259
auto funcOp = genericOp.getFunction();
@@ -272,13 +265,9 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
272265
// 3. Emit std_store.
273266
for (unsigned i = 0; i < nOutputs; ++i) {
274267
Value output = genericOp.getOutputBuffer(i);
275-
if (output.getType().cast<ShapedType>().getRank()) {
276-
ValueHandleArray indexing(makeCanonicalAffineApplies(
277-
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
278-
std_store(callOp->getResult(i), output, indexing);
279-
} else {
280-
std_store(callOp->getResult(i), output);
281-
}
268+
ValueHandleArray indexing(makeCanonicalAffineApplies(
269+
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
270+
std_store(callOp->getResult(i), output, indexing);
282271
}
283272
return;
284273
}
@@ -297,15 +286,10 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
297286
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
298287
assert(yieldOp->getNumOperands() == nOutputs);
299288
for (unsigned i = 0; i < nOutputs; ++i) {
300-
Value output = genericOp.getOutputBuffer(i);
301-
if (output.getType().cast<ShapedType>().getRank()) {
302-
ValueHandleArray indexing(makeCanonicalAffineApplies(
303-
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
304-
std_store(map.lookup(yieldOp->getOperand(i)),
305-
genericOp.getOutputBuffer(i), indexing);
306-
} else {
307-
std_store(map.lookup(yieldOp->getOperand(i)), output);
308-
}
289+
ValueHandleArray indexing(makeCanonicalAffineApplies(
290+
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
291+
std_store(map.lookup(yieldOp->getOperand(i)),
292+
genericOp.getOutputBuffer(i), indexing);
309293
}
310294
}
311295
};

mlir/lib/IR/AffineMap.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ AffineMap AffineMap::compose(AffineMap map) {
281281
exprs.reserve(getResults().size());
282282
for (auto expr : getResults())
283283
exprs.push_back(expr.compose(newMap));
284-
return AffineMap::get(numDims, numSymbols, exprs);
284+
return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext())
285+
: AffineMap::get(numDims, numSymbols, exprs);
285286
}
286287

287288
bool AffineMap::isProjectedPermutation() {
@@ -325,7 +326,7 @@ AffineMap mlir::simplifyAffineMap(AffineMap map) {
325326
}
326327

327328
AffineMap mlir::inversePermutation(AffineMap map) {
328-
if (!map)
329+
if (map.isEmpty())
329330
return map;
330331
assert(map.getNumSymbols() == 0 && "expected map without symbols");
331332
SmallVector<AffineExpr, 4> exprs(map.getNumDims());
@@ -351,18 +352,18 @@ AffineMap mlir::inversePermutation(AffineMap map) {
351352
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
352353
unsigned numResults = 0;
353354
for (auto m : maps)
354-
numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
355+
numResults += m.getNumResults();
355356
unsigned numDims = 0;
356357
SmallVector<AffineExpr, 8> results;
357358
results.reserve(numResults);
358359
for (auto m : maps) {
359-
if (!m || m.isSingleConstant())
360-
continue;
361360
assert(m.getNumSymbols() == 0 && "expected map without symbols");
362361
results.append(m.getResults().begin(), m.getResults().end());
363362
numDims = std::max(m.getNumDims(), numDims);
364363
}
365-
return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results);
364+
return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0,
365+
maps.front().getContext())
366+
: AffineMap::get(numDims, /*numSymbols=*/0, results);
366367
}
367368

368369
//===----------------------------------------------------------------------===//

mlir/lib/IR/MLIRContext.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,11 @@ AffineMap AffineMap::get(MLIRContext *context) {
611611
return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
612612
}
613613

614+
AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
615+
MLIRContext *context) {
616+
return getImpl(dimCount, /*symbolCount=*/0, /*results=*/{}, context);
617+
}
618+
614619
AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
615620
ArrayRef<AffineExpr> results) {
616621
// The number of results can't be zero.

mlir/lib/Parser/Parser.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,14 +3068,16 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
30683068
};
30693069

30703070
// Parse a multi-dimensional affine expression (a comma-separated list of
3071-
// 1-d affine expressions); the list cannot be empty. Grammar:
3072-
// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
3071+
// 1-d affine expressions); the list can be empty. Grammar:
3072+
// multi-dim-affine-expr ::= `(` `)`
3073+
// | `(` affine-expr (`,` affine-expr)* `)`
30733074
if (parseCommaSeparatedListUntil(rightToken, parseElt,
30743075
/*allowEmptyList=*/true))
30753076
return failure();
30763077
// Parsed a valid affine map.
30773078
if (exprs.empty())
3078-
map = AffineMap::get(getContext());
3079+
map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
3080+
getContext());
30793081
else
30803082
map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
30813083
exprs);
@@ -3101,13 +3103,14 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
31013103
};
31023104

31033105
// Parse a multi-dimensional affine expression (a comma-separated list of
3104-
// 1-d affine expressions); the list cannot be empty. Grammar:
3105-
// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
3106+
// 1-d affine expressions). Grammar:
3107+
// multi-dim-affine-expr ::= `(` `)`
3108+
// | `(` affine-expr (`,` affine-expr)* `)`
31063109
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
31073110
return AffineMap();
31083111

31093112
if (exprs.empty())
3110-
return AffineMap::get(getContext());
3113+
return AffineMap::get(numDims, numSymbols, getContext());
31113114

31123115
// Parsed a valid affine map.
31133116
return AffineMap::get(numDims, numSymbols, exprs);

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -170,30 +170,15 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
170170

171171
func @foo(%0: i32) -> i32 { return %0: i32 }
172172

173-
func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
173+
func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
174174
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
175175
linalg.generic {
176176
args_in = 0,
177177
args_out = 1,
178178
fun = @foo,
179179
indexing_maps = [ affine_map<() -> (0)> ],
180180
iterator_types = ["parallel"]
181-
} %arg0: memref<i32>
182-
}
183-
184-
// -----
185-
186-
func @foo(%0: i32) -> i32 { return %0: i32 }
187-
188-
func @generic_zero_d_view(%arg0: memref<i32>) {
189-
// expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
190-
linalg.generic {
191-
args_in = 0,
192-
args_out = 1,
193-
fun = @foo,
194-
indexing_maps = [ affine_map<() -> (1)> ],
195-
iterator_types = []
196-
} %arg0: memref<i32>
181+
} %arg0: memref<1xi32>
197182
}
198183

199184
// -----

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ func @indexed_generic_region(
360360
// -----
361361

362362
#broadcast_access = [
363-
affine_map<(i, j) -> (0)>,
363+
affine_map<(i, j) -> ()>,
364364
affine_map<(i, j) -> (i, j)>
365365
]
366366

@@ -414,7 +414,7 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
414414

415415
#reduce_1D_access = [
416416
affine_map<(i) -> (i)>,
417-
affine_map<(i) -> (0)>
417+
affine_map<(i) -> ()>
418418
]
419419

420420
#trait_reduce_1D = {
@@ -446,8 +446,8 @@ func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
446446

447447
#reduce_init_1D_access = [
448448
affine_map<(i) -> (i)>,
449-
affine_map<(i) -> (0)>,
450-
affine_map<(i) -> (0)>
449+
affine_map<(i) -> ()>,
450+
affine_map<(i) -> ()>
451451
]
452452

453453
#trait_reduce_init_1D = {

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ func @indexed_generic_with_tensor_input_and_output(
346346
// -----
347347

348348
#broadcast_access = [
349-
affine_map<(i, j) -> (0)>,
349+
affine_map<(i, j) -> ()>,
350350
affine_map<(i, j) -> (i, j)>
351351
]
352352

0 commit comments

Comments
 (0)