Skip to content

Commit 26cb5f8

Browse files
committed
[mlir][tosa] Support unranked input/weight tensors for convolution ops
This commit ensures that convolution operators including: conv2d, depthwise_conv2d, transpose_conv2d and conv3d, can have unranked input/weight operands. In order to support operands with unranked tensors, the tablegen definition was relaxed. The relaxation of tensor type will later be checked by the validation pass, should the user wish to use it. Change-Id: I33334909e0d4d0676daae81bfc4647e86abc063a Signed-off-by: Luke Hutton <[email protected]>
1 parent 728320f commit 26cb5f8

File tree

5 files changed

+86
-80
lines changed

5 files changed

+86
-80
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
124124

125125
let arguments = (ins
126126
Tosa_Tensor4D:$input,
127-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
127+
Tosa_Tensor4D:$weight,
128128
Tosa_Tensor1D:$bias,
129129
Tosa_ScalarIntOrFloatTensor:$input_zp,
130130
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -169,7 +169,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
169169

170170
let arguments = (ins
171171
Tosa_Tensor5D:$input,
172-
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
172+
Tosa_Tensor5D:$weight,
173173
Tosa_Tensor1D:$bias,
174174
Tosa_ScalarIntOrFloatTensor:$input_zp,
175175
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -215,7 +215,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
215215

216216
let arguments = (ins
217217
Tosa_Tensor4D:$input,
218-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
218+
Tosa_Tensor4D:$weight,
219219
Tosa_Tensor1D:$bias,
220220
Tosa_ScalarIntOrFloatTensor:$input_zp,
221221
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -429,7 +429,7 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
429429

430430
let arguments = (ins
431431
Tosa_Tensor4D:$input,
432-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
432+
Tosa_Tensor4D:$weight,
433433
Tosa_Tensor1D:$bias,
434434
Tosa_ScalarIntOrFloatTensor:$input_zp,
435435
Tosa_ScalarIntOrFloatTensor:$weight_zp,

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
8484
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8585
"number">;
8686

87-
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
88-
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
89-
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
90-
Tosa_QuantizedInt, AnyFloat]>;
91-
9287
//===----------------------------------------------------------------------===//
9388
// TOSA Tensor Conformance
9489
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,14 @@ static LogicalResult verifyConvOp(T op) {
282282
// tensors.
283283
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
284284
if (!inputType) {
285-
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
286-
return failure();
285+
// Skip following checks if input is not ranked
286+
return success();
287287
}
288288

289289
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
290290
if (!weightType) {
291-
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
292-
return failure();
291+
// Skip following checks if weight is not ranked
292+
return success();
293293
}
294294

295295
auto inputEType = inputType.getElementType();
@@ -2899,14 +2899,6 @@ LogicalResult TransposeConv2DOp::verify() {
28992899
return emitOpError("expect all stride values to be >= 1, got [")
29002900
<< strides << "]";
29012901

2902-
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
2903-
2904-
const auto outputType =
2905-
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2906-
2907-
const auto weightType =
2908-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2909-
29102902
const auto checkPadAgainstKernelDim =
29112903
[this](int64_t pad_value, int64_t kernel_dim_size,
29122904
llvm::StringRef pad_name,
@@ -2920,69 +2912,77 @@ LogicalResult TransposeConv2DOp::verify() {
29202912
};
29212913

29222914
const llvm::ArrayRef<int64_t> padding = getOutPad();
2923-
29242915
const int64_t outPadTop = padding[0];
29252916
const int64_t outPadBottom = padding[1];
2917+
const int64_t outPadLeft = padding[2];
2918+
const int64_t outPadRight = padding[3];
29262919

2927-
const int64_t kernelHeight = weightType.getDimSize(1);
2928-
2929-
if (!ShapedType::isDynamic(kernelHeight)) {
2930-
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top",
2931-
"KH")))
2932-
return failure();
2933-
2934-
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
2935-
"out_pad_bottom", "KH")))
2936-
return failure();
2937-
}
2920+
const auto weightType =
2921+
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
29382922

2939-
const int64_t kernelWidth = weightType.getDimSize(2);
2923+
if (weightType) {
2924+
const int64_t kernelHeight = weightType.getDimSize(1);
2925+
if (!ShapedType::isDynamic(kernelHeight)) {
2926+
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
2927+
"out_pad_top", "KH")))
2928+
return failure();
29402929

2941-
const int64_t outPadLeft = padding[2];
2942-
const int64_t outPadRight = padding[3];
2930+
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
2931+
"out_pad_bottom", "KH")))
2932+
return failure();
2933+
}
29432934

2944-
if (!ShapedType::isDynamic(kernelWidth)) {
2945-
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left",
2946-
"KW")))
2947-
return failure();
2935+
const int64_t kernelWidth = weightType.getDimSize(2);
2936+
if (!ShapedType::isDynamic(kernelWidth)) {
2937+
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
2938+
"out_pad_left", "KW")))
2939+
return failure();
29482940

2949-
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
2950-
"out_pad_right", "KW")))
2951-
return failure();
2941+
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
2942+
"out_pad_right", "KW")))
2943+
return failure();
2944+
}
29522945
}
29532946

29542947
// Rest of the checks depend on the output type being a RankedTensorType
2948+
const auto outputType =
2949+
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
29552950
if (!outputType)
29562951
return success();
29572952

2958-
const int64_t inputHeight = inputType.getDimSize(1);
2959-
const int64_t outputHeight = outputType.getDimSize(1);
2960-
2961-
if (!ShapedType::isDynamic(inputHeight) &&
2962-
!ShapedType::isDynamic(outputHeight)) {
2963-
if (outputHeight !=
2964-
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
2965-
return emitOpError(
2966-
"dimension mismatch: expected OH == (IH - 1) * stride_y "
2967-
"+ out_pad_top + out_pad_bottom + KH, but got ")
2968-
<< outputHeight << " != (" << inputHeight << " - 1) * " << strideY
2969-
<< " + " << outPadTop << " + " << outPadBottom << " + "
2970-
<< kernelHeight;
2971-
}
2953+
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
2954+
if (inputType && weightType) {
2955+
const int64_t inputHeight = inputType.getDimSize(1);
2956+
const int64_t kernelHeight = weightType.getDimSize(1);
2957+
const int64_t outputHeight = outputType.getDimSize(1);
2958+
2959+
if (!ShapedType::isDynamic(inputHeight) &&
2960+
!ShapedType::isDynamic(outputHeight)) {
2961+
if (outputHeight !=
2962+
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
2963+
return emitOpError(
2964+
"dimension mismatch: expected OH == (IH - 1) * stride_y "
2965+
"+ out_pad_top + out_pad_bottom + KH, but got ")
2966+
<< outputHeight << " != (" << inputHeight << " - 1) * "
2967+
<< strideY << " + " << outPadTop << " + " << outPadBottom
2968+
<< " + " << kernelHeight;
2969+
}
29722970

2973-
const int64_t inputWidth = inputType.getDimSize(2);
2974-
const int64_t outputWidth = outputType.getDimSize(2);
2971+
const int64_t inputWidth = inputType.getDimSize(2);
2972+
const int64_t kernelWidth = weightType.getDimSize(2);
2973+
const int64_t outputWidth = outputType.getDimSize(2);
29752974

2976-
if (!ShapedType::isDynamic(inputWidth) &&
2977-
!ShapedType::isDynamic(outputWidth)) {
2978-
if (outputWidth !=
2979-
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
2980-
return emitOpError(
2981-
"dimension mismatch: expected OW == (IW - 1) * stride_x "
2982-
"+ out_pad_left + out_pad_right + KW, but got ")
2983-
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
2984-
<< " + " << outPadLeft << " + " << outPadRight << " + "
2985-
<< kernelWidth;
2975+
if (!ShapedType::isDynamic(inputWidth) &&
2976+
!ShapedType::isDynamic(outputWidth)) {
2977+
if (outputWidth !=
2978+
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
2979+
return emitOpError(
2980+
"dimension mismatch: expected OW == (IW - 1) * stride_x "
2981+
"+ out_pad_left + out_pad_right + KW, but got ")
2982+
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
2983+
<< " + " << outPadLeft << " + " << outPadRight << " + "
2984+
<< kernelWidth;
2985+
}
29862986
}
29872987

29882988
const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,9 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>,
3333

3434
// -----
3535

36-
func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
37-
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
38-
// expected-error@+1 {{'tosa.conv2d' op expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
39-
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
40-
: (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
41-
return %0 : tensor<1x27x27x16xi8>
42-
}
43-
44-
// -----
45-
4636
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
4737
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
48-
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
38+
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
4939
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
5040
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
5141
return %0 : tensor<1x27x27x16xi8>

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
7070
return %0 : tensor<1x4x4x8xf32>
7171
}
7272

73+
// -----
74+
// CHECK-LABEL: conv2d_unranked_input
75+
func.func @test_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
76+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<*xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
77+
return %0 : tensor<1x4x4x8xf32>
78+
}
79+
7380
// -----
7481
// CHECK-LABEL: conv2d_quant_uniform
7582
func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
@@ -202,6 +209,20 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x
202209
return %0 : tensor<1x32x32x16xf32>
203210
}
204211

212+
// -----
213+
// CHECK-LABEL: transpose_conv2d_unranked_input
214+
func.func @test_transpose_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
215+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<*xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
216+
return %0 : tensor<1x32x32x16xf32>
217+
}
218+
219+
// -----
220+
// CHECK-LABEL: transpose_conv2d_unranked_weight
221+
func.func @test_transpose_conv2d_unranked_weight(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<*xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
222+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<*xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
223+
return %0 : tensor<1x32x32x16xf32>
224+
}
225+
205226
// -----
206227
// CHECK-LABEL: transpose_conv2d_with_local_bound
207228
func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {

0 commit comments

Comments
 (0)