Skip to content

Commit 40a5eb5

Browse files
committed
[mlir][tosa] Support unranked input/weight tensors for convolution ops
This commit ensures that convolution operators including: conv2d, depthwise_conv2d, transpose_conv2d and conv3d, can have unranked input/weight operands. In order to support operands with unranked tensors, the tablegen definition was relaxed. The relaxation of tensor type will later be checked by the validation pass, should the user wish to use it. Change-Id: I33334909e0d4d0676daae81bfc4647e86abc063a Signed-off-by: Luke Hutton <[email protected]>
1 parent 4e073a1 commit 40a5eb5

File tree

5 files changed

+86
-91
lines changed

5 files changed

+86
-91
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
125125

126126
let arguments = (ins
127127
Tosa_Tensor4D:$input,
128-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
128+
Tosa_Tensor4D:$weight,
129129
Tosa_Tensor1D:$bias,
130130
Tosa_ScalarIntOrFloatTensor:$input_zp,
131131
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -172,7 +172,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
172172

173173
let arguments = (ins
174174
Tosa_Tensor5D:$input,
175-
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
175+
Tosa_Tensor5D:$weight,
176176
Tosa_Tensor1D:$bias,
177177
Tosa_ScalarIntOrFloatTensor:$input_zp,
178178
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -218,7 +218,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
218218

219219
let arguments = (ins
220220
Tosa_Tensor4D:$input,
221-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
221+
Tosa_Tensor4D:$weight,
222222
Tosa_Tensor1D:$bias,
223223
Tosa_ScalarIntOrFloatTensor:$input_zp,
224224
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -434,7 +434,7 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
434434

435435
let arguments = (ins
436436
Tosa_Tensor4D:$input,
437-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
437+
Tosa_Tensor4D:$weight,
438438
Tosa_Tensor1D:$bias,
439439
Tosa_ScalarIntOrFloatTensor:$input_zp,
440440
Tosa_ScalarIntOrFloatTensor:$weight_zp,

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
8484
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8585
"number">;
8686

87-
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
88-
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
89-
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
90-
Tosa_QuantizedInt, AnyFloat]>;
91-
9287
//===----------------------------------------------------------------------===//
9388
// TOSA Tensor Conformance
9489
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 58 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
278278

279279
template <typename T>
280280
static LogicalResult verifyConvOp(T op) {
281-
// All TOSA conv ops have an input and weight arguments which must be ranked
282-
// tensors.
283-
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
284-
if (!inputType) {
285-
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
286-
return failure();
287-
}
288-
289-
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
290-
if (!weightType) {
291-
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
292-
return failure();
293-
}
281+
const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
282+
const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
294283

295284
auto inputEType = inputType.getElementType();
296285
auto weightEType = weightType.getElementType();
@@ -2998,14 +2987,6 @@ LogicalResult TransposeConv2DOp::verify() {
29982987
return emitOpError("expect all stride values to be >= 1, got [")
29992988
<< strides << "]";
30002989

3001-
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3002-
3003-
const auto outputType =
3004-
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3005-
3006-
const auto weightType =
3007-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3008-
30092990
const auto checkPadAgainstKernelDim =
30102991
[this](int64_t pad_value, int64_t kernel_dim_size,
30112992
llvm::StringRef pad_name,
@@ -3019,69 +3000,77 @@ LogicalResult TransposeConv2DOp::verify() {
30193000
};
30203001

30213002
const llvm::ArrayRef<int64_t> padding = getOutPad();
3022-
30233003
const int64_t outPadTop = padding[0];
30243004
const int64_t outPadBottom = padding[1];
3005+
const int64_t outPadLeft = padding[2];
3006+
const int64_t outPadRight = padding[3];
30253007

3026-
const int64_t kernelHeight = weightType.getDimSize(1);
3027-
3028-
if (!ShapedType::isDynamic(kernelHeight)) {
3029-
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top",
3030-
"KH")))
3031-
return failure();
3032-
3033-
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3034-
"out_pad_bottom", "KH")))
3035-
return failure();
3036-
}
3008+
const auto weightType =
3009+
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
30373010

3038-
const int64_t kernelWidth = weightType.getDimSize(2);
3011+
if (weightType) {
3012+
const int64_t kernelHeight = weightType.getDimSize(1);
3013+
if (!ShapedType::isDynamic(kernelHeight)) {
3014+
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3015+
"out_pad_top", "KH")))
3016+
return failure();
30393017

3040-
const int64_t outPadLeft = padding[2];
3041-
const int64_t outPadRight = padding[3];
3018+
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3019+
"out_pad_bottom", "KH")))
3020+
return failure();
3021+
}
30423022

3043-
if (!ShapedType::isDynamic(kernelWidth)) {
3044-
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left",
3045-
"KW")))
3046-
return failure();
3023+
const int64_t kernelWidth = weightType.getDimSize(2);
3024+
if (!ShapedType::isDynamic(kernelWidth)) {
3025+
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3026+
"out_pad_left", "KW")))
3027+
return failure();
30473028

3048-
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3049-
"out_pad_right", "KW")))
3050-
return failure();
3029+
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3030+
"out_pad_right", "KW")))
3031+
return failure();
3032+
}
30513033
}
30523034

30533035
// Rest of the checks depend on the output type being a RankedTensorType
3036+
const auto outputType =
3037+
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
30543038
if (!outputType)
30553039
return success();
30563040

3057-
const int64_t inputHeight = inputType.getDimSize(1);
3058-
const int64_t outputHeight = outputType.getDimSize(1);
3059-
3060-
if (!ShapedType::isDynamic(inputHeight) &&
3061-
!ShapedType::isDynamic(outputHeight)) {
3062-
if (outputHeight !=
3063-
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3064-
return emitOpError(
3065-
"dimension mismatch: expected OH == (IH - 1) * stride_y "
3066-
"+ out_pad_top + out_pad_bottom + KH, but got ")
3067-
<< outputHeight << " != (" << inputHeight << " - 1) * " << strideY
3068-
<< " + " << outPadTop << " + " << outPadBottom << " + "
3069-
<< kernelHeight;
3070-
}
3041+
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3042+
if (inputType && weightType) {
3043+
const int64_t inputHeight = inputType.getDimSize(1);
3044+
const int64_t kernelHeight = weightType.getDimSize(1);
3045+
const int64_t outputHeight = outputType.getDimSize(1);
3046+
3047+
if (!ShapedType::isDynamic(inputHeight) &&
3048+
!ShapedType::isDynamic(outputHeight)) {
3049+
if (outputHeight !=
3050+
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3051+
return emitOpError(
3052+
"dimension mismatch: expected OH == (IH - 1) * stride_y "
3053+
"+ out_pad_top + out_pad_bottom + KH, but got ")
3054+
<< outputHeight << " != (" << inputHeight << " - 1) * "
3055+
<< strideY << " + " << outPadTop << " + " << outPadBottom
3056+
<< " + " << kernelHeight;
3057+
}
30713058

3072-
const int64_t inputWidth = inputType.getDimSize(2);
3073-
const int64_t outputWidth = outputType.getDimSize(2);
3059+
const int64_t inputWidth = inputType.getDimSize(2);
3060+
const int64_t kernelWidth = weightType.getDimSize(2);
3061+
const int64_t outputWidth = outputType.getDimSize(2);
30743062

3075-
if (!ShapedType::isDynamic(inputWidth) &&
3076-
!ShapedType::isDynamic(outputWidth)) {
3077-
if (outputWidth !=
3078-
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3079-
return emitOpError(
3080-
"dimension mismatch: expected OW == (IW - 1) * stride_x "
3081-
"+ out_pad_left + out_pad_right + KW, but got ")
3082-
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3083-
<< " + " << outPadLeft << " + " << outPadRight << " + "
3084-
<< kernelWidth;
3063+
if (!ShapedType::isDynamic(inputWidth) &&
3064+
!ShapedType::isDynamic(outputWidth)) {
3065+
if (outputWidth !=
3066+
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3067+
return emitOpError(
3068+
"dimension mismatch: expected OW == (IW - 1) * stride_x "
3069+
"+ out_pad_left + out_pad_right + KW, but got ")
3070+
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3071+
<< " + " << outPadLeft << " + " << outPadRight << " + "
3072+
<< kernelWidth;
3073+
}
30853074
}
30863075

30873076
const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,20 @@ func.func @test_const_non_tensor_attr() {
2222

2323
// -----
2424

25-
func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
25+
func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
2626
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
2727
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
2828
// expected-error@+1 {{'tosa.conv2d' op expect both input and weight to be float or not together, got 'f32' and 'i8'}}
2929
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
30-
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
31-
return %0 : tensor<1x27x27x16xi8>
32-
}
33-
34-
// -----
35-
36-
func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
37-
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
38-
// expected-error@+1 {{'tosa.conv2d' op expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
39-
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
40-
: (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
30+
: (tensor<*xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
4131
return %0 : tensor<1x27x27x16xi8>
4232
}
4333

4434
// -----
4535

4636
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
4737
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
48-
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
38+
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
4939
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
5040
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
5141
return %0 : tensor<1x27x27x16xi8>

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
7070
return %0 : tensor<1x4x4x8xf32>
7171
}
7272

73+
// -----
74+
// CHECK-LABEL: conv2d_unranked_input
75+
func.func @test_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
76+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<*xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
77+
return %0 : tensor<1x4x4x8xf32>
78+
}
79+
7380
// -----
7481
// CHECK-LABEL: conv2d_quant_uniform
7582
func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
@@ -202,6 +209,20 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x
202209
return %0 : tensor<1x32x32x16xf32>
203210
}
204211

212+
// -----
213+
// CHECK-LABEL: transpose_conv2d_unranked_input
214+
func.func @test_transpose_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
215+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<*xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
216+
return %0 : tensor<1x32x32x16xf32>
217+
}
218+
219+
// -----
220+
// CHECK-LABEL: transpose_conv2d_unranked_weight
221+
func.func @test_transpose_conv2d_unranked_weight(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<*xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
222+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<*xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
223+
return %0 : tensor<1x32x32x16xf32>
224+
}
225+
205226
// -----
206227
// CHECK-LABEL: transpose_conv2d_with_local_bound
207228
func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {

0 commit comments

Comments
 (0)