@@ -1749,6 +1749,33 @@ static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
1749
1749
" }\n " ;
1750
1750
}
1751
1751
1752
+ // / Generate the check for the anchor of an optional group.
1753
+ static void genOptionalGroupPrinterAnchor (Element *anchor, OpMethodBody &body) {
1754
+ TypeSwitch<Element *>(anchor)
1755
+ .Case <OperandVariable, ResultVariable>([&](auto *element) {
1756
+ const NamedTypeConstraint *var = element->getVar ();
1757
+ if (var->isOptional ())
1758
+ body << " if (" << var->name << " ()) {\n " ;
1759
+ else if (var->isVariadic ())
1760
+ body << " if (!" << var->name << " ().empty()) {\n " ;
1761
+ })
1762
+ .Case <RegionVariable>([&](RegionVariable *element) {
1763
+ const NamedRegion *var = element->getVar ();
1764
+ // TODO: Add a check for optional regions here when ODS supports it.
1765
+ body << " if (!" << var->name << " ().empty()) {\n " ;
1766
+ })
1767
+ .Case <TypeDirective>([&](TypeDirective *element) {
1768
+ genOptionalGroupPrinterAnchor (element->getOperand (), body);
1769
+ })
1770
+ .Case <FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1771
+ genOptionalGroupPrinterAnchor (element->getInputs (), body);
1772
+ })
1773
+ .Case <AttributeVariable>([&](AttributeVariable *attr) {
1774
+ body << " if ((*this)->getAttr(\" " << attr->getVar ()->name
1775
+ << " \" )) {\n " ;
1776
+ });
1777
+ }
1778
+
1752
1779
void OperationFormat::genElementPrinter (Element *element, OpMethodBody &body,
1753
1780
Operator &op, bool &shouldEmitSpace,
1754
1781
bool &lastWasPunctuation) {
@@ -1769,21 +1796,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
1769
1796
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1770
1797
// Emit the check for the presence of the anchor element.
1771
1798
Element *anchor = optional->getAnchor ();
1772
- if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
1773
- const NamedTypeConstraint *var = operand->getVar ();
1774
- if (var->isOptional ())
1775
- body << " if (" << var->name << " ()) {\n " ;
1776
- else if (var->isVariadic ())
1777
- body << " if (!" << var->name << " ().empty()) {\n " ;
1778
- } else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
1779
- const NamedRegion *var = region->getVar ();
1780
- // TODO: Add a check for optional here when ODS supports it.
1781
- body << " if (!" << var->name << " ().empty()) {\n " ;
1782
-
1783
- } else {
1784
- body << " if ((*this)->getAttr(\" "
1785
- << cast<AttributeVariable>(anchor)->getVar ()->name << " \" )) {\n " ;
1786
- }
1799
+ genOptionalGroupPrinterAnchor (anchor, body);
1787
1800
1788
1801
// If the anchor is a unit attribute, we don't need to print it. When
1789
1802
// parsing, we will add this attribute if this group is present.
@@ -2244,8 +2257,9 @@ class FormatParser {
2244
2257
bool isTopLevel);
2245
2258
LogicalResult parseOptionalChildElement (
2246
2259
std::vector<std::unique_ptr<Element>> &childElements,
2247
- SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
2248
2260
Optional<unsigned > &anchorIdx);
2261
+ LogicalResult verifyOptionalChildElement (Element *element,
2262
+ llvm::SMLoc childLoc, bool isAnchor);
2249
2263
2250
2264
// / Parse the various different directives.
2251
2265
LogicalResult parseAttrDictDirective (std::unique_ptr<Element> &element,
@@ -2315,7 +2329,6 @@ class FormatParser {
2315
2329
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2316
2330
llvm::DenseSet<const NamedRegion *> seenRegions;
2317
2331
llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2318
- llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
2319
2332
};
2320
2333
} // end anonymous namespace
2321
2334
@@ -2760,10 +2773,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2760
2773
2761
2774
// Parse the child elements for this optional group.
2762
2775
std::vector<std::unique_ptr<Element>> elements;
2763
- SmallPtrSet<const NamedTypeConstraint *, 8 > seenVariables;
2764
2776
Optional<unsigned > anchorIdx;
2765
2777
do {
2766
- if (failed (parseOptionalChildElement (elements, seenVariables, anchorIdx)))
2778
+ if (failed (parseOptionalChildElement (elements, anchorIdx)))
2767
2779
return ::mlir::failure ();
2768
2780
} while (curToken.getKind () != Token::r_paren);
2769
2781
consumeToken ();
@@ -2787,31 +2799,6 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2787
2799
" first parsable element of an operand group must be "
2788
2800
" an attribute, literal, operand, or region" );
2789
2801
2790
- // After parsing all of the elements, ensure that all type directives refer
2791
- // only to elements within the group.
2792
- auto checkTypeOperand = [&](Element *typeEle) {
2793
- auto *opVar = dyn_cast<OperandVariable>(typeEle);
2794
- const NamedTypeConstraint *var = opVar ? opVar->getVar () : nullptr ;
2795
- if (!seenVariables.count (var))
2796
- return emitError (curLoc, " type directive can only refer to variables "
2797
- " within the optional group" );
2798
- return ::mlir::success ();
2799
- };
2800
- for (auto &ele : elements) {
2801
- if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get ())) {
2802
- if (failed (checkTypeOperand (typeEle->getOperand ())))
2803
- return failure ();
2804
- } else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get ())) {
2805
- if (failed (checkTypeOperand (typeEle->getOperand ())))
2806
- return ::mlir::failure ();
2807
- } else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get ())) {
2808
- if (failed (checkTypeOperand (typeEle->getInputs ())) ||
2809
- failed (checkTypeOperand (typeEle->getResults ())))
2810
- return ::mlir::failure ();
2811
- }
2812
- }
2813
-
2814
- optionalVariables.insert (seenVariables.begin (), seenVariables.end ());
2815
2802
auto parseStart = parseBegin - elements.begin ();
2816
2803
element = std::make_unique<OptionalElement>(std::move (elements), *anchorIdx,
2817
2804
parseStart);
@@ -2820,7 +2807,6 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2820
2807
2821
2808
LogicalResult FormatParser::parseOptionalChildElement (
2822
2809
std::vector<std::unique_ptr<Element>> &childElements,
2823
- SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
2824
2810
Optional<unsigned > &anchorIdx) {
2825
2811
llvm::SMLoc childLoc = curToken.getLoc ();
2826
2812
childElements.push_back ({});
@@ -2837,7 +2823,14 @@ LogicalResult FormatParser::parseOptionalChildElement(
2837
2823
consumeToken ();
2838
2824
}
2839
2825
2840
- return TypeSwitch<Element *, LogicalResult>(childElements.back ().get ())
2826
+ return verifyOptionalChildElement (childElements.back ().get (), childLoc,
2827
+ isAnchor);
2828
+ }
2829
+
2830
+ LogicalResult FormatParser::verifyOptionalChildElement (Element *element,
2831
+ llvm::SMLoc childLoc,
2832
+ bool isAnchor) {
2833
+ return TypeSwitch<Element *, LogicalResult>(element)
2841
2834
// All attributes can be within the optional group, but only optional
2842
2835
// attributes can be the anchor.
2843
2836
.Case ([&](AttributeVariable *attrEle) {
@@ -2852,24 +2845,42 @@ LogicalResult FormatParser::parseOptionalChildElement(
2852
2845
if (!ele->getVar ()->isVariableLength ())
2853
2846
return emitError (childLoc, " only variable length operands can be "
2854
2847
" used within an optional group" );
2855
- seenVariables.insert (ele->getVar ());
2848
+ return ::mlir::success ();
2849
+ })
2850
+ // Only optional-like(i.e. variadic) results can be within an optional
2851
+ // group.
2852
+ .Case <ResultVariable>([&](ResultVariable *ele) {
2853
+ if (!ele->getVar ()->isVariableLength ())
2854
+ return emitError (childLoc, " only variable length results can be "
2855
+ " used within an optional group" );
2856
2856
return ::mlir::success ();
2857
2857
})
2858
2858
.Case <RegionVariable>([&](RegionVariable *) {
2859
2859
// TODO: When ODS has proper support for marking "optional" regions, add
2860
2860
// a check here.
2861
2861
return ::mlir::success ();
2862
2862
})
2863
- // Literals, whitespace, custom directives, and type directives may be
2864
- // used, but they can't anchor the group.
2865
- .Case <LiteralElement, WhitespaceElement, CustomDirective,
2866
- FunctionalTypeDirective, OptionalElement, TypeRefDirective,
2867
- TypeDirective>([&](Element *) {
2868
- if (isAnchor)
2869
- return emitError (childLoc, " only variables can be used to anchor "
2870
- " an optional group" );
2871
- return ::mlir::success ();
2863
+ .Case <TypeDirective>([&](TypeDirective *ele) {
2864
+ return verifyOptionalChildElement (ele->getOperand (), childLoc,
2865
+ /* isAnchor=*/ false );
2872
2866
})
2867
+ .Case <FunctionalTypeDirective>([&](FunctionalTypeDirective *ele) {
2868
+ if (failed (verifyOptionalChildElement (ele->getInputs (), childLoc,
2869
+ /* isAnchor=*/ false )))
2870
+ return failure ();
2871
+ return verifyOptionalChildElement (ele->getResults (), childLoc,
2872
+ /* isAnchor=*/ false );
2873
+ })
2874
+ // Literals, whitespace, and custom directives may be used, but they can't
2875
+ // anchor the group.
2876
+ .Case <LiteralElement, WhitespaceElement, CustomDirective,
2877
+ FunctionalTypeDirective, OptionalElement, TypeRefDirective>(
2878
+ [&](Element *) {
2879
+ if (isAnchor)
2880
+ return emitError (childLoc, " only variables and types can be used "
2881
+ " to anchor an optional group" );
2882
+ return ::mlir::success ();
2883
+ })
2873
2884
.Default ([&](Element *) {
2874
2885
return emitError (childLoc, " only literals, types, and variables can be "
2875
2886
" used within an optional group" );
0 commit comments