Skip to content

Commit 29d420e

Browse files
committed
[mlir][OpFormatGen] Add support for anchoring optional groups with types
This revision adds support for using either operand or result types to anchor an optional group. It also removes the arbitrary restriction that type directives must refer to variables in the same group, which is overly limiting for a declarative format syntax. Fixes PR#48784 Differential Revision: https://reviews.llvm.org/D95109
1 parent 975086b commit 29d420e

File tree

5 files changed

+113
-66
lines changed

5 files changed

+113
-66
lines changed

mlir/docs/OpDefinitions.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,8 @@ information. An optional group is defined by wrapping a set of elements within
777777
* The first element of the group must either be a attribute, literal, operand,
778778
or region.
779779
- This is because the first element must be optionally parsable.
780-
* Exactly one argument variable within the group must be marked as the anchor
781-
of the group.
780+
* Exactly one argument variable or type directive within the group must be
781+
marked as the anchor of the group.
782782
- The anchor is the element whose presence controls whether the group
783783
should be printed/parsed.
784784
- An element is marked as the anchor by adding a trailing `^`.
@@ -789,11 +789,9 @@ information. An optional group is defined by wrapping a set of elements within
789789
valid elements within the group.
790790
- Any attribute variable may be used, but only optional attributes can be
791791
marked as the anchor.
792-
- Only variadic or optional operand arguments can be used.
792+
- Only variadic or optional results and operand arguments and can be used.
793793
- All region variables can be used. When a non-variable length region is
794794
used, if the group is not present the region is empty.
795-
- The operands to a type directive must be defined within the optional
796-
group.
797795

798796
An example of an operation with an optional group is `std.return`, which has a
799797
variadic number of operands.

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,6 +1571,25 @@ def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
15711571
(`[` $variadic^ `]`)? attr-dict
15721572
}]>;
15731573

1574+
// Test optional result type formatting.
1575+
class FormatOptionalResultOpBase<string suffix, string fmt>
1576+
: TEST_Op<"format_optional_result_" # suffix # "_op",
1577+
[AttrSizedResultSegments]> {
1578+
let results = (outs Optional<I64>:$optional, Variadic<I64>:$variadic);
1579+
let assemblyFormat = fmt;
1580+
}
1581+
def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{
1582+
(`:` type($optional)^ `->` type($variadic))? attr-dict
1583+
}]>;
1584+
1585+
def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{
1586+
(`:` type($optional) `->` type($variadic)^)? attr-dict
1587+
}]>;
1588+
1589+
def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{
1590+
(`:` functional-type($optional, $variadic)^)? attr-dict
1591+
}]>;
1592+
15741593
def FormatTwoVariadicOperandsNoBuildableTypeOp
15751594
: TEST_Op<"format_two_variadic_operands_no_buildable_type_op",
15761595
[AttrSizedOperandSegments]> {

mlir/test/mlir-tblgen/op-format-spec.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
333333
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
334334
(type($operand) $operand^)? attr-dict
335335
}]>, Arguments<(ins Optional<I64>:$operand)>;
336-
// CHECK: error: type directive can only refer to variables within the optional group
336+
// CHECK: error: only literals, types, and variables can be used within an optional group
337337
def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
338338
(`,` $attr^ type(operands))? attr-dict
339339
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
@@ -349,9 +349,9 @@ def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
349349
def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
350350
($arg^) attr-dict
351351
}]>, Arguments<(ins I64:$arg)>;
352-
// CHECK: error: only variables can be used to anchor an optional group
352+
// CHECK: error: only literals, types, and variables can be used within an optional group
353353
def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{
354-
($arg type($arg)^) attr-dict
354+
(functional-type($arg, results)^)? attr-dict
355355
}]>, Arguments<(ins Variadic<I64>:$arg)>;
356356
// CHECK: error: only literals, types, and variables can be used within an optional group
357357
def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
@@ -361,11 +361,11 @@ def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
361361
def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
362362
($arg^)
363363
}]>, Arguments<(ins Variadic<I64>:$arg)>;
364-
// CHECK: error: only variables can be used to anchor an optional group
364+
// CHECK: error: only variables and types can be used to anchor an optional group
365365
def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
366366
(custom<MyDirective>($arg)^)?
367367
}]>, Arguments<(ins I64:$arg)>;
368-
// CHECK: error: only variables can be used to anchor an optional group
368+
// CHECK: error: only variables and types can be used to anchor an optional group
369369
def OptionalInvalidM : TestFormat_Op<"optional_invalid_m", [{
370370
(` `^)?
371371
}]>, Arguments<(ins)>;

mlir/test/mlir-tblgen/op-format.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,25 @@ test.format_optional_operand_result_b_op( : ) : i64
220220
// CHECK: test.format_optional_operand_result_b_op : i64
221221
test.format_optional_operand_result_b_op : i64
222222

223+
//===----------------------------------------------------------------------===//
224+
// Format optional results
225+
//===----------------------------------------------------------------------===//
226+
227+
// CHECK: test.format_optional_result_a_op
228+
test.format_optional_result_a_op
229+
230+
// CHECK: test.format_optional_result_a_op : i64 -> i64, i64
231+
test.format_optional_result_a_op : i64 -> i64, i64
232+
233+
// CHECK: test.format_optional_result_b_op
234+
test.format_optional_result_b_op
235+
236+
// CHECK: test.format_optional_result_b_op : i64 -> i64, i64
237+
test.format_optional_result_b_op : i64 -> i64, i64
238+
239+
// CHECK: test.format_optional_result_c_op : (i64) -> (i64, i64)
240+
test.format_optional_result_c_op : (i64) -> (i64, i64)
241+
223242
//===----------------------------------------------------------------------===//
224243
// Format custom directives
225244
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,33 @@ static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
17491749
" }\n";
17501750
}
17511751

1752+
/// Generate the check for the anchor of an optional group.
1753+
static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) {
1754+
TypeSwitch<Element *>(anchor)
1755+
.Case<OperandVariable, ResultVariable>([&](auto *element) {
1756+
const NamedTypeConstraint *var = element->getVar();
1757+
if (var->isOptional())
1758+
body << " if (" << var->name << "()) {\n";
1759+
else if (var->isVariadic())
1760+
body << " if (!" << var->name << "().empty()) {\n";
1761+
})
1762+
.Case<RegionVariable>([&](RegionVariable *element) {
1763+
const NamedRegion *var = element->getVar();
1764+
// TODO: Add a check for optional regions here when ODS supports it.
1765+
body << " if (!" << var->name << "().empty()) {\n";
1766+
})
1767+
.Case<TypeDirective>([&](TypeDirective *element) {
1768+
genOptionalGroupPrinterAnchor(element->getOperand(), body);
1769+
})
1770+
.Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1771+
genOptionalGroupPrinterAnchor(element->getInputs(), body);
1772+
})
1773+
.Case<AttributeVariable>([&](AttributeVariable *attr) {
1774+
body << " if ((*this)->getAttr(\"" << attr->getVar()->name
1775+
<< "\")) {\n";
1776+
});
1777+
}
1778+
17521779
void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
17531780
Operator &op, bool &shouldEmitSpace,
17541781
bool &lastWasPunctuation) {
@@ -1769,21 +1796,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
17691796
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
17701797
// Emit the check for the presence of the anchor element.
17711798
Element *anchor = optional->getAnchor();
1772-
if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
1773-
const NamedTypeConstraint *var = operand->getVar();
1774-
if (var->isOptional())
1775-
body << " if (" << var->name << "()) {\n";
1776-
else if (var->isVariadic())
1777-
body << " if (!" << var->name << "().empty()) {\n";
1778-
} else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
1779-
const NamedRegion *var = region->getVar();
1780-
// TODO: Add a check for optional here when ODS supports it.
1781-
body << " if (!" << var->name << "().empty()) {\n";
1782-
1783-
} else {
1784-
body << " if ((*this)->getAttr(\""
1785-
<< cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
1786-
}
1799+
genOptionalGroupPrinterAnchor(anchor, body);
17871800

17881801
// If the anchor is a unit attribute, we don't need to print it. When
17891802
// parsing, we will add this attribute if this group is present.
@@ -2244,8 +2257,9 @@ class FormatParser {
22442257
bool isTopLevel);
22452258
LogicalResult parseOptionalChildElement(
22462259
std::vector<std::unique_ptr<Element>> &childElements,
2247-
SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
22482260
Optional<unsigned> &anchorIdx);
2261+
LogicalResult verifyOptionalChildElement(Element *element,
2262+
llvm::SMLoc childLoc, bool isAnchor);
22492263

22502264
/// Parse the various different directives.
22512265
LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
@@ -2315,7 +2329,6 @@ class FormatParser {
23152329
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
23162330
llvm::DenseSet<const NamedRegion *> seenRegions;
23172331
llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2318-
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
23192332
};
23202333
} // end anonymous namespace
23212334

@@ -2760,10 +2773,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
27602773

27612774
// Parse the child elements for this optional group.
27622775
std::vector<std::unique_ptr<Element>> elements;
2763-
SmallPtrSet<const NamedTypeConstraint *, 8> seenVariables;
27642776
Optional<unsigned> anchorIdx;
27652777
do {
2766-
if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx)))
2778+
if (failed(parseOptionalChildElement(elements, anchorIdx)))
27672779
return ::mlir::failure();
27682780
} while (curToken.getKind() != Token::r_paren);
27692781
consumeToken();
@@ -2787,31 +2799,6 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
27872799
"first parsable element of an operand group must be "
27882800
"an attribute, literal, operand, or region");
27892801

2790-
// After parsing all of the elements, ensure that all type directives refer
2791-
// only to elements within the group.
2792-
auto checkTypeOperand = [&](Element *typeEle) {
2793-
auto *opVar = dyn_cast<OperandVariable>(typeEle);
2794-
const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr;
2795-
if (!seenVariables.count(var))
2796-
return emitError(curLoc, "type directive can only refer to variables "
2797-
"within the optional group");
2798-
return ::mlir::success();
2799-
};
2800-
for (auto &ele : elements) {
2801-
if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2802-
if (failed(checkTypeOperand(typeEle->getOperand())))
2803-
return failure();
2804-
} else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2805-
if (failed(checkTypeOperand(typeEle->getOperand())))
2806-
return ::mlir::failure();
2807-
} else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
2808-
if (failed(checkTypeOperand(typeEle->getInputs())) ||
2809-
failed(checkTypeOperand(typeEle->getResults())))
2810-
return ::mlir::failure();
2811-
}
2812-
}
2813-
2814-
optionalVariables.insert(seenVariables.begin(), seenVariables.end());
28152802
auto parseStart = parseBegin - elements.begin();
28162803
element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
28172804
parseStart);
@@ -2820,7 +2807,6 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
28202807

28212808
LogicalResult FormatParser::parseOptionalChildElement(
28222809
std::vector<std::unique_ptr<Element>> &childElements,
2823-
SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
28242810
Optional<unsigned> &anchorIdx) {
28252811
llvm::SMLoc childLoc = curToken.getLoc();
28262812
childElements.push_back({});
@@ -2837,7 +2823,14 @@ LogicalResult FormatParser::parseOptionalChildElement(
28372823
consumeToken();
28382824
}
28392825

2840-
return TypeSwitch<Element *, LogicalResult>(childElements.back().get())
2826+
return verifyOptionalChildElement(childElements.back().get(), childLoc,
2827+
isAnchor);
2828+
}
2829+
2830+
LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
2831+
llvm::SMLoc childLoc,
2832+
bool isAnchor) {
2833+
return TypeSwitch<Element *, LogicalResult>(element)
28412834
// All attributes can be within the optional group, but only optional
28422835
// attributes can be the anchor.
28432836
.Case([&](AttributeVariable *attrEle) {
@@ -2852,24 +2845,42 @@ LogicalResult FormatParser::parseOptionalChildElement(
28522845
if (!ele->getVar()->isVariableLength())
28532846
return emitError(childLoc, "only variable length operands can be "
28542847
"used within an optional group");
2855-
seenVariables.insert(ele->getVar());
2848+
return ::mlir::success();
2849+
})
2850+
// Only optional-like(i.e. variadic) results can be within an optional
2851+
// group.
2852+
.Case<ResultVariable>([&](ResultVariable *ele) {
2853+
if (!ele->getVar()->isVariableLength())
2854+
return emitError(childLoc, "only variable length results can be "
2855+
"used within an optional group");
28562856
return ::mlir::success();
28572857
})
28582858
.Case<RegionVariable>([&](RegionVariable *) {
28592859
// TODO: When ODS has proper support for marking "optional" regions, add
28602860
// a check here.
28612861
return ::mlir::success();
28622862
})
2863-
// Literals, whitespace, custom directives, and type directives may be
2864-
// used, but they can't anchor the group.
2865-
.Case<LiteralElement, WhitespaceElement, CustomDirective,
2866-
FunctionalTypeDirective, OptionalElement, TypeRefDirective,
2867-
TypeDirective>([&](Element *) {
2868-
if (isAnchor)
2869-
return emitError(childLoc, "only variables can be used to anchor "
2870-
"an optional group");
2871-
return ::mlir::success();
2863+
.Case<TypeDirective>([&](TypeDirective *ele) {
2864+
return verifyOptionalChildElement(ele->getOperand(), childLoc,
2865+
/*isAnchor=*/false);
28722866
})
2867+
.Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *ele) {
2868+
if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
2869+
/*isAnchor=*/false)))
2870+
return failure();
2871+
return verifyOptionalChildElement(ele->getResults(), childLoc,
2872+
/*isAnchor=*/false);
2873+
})
2874+
// Literals, whitespace, and custom directives may be used, but they can't
2875+
// anchor the group.
2876+
.Case<LiteralElement, WhitespaceElement, CustomDirective,
2877+
FunctionalTypeDirective, OptionalElement, TypeRefDirective>(
2878+
[&](Element *) {
2879+
if (isAnchor)
2880+
return emitError(childLoc, "only variables and types can be used "
2881+
"to anchor an optional group");
2882+
return ::mlir::success();
2883+
})
28732884
.Default([&](Element *) {
28742885
return emitError(childLoc, "only literals, types, and variables can be "
28752886
"used within an optional group");

0 commit comments

Comments
 (0)