Skip to content

Commit eaad883

Browse files
Re-enable torch-adjust-calling-conventions tests (#4034)
This PR updates AdjustCallingConventionsPass to the dialect conversion framework API updates introduced in llvm/llvm-project#116470. This may not be an optimal use of the new API, but it is functional. Suggestions welcome! fixes #3983
1 parent caaeb21 commit eaad883

File tree

2 files changed

+106
-81
lines changed

2 files changed

+106
-81
lines changed

lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,27 +164,51 @@ class AdjustCallingConventionForReturn
164164
public:
165165
using OpConversionPattern::OpConversionPattern;
166166
LogicalResult
167-
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
167+
matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
168168
ConversionPatternRewriter &rewriter) const override {
169-
170169
SmallVector<Value> newOperands;
171-
for (auto operand : adaptor.getOperands()) {
172-
if (!operand)
173-
continue;
174-
if (isa<Torch::NoneType>(operand.getType()))
175-
continue;
176-
if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
177-
Location loc = op.getLoc();
178-
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
179-
auto i = rewriter.create<ConstantIntOp>(
180-
loc, rewriter.getI64IntegerAttr(en.index()));
181-
newOperands.push_back(
182-
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
170+
for (const auto &vals : adaptor.getOperands()) {
171+
if (vals.size() == 1) {
172+
if (isa<Torch::NoneType>(vals[0].getType()))
173+
continue;
174+
newOperands.push_back(vals[0]);
175+
} else if (vals.size() > 1) {
176+
// The dialect conversion framework inserts unrealized conversion casts
177+
// to materialize legal types from illegal types. For example, for input
178+
// IR like
179+
// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
180+
// torch.tensor -> !torch.tuple<tensor, tensor>
181+
// return %1 : !torch.tuple<tensor, tensor>
182+
// at this stage in the conversion process we'll have something like
183+
// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
184+
// !torch.tensor -> !torch.tuple<tensor, tensor>
185+
// %2 = builtin.unrealized_conversion_cast %1 :
186+
// !torch.tuple<tensor, tensor> to !torch.tensor
187+
// %3 = builtin.unrealized_conversion_cast %1 :
188+
// !torch.tuple<tensor, tensor> to !torch.tensor
189+
// return %2, %3 : !torch.tensor, !torch.tensor
190+
//
191+
// Given (%2, %3) as operands, here we map back to the original
192+
// torch.prim.TupleConstruct.
193+
if (vals[0].getDefiningOp() &&
194+
isa<mlir::UnrealizedConversionCastOp>(vals[0].getDefiningOp())) {
195+
Value operand = vals[0].getDefiningOp()->getOperand(0);
196+
if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
197+
Location loc = op.getLoc();
198+
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
199+
auto i = rewriter.create<ConstantIntOp>(
200+
loc, rewriter.getI64IntegerAttr(en.index()));
201+
newOperands.push_back(rewriter.create<PrimTupleIndexOp>(
202+
loc, en.value(), operand, i));
203+
}
204+
continue;
205+
}
183206
}
184-
continue;
207+
208+
llvm::append_range(newOperands, vals);
185209
}
186-
newOperands.push_back(operand);
187210
}
211+
188212
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, newOperands);
189213
return success();
190214
}

test/Dialect/Torch/adjust-calling-conventions.mlir

Lines changed: 66 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@ func.func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?]
99
return %arg0 : !torch.tensor
1010
}
1111

12+
// -----
13+
1214
// CHECK-LABEL: func.func @no_type_bound(
1315
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
1416
// CHECK: return %[[ARG]] : !torch.tensor
1517
func.func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
1618
return %arg0 : !torch.tensor
1719
}
1820

21+
// -----
22+
1923
// CHECK-LABEL: func.func @call(
2024
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
2125
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
@@ -29,71 +33,68 @@ func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],
2933
return %arg0 : !torch.tensor
3034
}
3135

32-
// COM: func.func @none_return() {
33-
// COM: %[[NONE:.*]] = torch.constant.none
34-
// COM: return
35-
// func.func @none_return() -> !torch.none {
36-
// %1 = torch.constant.none
37-
// return %1 : !torch.none
38-
// }
36+
// -----
37+
38+
// CHECK-LABEL: func.func @none_return() {
39+
// CHECK: %[[NONE:.*]] = torch.constant.none
40+
// CHECK: return
41+
func.func @none_return() -> !torch.none {
42+
%1 = torch.constant.none
43+
return %1 : !torch.none
44+
}
45+
46+
// CHECK-LABEL: func.func @none_call_return() {
47+
// CHECK: call @none_return() : () -> ()
48+
// CHECK: %[[NONE:.*]] = torch.constant.none
49+
// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> ()
50+
// CHECK: return
51+
func.func @none_call_return() {
52+
%0 = call @none_return() : () -> !torch.none
53+
"test.use"(%0) : (!torch.none) -> ()
54+
return
55+
}
3956

40-
// COM: func.func @none_call_return() {
41-
// COM: call @none_return() : () -> ()
42-
// COM: %[[NONE:.*]] = torch.constant.none
43-
// COM: "test.use"(%[[NONE]]) : (!torch.none) -> ()
44-
// COM: return
45-
// func.func @none_call_return() {
46-
// %0 = call @none_return() : () -> !torch.none
47-
// "test.use"(%0) : (!torch.none) -> ()
48-
// return
49-
// }
57+
// -----
5058

51-
// COM: func.func @tuple_return(
52-
// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
53-
// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
54-
// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
55-
// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
56-
// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
57-
// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
58-
// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
59-
// COM: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
60-
// COM: %[[CST0:.*]] = torch.constant.int 0
61-
// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
62-
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
63-
// COM: %[[CST1:.*]] = torch.constant.int 1
64-
// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
65-
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
66-
// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
67-
// func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
68-
// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
69-
// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
70-
// return %1 : !torch.tuple<tensor, tensor>
71-
// }
59+
// CHECK-LABEL: func.func @tuple_return(
60+
// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
61+
// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
62+
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
63+
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
64+
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
65+
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
66+
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
67+
// CHECK: %[[CST0:.*]] = torch.constant.int 0
68+
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
69+
// CHECK: %[[CST1:.*]] = torch.constant.int 1
70+
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
71+
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
72+
func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
73+
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
74+
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
75+
return %1 : !torch.tuple<tensor, tensor>
76+
}
7277

73-
// COM: func.func @call_tuple_return(
74-
// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
75-
// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
76-
// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
77-
// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
78-
// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
79-
// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
80-
// COM: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
81-
// COM: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
82-
// COM: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
83-
// COM: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
84-
// COM: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
85-
// COM: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
86-
// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
87-
// COM: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
88-
// COM: %[[CST0:.*]] = torch.constant.int 0
89-
// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
90-
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
91-
// COM: %[[CST1:.*]] = torch.constant.int 1
92-
// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
93-
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
94-
// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
95-
// func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
96-
// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
97-
// %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
98-
// return %0 : !torch.tuple<tensor, tensor>
99-
// }
78+
// CHECK-LABEL: func.func @call_tuple_return(
79+
// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
80+
// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
81+
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
82+
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
83+
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
84+
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
85+
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
86+
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
87+
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
88+
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
89+
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
90+
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
91+
// CHECK: %[[CST0:.*]] = torch.constant.int 0
92+
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
93+
// CHECK: %[[CST1:.*]] = torch.constant.int 1
94+
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
95+
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
96+
func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
97+
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
98+
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
99+
return %0 : !torch.tuple<tensor, tensor>
100+
}

0 commit comments

Comments
 (0)