Skip to content

FXML.2007: PDLL support for creating new ops with empty regions #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
Variadic<PDL_Attribute>:$attributeValues,
StrArrayAttr:$attributeValueNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues,
OptionalAttr<UI32Attr>:$numRegions);
let results = (outs PDL_Operation:$op);
let assemblyFormat = [{
($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
Expand All @@ -361,9 +362,10 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
CArg<"ValueRange", "llvm::None">:$attrValues,
CArg<"ValueRange", "llvm::None">:$resultTypes), [{
auto nameAttr = name ? $_builder.getStringAttr(*name) : StringAttr();
IntegerAttr numRegionsAttr;
build($_builder, $_state, $_builder.getType<OperationType>(), nameAttr,
operandValues, attrValues, $_builder.getStrArrayAttr(attrNames),
resultTypes);
resultTypes, numRegionsAttr);
}]>,
];
let extraClassDeclaration = [{
Expand Down
13 changes: 11 additions & 2 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,24 @@ def PDLInterp_CreateOperationOp
Variadic<PDL_Attribute>:$inputAttributes,
StrArrayAttr:$inputAttributeNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes,
UnitAttr:$inferredResultTypes);
UnitAttr:$inferredResultTypes,
OptionalAttr<UI32Attr>:$numRegions);
let results = (outs PDL_Operation:$resultOp);

let builders = [
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
"bool":$inferredResultTypes, "ValueRange":$operands,
"ValueRange":$attributes, "ArrayAttr":$attributeNames), [{
IntegerAttr numRegionsAttr;
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
operands, attributes, attributeNames, types, inferredResultTypes);
operands, attributes, attributeNames, types, inferredResultTypes, numRegionsAttr);
}]>,
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
"bool":$inferredResultTypes, "ValueRange":$operands,
"ValueRange":$attributes, "ArrayAttr":$attributeNames, "uint32_t":$numRegions), [{
auto numRegionsAttr = $_builder.getUI32IntegerAttr(numRegions);
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
operands, attributes, attributeNames, types, inferredResultTypes, numRegionsAttr);
}]>
];
let assemblyFormat = [{
Expand Down
22 changes: 12 additions & 10 deletions mlir/include/mlir/Tools/PDLL/AST/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,12 +501,11 @@ class OperationExpr final
private llvm::TrailingObjects<OperationExpr, Expr *,
NamedAttributeDecl *> {
public:
static OperationExpr *create(Context &ctx, SMRange loc,
const ods::Operation *odsOp,
const OpNameDecl *nameDecl,
ArrayRef<Expr *> operands,
ArrayRef<Expr *> resultTypes,
ArrayRef<NamedAttributeDecl *> attributes);
static OperationExpr *
create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
const OpNameDecl *nameDecl, ArrayRef<Expr *> operands,
ArrayRef<Expr *> resultTypes,
ArrayRef<NamedAttributeDecl *> attributes, unsigned numRegions);

/// Return the name of the operation, or None if there isn't one.
Optional<StringRef> getName() const;
Expand Down Expand Up @@ -542,19 +541,22 @@ class OperationExpr final
return const_cast<OperationExpr *>(this)->getAttributes();
}

unsigned getNumRegions() const { return numRegions; }

private:
OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
unsigned numOperands, unsigned numResultTypes,
unsigned numAttributes, SMRange nameLoc)
unsigned numAttributes, unsigned numRegions, SMRange nameLoc)
: Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
numResultTypes(numResultTypes), numAttributes(numAttributes),
nameLoc(nameLoc) {}
numRegions(numRegions), nameLoc(nameLoc) {}

/// The name decl of this expression.
const OpNameDecl *nameDecl;

/// The number of operands, result types, and attributes of the operation.
unsigned numOperands, numResultTypes, numAttributes;
/// The number of operands, result types, attributes and regions of the
/// operation.
unsigned numOperands, numResultTypes, numAttributes, numRegions;

/// The location of the operation name in the expression if it has a name.
SMRange nameLoc;
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,10 @@ void PatternLowering::generateRewriter(

// Create the new operation.
Location loc = operationOp.getLoc();
auto numRegions = operationOp.getNumRegions().value_or(0);
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
attributes, operationOp.getAttributeValueNames());
attributes, operationOp.getAttributeValueNames(), numRegions);
rewriteValues[operationOp.getOp()] = createdOp;

// Generate accesses for any results that have their types constrained.
Expand Down
41 changes: 29 additions & 12 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,34 +141,51 @@ LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
// pdl::OperationOp
//===----------------------------------------------------------------------===//

/// Handles parsing of OperationOpAttributes, e.g. {"attr" = %attribute}.
/// Also allows empty `{}`
static ParseResult parseOperationOpAttributes(
OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
SmallVector<Attribute, 4> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
auto parseOperands = [&]() {
StringAttr nameAttr;
OpAsmParser::UnresolvedOperand operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
if (failed(p.parseOptionalRBrace())) {
auto parseOperands = [&]() {
StringAttr nameAttr;
OpAsmParser::UnresolvedOperand operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOperands.push_back(operand);
return success();
};
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
return failure();
attrNames.push_back(nameAttr);
attrOperands.push_back(operand);
return success();
};
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
return failure();
}
}
Comment on lines +153 to 167
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The github diff is not very informative here.
The existing parsing is simply surrounded by

if (failed(p.parseOptionalRBrace())) {
//[...]
}

to enable parsing {} as empty operationOpAttributes.

attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}

/// Handles printing of OperationOpAttributes, e.g. {"attr" = %attribute}.
/// Prints empty `{}` when it would not be possible to discern the attr-dict
/// otherwise.
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
OperandRange attrArgs,
ArrayAttr attrNames) {
if (attrNames.empty())
/// Only omit printing empty `{}` if there are no other attributes that have
/// to be printed later because otherwise we could not discern the attr dict.
static const SmallVector<StringRef, 3> specialAttrs = {
"operand_segment_sizes", "attributeValueNames", "opName"};
bool onlySpecialAttrs =
llvm::all_of(op->getAttrs(), [&](const NamedAttribute &attr) {
return llvm::any_of(specialAttrs, [&](const StringRef &predefinedAttr) {
return attr.getName() == predefinedAttr;
});
});
if (attrNames.empty() && onlySpecialAttrs)
return;
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
Expand Down
34 changes: 22 additions & 12 deletions mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,45 @@ LogicalResult CreateOperationOp::verify() {
return success();
}

/// Handles parsing of OperationOpAttributes, e.g. {"attr" = %attribute}.
/// Also allows empty `{}`
static ParseResult parseCreateOperationOpAttributes(
OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
SmallVector<Attribute, 4> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
auto parseOperands = [&]() {
StringAttr nameAttr;
OpAsmParser::UnresolvedOperand operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
if (failed(p.parseOptionalRBrace())) {
auto parseOperands = [&]() {
StringAttr nameAttr;
OpAsmParser::UnresolvedOperand operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOperands.push_back(operand);
return success();
};
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
return failure();
attrNames.push_back(nameAttr);
attrOperands.push_back(operand);
return success();
};
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
return failure();
}
}
attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}

/// Handles printing of OperationOpAttributes, e.g. {"attr" = %attribute}.
/// Prints empty `{}` when it would not be possible to discern the attr-dict
/// otherwise.
static void printCreateOperationOpAttributes(OpAsmPrinter &p,
CreateOperationOp op,
OperandRange attrArgs,
ArrayAttr attrNames) {
if (attrNames.empty())
/// Only omit printing empty `{}` if we have result types because otherwise we
/// could not discern the attr dict.
unsigned numResultTypes = op.getODSOperandIndexAndLength(2).second;
if (attrNames.empty() && (numResultTypes > 0 || op.getInferredResultTypes()))
return;
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,14 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
writer.append(kInferTypesMarker);
else
writer.appendPDLValueList(op.getInputResultTypes());

// Add number of regions
if (IntegerAttr attr = op.getNumRegionsAttr()) {
writer.append(ByteCodeField(attr.getUInt()));
} else {
unsigned numRegions = 0;
writer.append(ByteCodeField(numRegions));
}
}
void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
// Append the correct opcode for the range type.
Expand Down Expand Up @@ -1663,6 +1671,12 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
}
}

// handle regions:
unsigned numRegions = read();
for (unsigned i = 0; i < numRegions; i++) {
state.addRegion();
}

Operation *resultOp = rewriter.create(state);
memory[memIndex] = resultOp;

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ void NodePrinter::printImpl(const MemberAccessExpr *expr) {
void NodePrinter::printImpl(const OperationExpr *expr) {
os << "OperationExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
os << "> numRegions:" << expr->getNumRegions() << "\n";

printChildren(expr->getNameDecl());
printChildren("Operands", expr->getOperands());
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Tools/PDLL/AST/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,13 @@ MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
// OperationExpr
//===----------------------------------------------------------------------===//

OperationExpr *
OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
const OpNameDecl *name, ArrayRef<Expr *> operands,
ArrayRef<Expr *> resultTypes,
ArrayRef<NamedAttributeDecl *> attributes) {
OperationExpr *OperationExpr::create(Context &ctx, SMRange loc,
const ods::Operation *odsOp,
const OpNameDecl *name,
ArrayRef<Expr *> operands,
ArrayRef<Expr *> resultTypes,
ArrayRef<NamedAttributeDecl *> attributes,
unsigned numRegions) {
unsigned allocSize =
OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
operands.size() + resultTypes.size(), attributes.size());
Expand All @@ -315,7 +317,7 @@ OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
Type resultType = OperationType::get(ctx, name->getName(), odsOp);
OperationExpr *opExpr = new (rawData)
OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
attributes.size(), name->getLoc());
attributes.size(), numRegions, name->getLoc());
std::uninitialized_copy(operands.begin(), operands.end(),
opExpr->getOperands().begin());
std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,14 @@ Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
for (const ast::Expr *result : expr->getResultTypes())
results.push_back(genSingleExpr(result));

return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
attrValues, results);
auto operationOp = builder.create<pdl::OperationOp>(
loc, opName, operands, attrNames, attrValues, results);

// numRegions
if (expr->getNumRegions() > 0)
operationOp.setNumRegions(expr->getNumRegions());

return operationOp;
}

Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {
Expand Down
24 changes: 20 additions & 4 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,8 @@ class Parser {
OpResultTypeContext resultTypeContext,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
SmallVectorImpl<ast::Expr *> &results);
SmallVectorImpl<ast::Expr *> &results,
unsigned numRegions);
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
Expand Down Expand Up @@ -2129,8 +2130,23 @@ Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
resultTypeContext = OpResultTypeContext::Interface;
}

// Parse list of regions
unsigned numRegions = 0;
if (consumeIf(Token::l_paren)) {
do {
if (failed(parseToken(Token::l_brace, "expected `{` to open region")))
return failure();
if (failed(parseToken(Token::r_brace, "expected `}` to close region")))
return failure();
numRegions++;
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_paren, "expected `)` to close region "
"list")))
return failure();
}

return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
attributes, resultTypes);
attributes, resultTypes, numRegions);
}

FailureOr<ast::Expr *> Parser::parseTupleExpr() {
Expand Down Expand Up @@ -2807,7 +2823,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
OpResultTypeContext resultTypeContext,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
SmallVectorImpl<ast::Expr *> &results) {
SmallVectorImpl<ast::Expr *> &results, unsigned numRegions) {
Optional<StringRef> opNameRef = name->getName();
const ods::Operation *odsOp = lookupODSOperation(opNameRef);

Expand Down Expand Up @@ -2844,7 +2860,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
}

return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
attributes);
attributes, numRegions);
}

LogicalResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,19 @@ module @range_op {
}
}
}

// -----

// CHECK-LABEL: module @create_empty_region
module @create_empty_region {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter()
// CHECK: %[[UNUSED:.*]] = pdl_interp.create_operation "bar.op" {} {numRegions = 1 : ui32}
// CHECK: pdl_interp.finalize
pdl.pattern : benefit(1) {
%root = operation "foo.op"
rewrite %root {
%unused = operation "bar.op" {} {"numRegions" = 1 : ui32}
}
}
}
Loading