Skip to content

Commit 2c60b4c

Browse files
committed
[mlir][ods] Populate properties in generated builder
Previously this was only populated in the create method later.
1 parent b3ca9c3 commit 2c60b4c

File tree

3 files changed

+143
-25
lines changed

3 files changed

+143
-25
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,6 +2401,13 @@ def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
24012401
let regions = (region AnyRegion:$body);
24022402
}
24032403

2404+
// Two variadic args, non variadic results, with AttrSizedOperandSegments
2405+
// Test build method generation for property conversion & type inference.
2406+
def TableGenBuildOp6 : TEST_Op<"tblgen_build_6", [AttrSizedOperandSegments]> {
2407+
let arguments = (ins Variadic<AnyType>:$a, Variadic<AnyType>:$b);
2408+
let results = (outs F32:$result);
2409+
}
2410+
24042411
//===----------------------------------------------------------------------===//
24052412
// Test BufferPlacement
24062413
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@
4141

4242
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
4343

44+
#if 0
45+
#define DBG_ODS_PRINT(body, X) \
46+
body << "fprintf(stderr, \"Generated from " << X \
47+
<< " at %s:%d\\n\", __FILE__, __LINE__);\n";
48+
#else
49+
#define DBG_ODS_PRINT(body, X)
50+
#endif
51+
4452
using namespace llvm;
4553
using namespace mlir;
4654
using namespace mlir::tblgen;
@@ -1321,7 +1329,7 @@ void OpEmitter::genPropertiesSupport() {
13211329
{2};
13221330
if (!attr) {{
13231331
emitError() << "expected key entry for {1} in DictionaryAttr to set "
1324-
"Properties.";
1332+
"Properties";
13251333
return ::mlir::failure();
13261334
}
13271335
if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
@@ -1380,14 +1388,14 @@ void OpEmitter::genPropertiesSupport() {
13801388
if (attr || /*isRequired=*/{1}) {{
13811389
if (!attr) {{
13821390
emitError() << "expected key entry for {0} in DictionaryAttr to set "
1383-
"Properties.";
1391+
"Properties";
13841392
return ::mlir::failure();
13851393
}
13861394
auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
13871395
if (convertedAttr) {{
13881396
propStorage = convertedAttr;
13891397
} else {{
1390-
emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1398+
emitError() << "invalid attribute `{0}` in property conversion: " << attr;
13911399
return ::mlir::failure();
13921400
}
13931401
}
@@ -2397,6 +2405,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
23972405
if (!m)
23982406
return;
23992407
auto &body = m->body();
2408+
DBG_ODS_PRINT(body, __LINE__);
24002409
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
24012410
/*isRawValueAttr=*/attrType ==
24022411
AttrParamKind::UnwrappedValue);
@@ -2519,6 +2528,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
25192528
if (!m)
25202529
return;
25212530
auto &body = m->body();
2531+
DBG_ODS_PRINT(body, __LINE__);
25222532

25232533
// Operands
25242534
body << " " << builderOpState << ".addOperands(operands);\n";
@@ -2623,6 +2633,7 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
26232633
if (!m)
26242634
return;
26252635
auto &body = m->body();
2636+
DBG_ODS_PRINT(body, __LINE__);
26262637

26272638
int numResults = op.getNumResults();
26282639
int numVariadicResults = op.getNumVariableLengthResults();
@@ -2650,6 +2661,19 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
26502661
}
26512662

26522663
// Result types
2664+
if (emitHelper.hasProperties()) {
2665+
// Initialize the properties from Attributes before invoking the infer
2666+
// function.
2667+
body << formatv(R"(
2668+
::mlir::OpaqueProperties properties =
2669+
&{1}.getOrAddProperties<{0}::Properties>();
2670+
std::optional<::mlir::RegisteredOperationName> info =
2671+
{1}.name.getRegisteredInfo();
2672+
if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2673+
{1}.attributes.getDictionary({1}.getContext()), nullptr)))
2674+
::llvm::report_fatal_error("Property conversion failed.");)",
2675+
opClass.getClassName(), builderOpState);
2676+
}
26532677
body << formatv(R"(
26542678
::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
26552679
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
@@ -2684,6 +2708,7 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
26842708
if (!m)
26852709
return;
26862710
auto &body = m->body();
2711+
DBG_ODS_PRINT(body, __LINE__);
26872712
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
26882713
/*isRawValueAttr=*/attrType ==
26892714
AttrParamKind::UnwrappedValue);
@@ -2721,6 +2746,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
27212746
return;
27222747

27232748
auto &body = m->body();
2749+
DBG_ODS_PRINT(body, __LINE__);
27242750

27252751
// Push all result types to the operation state
27262752
std::string resultType;
@@ -2852,6 +2878,7 @@ void OpEmitter::genCollectiveParamBuilder() {
28522878
if (!m)
28532879
return;
28542880
auto &body = m->body();
2881+
DBG_ODS_PRINT(body, __LINE__);
28552882

28562883
// Operands
28572884
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
@@ -2879,6 +2906,20 @@ void OpEmitter::genCollectiveParamBuilder() {
28792906
<< "u && \"mismatched number of return types\");\n";
28802907
body << " " << builderOpState << ".addTypes(resultTypes);\n";
28812908

2909+
if (emitHelper.hasProperties()) {
2910+
// Initialize the properties from Attributes before invoking the infer
2911+
// function.
2912+
body << formatv(R"(
2913+
::mlir::OpaqueProperties properties =
2914+
&{1}.getOrAddProperties<{0}::Properties>();
2915+
std::optional<::mlir::RegisteredOperationName> info =
2916+
{1}.name.getRegisteredInfo();
2917+
if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2918+
{1}.attributes.getDictionary({1}.getContext()), nullptr)))
2919+
::llvm::report_fatal_error("Property conversion failed.");)",
2920+
opClass.getClassName(), builderOpState);
2921+
}
2922+
28822923
// Generate builder that infers type too.
28832924
// TODO: Expand to handle successors.
28842925
if (canInferType(op) && op.getNumSuccessors() == 0)

mlir/unittests/TableGen/OpBuildGen.cpp

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,29 +66,44 @@ class OpBuildGenTest : public ::testing::Test {
6666
EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
6767
attrs[idx].getValue());
6868

69+
EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
6970
concreteOp.erase();
7071
}
7172

72-
// Helper method to test ops with inferred result types and single variadic
73-
// input.
7473
template <typename OpTy>
75-
void testSingleVariadicInputInferredType() {
76-
// Test separate arg, separate param build method.
77-
auto op = builder.create<OpTy>(loc, i32Ty, ValueRange{*cstI32, *cstI32});
78-
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
79-
80-
// Test collective params build method.
81-
op = builder.create<OpTy>(loc, TypeRange{i32Ty},
82-
ValueRange{*cstI32, *cstI32});
83-
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
84-
85-
// Test build method with no result types, default value of attributes.
86-
op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32});
87-
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
88-
89-
// Test build method with no result types and supplied attributes.
90-
op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32}, attrs);
91-
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
74+
void verifyOp(OpTy &&concreteOp, std::vector<Type> resultTypes,
75+
std::vector<Value> operands1, std::vector<Value> operands2,
76+
std::vector<NamedAttribute> attrs) {
77+
ASSERT_NE(concreteOp, nullptr);
78+
Operation *op = concreteOp.getOperation();
79+
80+
EXPECT_EQ(op->getNumResults(), resultTypes.size());
81+
for (unsigned idx : llvm::seq(0U, op->getNumResults()))
82+
EXPECT_EQ(op->getResult(idx).getType(), resultTypes[idx]);
83+
84+
auto operands = llvm::to_vector(llvm::concat<Value>(operands1, operands2));
85+
EXPECT_EQ(op->getNumOperands(), operands.size());
86+
for (unsigned idx : llvm::seq(0U, op->getNumOperands()))
87+
EXPECT_EQ(op->getOperand(idx), operands[idx]);
88+
89+
EXPECT_EQ(op->getAttrs().size(), attrs.size());
90+
if (op->getAttrs().size() != attrs.size()) {
91+
// Simple export where there is mismatch count.
92+
llvm::errs() << "Op attrs:\n";
93+
for (auto it : op->getAttrs())
94+
llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
95+
96+
llvm::errs() << "Expected attrs:\n";
97+
for (auto it : attrs)
98+
llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
99+
} else {
100+
for (unsigned idx : llvm::seq<unsigned>(0U, attrs.size()))
101+
EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
102+
attrs[idx].getValue());
103+
}
104+
105+
EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
106+
concreteOp.erase();
92107
}
93108

94109
protected:
@@ -205,13 +220,31 @@ TEST_F(OpBuildGenTest,
205220
verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
206221
}
207222

208-
// The next test checks supression of ambiguous build methods for ops that
223+
// The next test checks suppression of ambiguous build methods for ops that
209224
// have a single variadic input, and single non-variadic result, and which
210-
// support the SameOperandsAndResultType trait and and optionally the
225+
// support the SameOperandsAndResultType trait and optionally the
211226
// InferOpTypeInterface interface. For such ops, the ODS framework generates
212227
// build methods with no result types as they are inferred from the input types.
213228
TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
214-
testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
229+
// Test separate arg, separate param build method.
230+
auto op = builder.create<test::TableGenBuildOp4>(
231+
loc, i32Ty, ValueRange{*cstI32, *cstI32});
232+
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
233+
234+
// Test collective params build method.
235+
op = builder.create<test::TableGenBuildOp4>(loc, TypeRange{i32Ty},
236+
ValueRange{*cstI32, *cstI32});
237+
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
238+
239+
// Test build method with no result types, default value of attributes.
240+
op =
241+
builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32});
242+
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
243+
244+
// Test build method with no result types and supplied attributes.
245+
op = builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32},
246+
attrs);
247+
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
215248
}
216249

217250
TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
@@ -221,4 +254,41 @@ TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
221254
verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);
222255
}
223256

257+
TEST_F(OpBuildGenTest, BuildMethodsVariadicProperties) {
258+
// Account for conversion as part of getAttrs().
259+
std::vector<NamedAttribute> noAttrsStorage;
260+
auto segmentSize = builder.getNamedAttr("operandSegmentSizes",
261+
builder.getDenseI32ArrayAttr({1, 1}));
262+
noAttrsStorage.push_back(segmentSize);
263+
ArrayRef<NamedAttribute> noAttrs(noAttrsStorage);
264+
std::vector<NamedAttribute> attrsStorage = this->attrStorage;
265+
attrsStorage.push_back(segmentSize);
266+
ArrayRef<NamedAttribute> attrs(attrsStorage);
267+
268+
// Test separate arg, separate param build method.
269+
auto op = builder.create<test::TableGenBuildOp6>(
270+
loc, f32Ty, ValueRange{*cstI32}, ValueRange{*cstI32});
271+
verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
272+
273+
// Test build method with no result types, default value of attributes.
274+
op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32},
275+
ValueRange{*cstI32});
276+
verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
277+
278+
// Test collective params build method.
279+
op = builder.create<test::TableGenBuildOp6>(
280+
loc, TypeRange{f32Ty}, ValueRange{*cstI32}, ValueRange{*cstI32});
281+
verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
282+
283+
// Test build method with result types, supplied attributes.
284+
op = builder.create<test::TableGenBuildOp6>(
285+
loc, TypeRange{f32Ty}, ValueRange{*cstI32, *cstI32}, attrs);
286+
verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
287+
288+
// Test build method with no result types and supplied attributes.
289+
op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32, *cstI32},
290+
attrs);
291+
verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
292+
}
293+
224294
} // namespace mlir

0 commit comments

Comments
 (0)