Skip to content

Commit 4882cac

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Adapt FillOp to use a scalar operand.
Adapt the FillOp definition to use a scalar operand instead of a capture. This patch is a follow up to https://reviews.llvm.org/D104109. As the input operands are in front of the output operands the patch changes the internal operand order of the FillOp. The pretty printed version of the operation remains unchanged though. The patch also adapts the linalg to standard lowering to ensure the c signature of the FillOp remains unchanged as well. Differential Revision: https://reviews.llvm.org/D104121
1 parent f14e6e4 commit 4882cac

File tree

8 files changed

+65
-53
lines changed

8 files changed

+65
-53
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
175175
}
176176

177177
def FillOp : LinalgStructured_Op<"fill", []> {
178-
let arguments = (ins AnyShaped:$output,
179-
AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger,
180-
AnyVector]>:$value);
178+
let arguments = (ins
179+
AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, AnyVector]>:$value,
180+
AnyShaped:$output);
181181
let results = (outs Optional<AnyRankedTensor>:$result);
182182
let regions = (region AnyRegion:$region);
183183
let extraClassDeclaration = structuredOpsDecls # [{
184-
ValueRange inputs() { return {}; }
185-
ValueRange outputs() { return getOperands().take_front(); }
184+
ValueRange inputs() { return getOperands().take_front(); }
185+
ValueRange outputs() { return getOperands().take_back(); }
186186

187187
// Rank-polymorphic.
188188
// filling_value -> O(ivs) with parallel iterators.
@@ -196,6 +196,7 @@ def FillOp : LinalgStructured_Op<"fill", []> {
196196
MLIRContext *context = getContext();
197197
// filling_value -> O(ivs)
198198
return Builder(getContext()).getAffineMapArrayAttr({
199+
AffineMap::get(getNumParallelLoops(), 0, {}, getContext()),
199200
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
200201
}
201202

@@ -206,13 +207,13 @@ def FillOp : LinalgStructured_Op<"fill", []> {
206207
getRegionBuilder() {
207208
return &regionBuilder;
208209
}
209-
static unsigned getNumRegionArgs() { return 1; }
210+
static unsigned getNumRegionArgs() { return 2; }
210211
}];
211212

212213
let assemblyFormat = [{
213214
`(` $output `,` $value `)` attr-dict `:`
214215
type($output) `,` type($value) (`->` type($result)^)?
215-
custom<FillOpRegion>($region, ref(type($output)), ref($value))
216+
custom<FillOpRegion>($region, ref(type($output)), ref(type($value)))
216217
}];
217218

218219
let builders = [

mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,21 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
100100
if (isa<CopyOp>(op))
101101
return failure();
102102

103+
// Swap the operand order of the FillOp to maintain the pretty printed
104+
// signature that takes an output buffer followed by the fill value.
105+
SmallVector<Value> originalOperandOrder = op->getOperands();
106+
if (auto fillOp = dyn_cast<FillOp>(op.getOperation())) {
107+
Value value = fillOp.value();
108+
Value output = fillOp.output();
109+
op->setOperands(ValueRange{output, value});
110+
}
111+
103112
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
104-
if (!libraryCallName)
113+
if (!libraryCallName) {
114+
// Restore the operand order in case it has been modified.
115+
op->setOperands(originalOperandOrder);
105116
return failure();
117+
}
106118

107119
// TODO: Add support for more complex library call signatures that include
108120
// indices or captured values.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,32 +421,29 @@ void CopyOp::getEffects(
421421
//===----------------------------------------------------------------------===//
422422
void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
423423
ValueRange captures) {
424-
assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
425-
b.create<linalg::YieldOp>(captures);
424+
assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
425+
b.create<linalg::YieldOp>(block.getArgument(0));
426426
}
427427

428428
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
429429
Value value) {
430-
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
431-
value);
432-
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
433-
TypeRange{output.getType()}, value);
430+
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
431+
output);
432+
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
433+
TypeRange{value.getType()},
434+
TypeRange{output.getType()}, {});
434435
}
435436

436437
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
437-
OpAsmParser::OperandType valueRef) {
438+
Type valueType) {
438439
OpBuilder opBuilder(parser.getBuilder().getContext());
439-
// Resolve `valueRef` into `value` at parse time so we can build the region
440-
// with captures.
441-
SmallVector<Value> value;
442-
parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
443-
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
444-
TypeRange{outputType}, value);
440+
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
441+
TypeRange{outputType});
445442
return success();
446443
}
447444

448445
/// FillOp region is elided when printing.
449-
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
446+
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
450447

451448
static LogicalResult verify(FillOp op) {
452449
OpOperand *output = op.getOutputOperand(0);

mlir/python/mlir/dialects/_linalg_ops_ext.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from _mlir.dialects.linalg import fill_builtin_region
1111

1212

13-
def isa(cls : Type, ty : Type):
13+
def isa(cls: Type, ty: Type):
1414
try:
1515
cls(ty)
1616
return True
@@ -21,23 +21,19 @@ def isa(cls : Type, ty : Type):
2121
class FillOp:
2222
"""Extends the linalg.fill op."""
2323

24-
def __init__(self,
25-
output: Value,
26-
value: Value,
27-
*,
28-
loc=None,
29-
ip=None):
24+
def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
3025
results = []
3126
if isa(RankedTensorType, output.type):
3227
results = [output.type]
33-
op = self.build_generic(results=results,
34-
operands=[output, value],
35-
attributes=None,
36-
loc=loc,
37-
ip=ip)
28+
op = self.build_generic(
29+
results=results,
30+
operands=[value, output],
31+
attributes=None,
32+
loc=loc,
33+
ip=ip)
3834
OpView.__init__(self, op)
3935
linalgDialect = Context.current.get_dialect_descriptor("linalg")
40-
fill_builtin_region(linalgDialect, self.operation, [value])
36+
fill_builtin_region(linalgDialect, self.operation, [])
4137
# TODO: self.result is None. When len(results) == 1 we expect it to be
4238
# results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
4339
# in the generator of _linalg_ops_gen.py where we have:
@@ -78,11 +74,12 @@ def __init__(self,
7874
attributes["static_sizes"] = ArrayAttr.get(
7975
[IntegerAttr.get(i64_type, s) for s in static_size_ints],
8076
context=context)
81-
op = self.build_generic(results=[result_type],
82-
operands=operands,
83-
attributes=attributes,
84-
loc=loc,
85-
ip=ip)
77+
op = self.build_generic(
78+
results=[result_type],
79+
operands=operands,
80+
attributes=attributes,
81+
loc=loc,
82+
ip=ip)
8683
OpView.__init__(self, op)
8784

8885

@@ -91,10 +88,11 @@ class StructuredOpMixin:
9188

9289
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
9390
super().__init__(
94-
self.build_generic(results=list(results),
95-
operands=[list(inputs), list(outputs)],
96-
loc=loc,
97-
ip=ip))
91+
self.build_generic(
92+
results=list(results),
93+
operands=[list(inputs), list(outputs)],
94+
loc=loc,
95+
ip=ip))
9896

9997

10098
def select_opview_mixin(parent_opview_cls):

mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,15 @@ module {
296296
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
297297
// TLOOP-SAME: step (%[[C32]], %[[C64]])
298298
// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
299-
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]])
299+
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]],
300+
// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]]
300301
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
301302

302303
// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
303304
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
304305
// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
305306
// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
306-
// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
307+
// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32_]])
307308

308309
// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
309310
// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
@@ -398,3 +399,4 @@ module {
398399
// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
399400
// TLOOP: }
400401
// TLOOP: return %[[AB]] : [[TY]]
402+

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,15 +476,17 @@ func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
476476
return
477477
}
478478

479-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
479+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
480+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
480481

481482
// CHECK: func @generalize_fill
482483
// CHECK-SAME: (%[[ARG0:.+]]: memref<?x?xf32>, %[[VAL:.+]]: f32)
483484

484485
// CHECK: linalg.generic
485-
// CHECK-SAME: indexing_maps = [#[[MAP0]]]
486+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
486487
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
488+
// CHECK-SAME: ins(%[[VAL]] : f32)
487489
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
488490

489-
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32)
490-
// CHECK-NEXT: linalg.yield %[[VAL]] : f32
491+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
492+
// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ func @illegal_fill_memref_with_tensor_return
668668
func @illegal_fill_tensor_with_memref_return
669669
(%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
670670
{
671-
// expected-error @+1 {{expected type of operand #0 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
671+
// expected-error @+1 {{expected type of operand #1 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
672672
%0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32>
673673
return %0 : memref<?x?xf32>
674674
}

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ static void applyPatterns(FuncOp funcOp) {
235235
patterns.add<LinalgPromotionPattern<FillOp>>(
236236
ctx,
237237
LinalgPromotionOptions()
238-
.setOperandsToPromote({0})
239-
.setUseFullTileBuffers({true})
238+
.setOperandsToPromote({1})
239+
.setUseFullTileBuffers({false, true})
240240
.setAlignment(32),
241241
LinalgTransformationFilter(
242242
Identifier::get("_promote_views_aligned_", ctx),

0 commit comments

Comments
 (0)