Skip to content

Commit 412e30b

Browse files
[mlir][Transforms] Dialect Conversion: Add 1:N op replacement test case (#121271)
This commit adds a test case that performs two back-to-back 1:N replacements: `(i16) -> (i16, i16) -> ((i16, i16), (i16, i16))`. For the moment, 3 argument materializations are inserted. In the future (when the conversion value mapping supports 1:N), a single target materialization will be inserted. Addresses a [comment](#116524 (comment)) in #116524.
1 parent e45e091 commit 412e30b

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ func.func @fold_legalization() -> i32 {
450450
// -----
451451

452452
// CHECK-LABEL: func @convert_detached_signature()
453-
// CHECK: "test.legal_op_with_region"() ({
453+
// CHECK: "test.legal_op"() ({
454454
// CHECK: ^bb0(%arg0: f64):
455455
// CHECK: "test.return"() : () -> ()
456456
// CHECK: }) : () -> ()
@@ -483,3 +483,25 @@ func.func @test_1_to_n_block_signature_conversion() {
483483
"test.return"() : () -> ()
484484
}
485485

486+
// -----
487+
488+
// CHECK: notifyOperationInserted: test.step_1
489+
// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement
490+
// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement
491+
// CHECK: notifyOperationInserted: test.legal_op
492+
// CHECK: notifyOperationReplaced: test.step_1
493+
// CHECK: notifyOperationErased: test.step_1
494+
495+
// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
496+
// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
497+
// TODO: There should be a single cast (i.e., a single target materialization).
498+
// This is currently not possible due to 1:N limitations of the conversion
499+
// mapping. Instead, we have 3 argument materializations.
500+
// CHECK: %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
501+
// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
502+
// CHECK: %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
503+
// CHECK: "test.valid"(%[[cast3]]) : (f16) -> ()
504+
func.func @test_multiple_1_to_n_replacement() {
505+
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
506+
"test.invalid"(%0) : (f16) -> ()
507+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
785785
ConversionPatternRewriter &rewriter) const final {
786786
if (op->getNumRegions() != 1)
787787
return failure();
788-
OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
788+
OperationState state(op->getLoc(), "test.legal_op", operands,
789789
op->getResultTypes(), {}, BlockRange());
790790
Region *newRegion = state.addRegion();
791791
rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,49 @@ class TestRepetitive1ToNConsumer : public ConversionPattern {
12341234
}
12351235
};
12361236

1237+
/// A pattern that tests two back-to-back 1 -> 2 op replacements.
1238+
class TestMultiple1ToNReplacement : public ConversionPattern {
1239+
public:
1240+
TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
1241+
: ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
1242+
ctx) {}
1243+
LogicalResult
1244+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1245+
ConversionPatternRewriter &rewriter) const final {
1246+
// Helper function that replaces the given op with a new op of the given
1247+
// name and doubles each result (1 -> 2 replacement of each result).
1248+
auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
1249+
SmallVector<Type> types;
1250+
for (Type t : op->getResultTypes()) {
1251+
types.push_back(t);
1252+
types.push_back(t);
1253+
}
1254+
OperationState state(op->getLoc(), name,
1255+
/*operands=*/{}, types, op->getAttrs());
1256+
auto *newOp = rewriter.create(state);
1257+
SmallVector<ValueRange> repls;
1258+
for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
1259+
repls.push_back(newOp->getResults().slice(2 * i, 2));
1260+
rewriter.replaceOpWithMultiple(op, repls);
1261+
return newOp;
1262+
};
1263+
1264+
// Replace test.multiple_1_to_n_replacement with test.step_1.
1265+
Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
1266+
// Now replace test.step_1 with test.legal_op.
1267+
// TODO: Ideally, it should not be necessary to reset the insertion point
1268+
// here. Based on the API calls, it looks like test.step_1 is entirely
1269+
// erased. But that's not the case: an argument materialization will
1270+
// survive. And that argument materialization will be used by the users of
1271+
// `op`. If we don't reset the insertion point here, we get dominance
1272+
// errors. This will be fixed when we have 1:N support in the conversion
1273+
// value mapping.
1274+
rewriter.setInsertionPoint(repl1);
1275+
replaceWithDoubleResults(repl1, "test.legal_op");
1276+
return success();
1277+
}
1278+
};
1279+
12371280
} // namespace
12381281

12391282
namespace {
@@ -1319,7 +1362,8 @@ struct TestLegalizePatternDriver
13191362
TestUndoPropertiesModification, TestEraseOp,
13201363
TestRepetitive1ToNConsumer>(&getContext());
13211364
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1322-
TestPassthroughInvalidOp>(&getContext(), converter);
1365+
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1366+
&getContext(), converter);
13231367
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
13241368
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
13251369
converter);
@@ -1330,8 +1374,7 @@ struct TestLegalizePatternDriver
13301374
target.addLegalOp<ModuleOp>();
13311375
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
13321376
TerminatorOp, OneRegionOp>();
1333-
target.addLegalOp(
1334-
OperationName("test.legal_op_with_region", &getContext()));
1377+
target.addLegalOp(OperationName("test.legal_op", &getContext()));
13351378
target
13361379
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
13371380
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {

0 commit comments

Comments
 (0)