Skip to content

Commit 85cec06

Browse files
[mlir][Func] Support 1:N result type conversions in func.call conversion
1 parent 132de3a commit 85cec06

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,34 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2323
LogicalResult
2424
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
2525
ConversionPatternRewriter &rewriter) const override {
26-
// Convert the original function results.
26+
// Convert the original function results. Keep track of how many result
27+
// types an original result type is converted into.
28+
SmallVector<size_t> numResultsReplacments;
2729
SmallVector<Type, 1> convertedResults;
28-
if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
29-
convertedResults)))
30-
return failure();
31-
32-
// If this isn't a one-to-one type mapping, we don't know how to aggregate
33-
// the results.
34-
if (callOp->getNumResults() != convertedResults.size())
35-
return failure();
30+
size_t numFlattenedResults = 0;
31+
for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
32+
if (failed(typeConverter->convertTypes(type, convertedResults)))
33+
return failure();
34+
numResultsReplacments.push_back(convertedResults.size() -
35+
numFlattenedResults);
36+
numFlattenedResults = convertedResults.size();
37+
}
3638

3739
// Substitute with the new result types from the corresponding FuncType
3840
// conversion.
39-
rewriter.replaceOpWithNewOp<CallOp>(
40-
callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
41+
auto newCallOp =
42+
rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
43+
convertedResults, adaptor.getOperands());
44+
SmallVector<ValueRange> replacements;
45+
size_t offset = 0;
46+
for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
47+
replacements.push_back(
48+
newCallOp->getResults().slice(offset, numResultsReplacments[i]));
49+
offset += numResultsReplacments[i];
50+
}
51+
assert(offset == convertedResults.size() &&
52+
"expected that all converted results are used");
53+
rewriter.replaceOpWithMultiple(callOp, replacements);
4154
return success();
4255
}
4356
};

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,15 +379,24 @@ builtin.module {
379379

380380
// -----
381381

382-
// expected-remark @below {{applyPartialConversion failed}}
383382
module {
384-
func.func private @callee(%0 : f32) -> f32
385-
386-
func.func @caller( %arg: f32) {
387-
// expected-error @below {{failed to legalize}}
388-
%1 = func.call @callee(%arg) : (f32) -> f32
389-
return
390-
}
383+
// CHECK-LABEL: func.func private @callee() -> (f16, f16)
384+
func.func private @callee() -> (f32, i24)
385+
386+
// CHECK: func.func @caller()
387+
func.func @caller() {
388+
// f32 is converted to (f16, f16).
389+
// i24 is converted to ().
390+
// CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
391+
%0:2 = func.call @callee() : () -> (f32, i24)
392+
393+
// CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
394+
// CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
395+
// CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
396+
// expected-remark @below{{'test.some_user' is not legalizable}}
397+
"test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
398+
"test.return"() : () -> ()
399+
}
391400
}
392401

393402
// -----

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,11 @@ struct TestTypeConverter : public TypeConverter {
12151215
return success();
12161216
}
12171217

1218+
// Drop I24 types.
1219+
if (t.isInteger(24)) {
1220+
return success();
1221+
}
1222+
12181223
// Otherwise, convert the type directly.
12191224
results.push_back(t);
12201225
return success();

0 commit comments

Comments
 (0)