Skip to content

Commit 63a2b0b

Browse files
authored
[mlir][tosa] Support unranked input/weight tensors for convolution ops (#134856)
This commit ensures that convolution operators including: conv2d, depthwise_conv2d, transpose_conv2d and conv3d, can have unranked input/weight operands. In order to support operands with unranked tensors, the tablegen definition was relaxed. The relaxation of tensor type will later be checked by the validation pass, should the user wish to use it. Signed-off-by: Luke Hutton <[email protected]>
1 parent 9147569 commit 63a2b0b

File tree

5 files changed

+86
-91
lines changed

5 files changed

+86
-91
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
125125

126126
let arguments = (ins
127127
Tosa_Tensor4D:$input,
128-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
128+
Tosa_Tensor4D:$weight,
129129
Tosa_Tensor1D:$bias,
130130
Tosa_ScalarIntOrFloatTensor:$input_zp,
131131
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -172,7 +172,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
172172

173173
let arguments = (ins
174174
Tosa_Tensor5D:$input,
175-
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
175+
Tosa_Tensor5D:$weight,
176176
Tosa_Tensor1D:$bias,
177177
Tosa_ScalarIntOrFloatTensor:$input_zp,
178178
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -218,7 +218,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
218218

219219
let arguments = (ins
220220
Tosa_Tensor4D:$input,
221-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
221+
Tosa_Tensor4D:$weight,
222222
Tosa_Tensor1D:$bias,
223223
Tosa_ScalarIntOrFloatTensor:$input_zp,
224224
Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -434,7 +434,7 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
434434

435435
let arguments = (ins
436436
Tosa_Tensor4D:$input,
437-
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
437+
Tosa_Tensor4D:$weight,
438438
Tosa_Tensor1D:$bias,
439439
Tosa_ScalarIntOrFloatTensor:$input_zp,
440440
Tosa_ScalarIntOrFloatTensor:$weight_zp,

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
8484
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8585
"number">;
8686

87-
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
88-
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
89-
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
90-
Tosa_QuantizedInt, AnyFloat]>;
91-
9287
//===----------------------------------------------------------------------===//
9388
// TOSA Tensor Conformance
9489
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 58 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
278278

279279
template <typename T>
280280
static LogicalResult verifyConvOp(T op) {
281-
// All TOSA conv ops have an input and weight arguments which must be ranked
282-
// tensors.
283-
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
284-
if (!inputType) {
285-
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
286-
return failure();
287-
}
288-
289-
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
290-
if (!weightType) {
291-
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
292-
return failure();
293-
}
281+
const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
282+
const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
294283

295284
auto inputEType = inputType.getElementType();
296285
auto weightEType = weightType.getElementType();
@@ -3063,14 +3052,6 @@ LogicalResult TransposeConv2DOp::verify() {
30633052
return emitOpError("expect all stride values to be >= 1, got [")
30643053
<< strides << "]";
30653054

3066-
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3067-
3068-
const auto outputType =
3069-
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3070-
3071-
const auto weightType =
3072-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3073-
30743055
const auto checkPadAgainstKernelDim =
30753056
[this](int64_t pad_value, int64_t kernel_dim_size,
30763057
llvm::StringRef pad_name,
@@ -3084,69 +3065,77 @@ LogicalResult TransposeConv2DOp::verify() {
30843065
};
30853066

30863067
const llvm::ArrayRef<int64_t> padding = getOutPad();
3087-
30883068
const int64_t outPadTop = padding[0];
30893069
const int64_t outPadBottom = padding[1];
3070+
const int64_t outPadLeft = padding[2];
3071+
const int64_t outPadRight = padding[3];
30903072

3091-
const int64_t kernelHeight = weightType.getDimSize(1);
3092-
3093-
if (!ShapedType::isDynamic(kernelHeight)) {
3094-
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top",
3095-
"KH")))
3096-
return failure();
3097-
3098-
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3099-
"out_pad_bottom", "KH")))
3100-
return failure();
3101-
}
3073+
const auto weightType =
3074+
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
31023075

3103-
const int64_t kernelWidth = weightType.getDimSize(2);
3076+
if (weightType) {
3077+
const int64_t kernelHeight = weightType.getDimSize(1);
3078+
if (!ShapedType::isDynamic(kernelHeight)) {
3079+
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3080+
"out_pad_top", "KH")))
3081+
return failure();
31043082

3105-
const int64_t outPadLeft = padding[2];
3106-
const int64_t outPadRight = padding[3];
3083+
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3084+
"out_pad_bottom", "KH")))
3085+
return failure();
3086+
}
31073087

3108-
if (!ShapedType::isDynamic(kernelWidth)) {
3109-
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left",
3110-
"KW")))
3111-
return failure();
3088+
const int64_t kernelWidth = weightType.getDimSize(2);
3089+
if (!ShapedType::isDynamic(kernelWidth)) {
3090+
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3091+
"out_pad_left", "KW")))
3092+
return failure();
31123093

3113-
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3114-
"out_pad_right", "KW")))
3115-
return failure();
3094+
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3095+
"out_pad_right", "KW")))
3096+
return failure();
3097+
}
31163098
}
31173099

31183100
// Rest of the checks depend on the output type being a RankedTensorType
3101+
const auto outputType =
3102+
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
31193103
if (!outputType)
31203104
return success();
31213105

3122-
const int64_t inputHeight = inputType.getDimSize(1);
3123-
const int64_t outputHeight = outputType.getDimSize(1);
3124-
3125-
if (!ShapedType::isDynamic(inputHeight) &&
3126-
!ShapedType::isDynamic(outputHeight)) {
3127-
if (outputHeight !=
3128-
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3129-
return emitOpError(
3130-
"dimension mismatch: expected OH == (IH - 1) * stride_y "
3131-
"+ out_pad_top + out_pad_bottom + KH, but got ")
3132-
<< outputHeight << " != (" << inputHeight << " - 1) * " << strideY
3133-
<< " + " << outPadTop << " + " << outPadBottom << " + "
3134-
<< kernelHeight;
3135-
}
3106+
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3107+
if (inputType && weightType) {
3108+
const int64_t inputHeight = inputType.getDimSize(1);
3109+
const int64_t kernelHeight = weightType.getDimSize(1);
3110+
const int64_t outputHeight = outputType.getDimSize(1);
3111+
3112+
if (!ShapedType::isDynamic(inputHeight) &&
3113+
!ShapedType::isDynamic(outputHeight)) {
3114+
if (outputHeight !=
3115+
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3116+
return emitOpError(
3117+
"dimension mismatch: expected OH == (IH - 1) * stride_y "
3118+
"+ out_pad_top + out_pad_bottom + KH, but got ")
3119+
<< outputHeight << " != (" << inputHeight << " - 1) * "
3120+
<< strideY << " + " << outPadTop << " + " << outPadBottom
3121+
<< " + " << kernelHeight;
3122+
}
31363123

3137-
const int64_t inputWidth = inputType.getDimSize(2);
3138-
const int64_t outputWidth = outputType.getDimSize(2);
3124+
const int64_t inputWidth = inputType.getDimSize(2);
3125+
const int64_t kernelWidth = weightType.getDimSize(2);
3126+
const int64_t outputWidth = outputType.getDimSize(2);
31393127

3140-
if (!ShapedType::isDynamic(inputWidth) &&
3141-
!ShapedType::isDynamic(outputWidth)) {
3142-
if (outputWidth !=
3143-
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3144-
return emitOpError(
3145-
"dimension mismatch: expected OW == (IW - 1) * stride_x "
3146-
"+ out_pad_left + out_pad_right + KW, but got ")
3147-
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3148-
<< " + " << outPadLeft << " + " << outPadRight << " + "
3149-
<< kernelWidth;
3128+
if (!ShapedType::isDynamic(inputWidth) &&
3129+
!ShapedType::isDynamic(outputWidth)) {
3130+
if (outputWidth !=
3131+
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3132+
return emitOpError(
3133+
"dimension mismatch: expected OW == (IW - 1) * stride_x "
3134+
"+ out_pad_left + out_pad_right + KW, but got ")
3135+
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3136+
<< " + " << outPadLeft << " + " << outPadRight << " + "
3137+
<< kernelWidth;
3138+
}
31503139
}
31513140

31523141
const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,20 @@ func.func @test_const_non_tensor_attr() {
2222

2323
// -----
2424

25-
func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
25+
func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
2626
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
2727
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
2828
// expected-error@+1 {{'tosa.conv2d' op expect both input and weight to be float or not together, got 'f32' and 'i8'}}
2929
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
30-
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
31-
return %0 : tensor<1x27x27x16xi8>
32-
}
33-
34-
// -----
35-
36-
func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
37-
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
38-
// expected-error@+1 {{'tosa.conv2d' op expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
39-
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
40-
: (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
30+
: (tensor<*xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
4131
return %0 : tensor<1x27x27x16xi8>
4232
}
4333

4434
// -----
4535

4636
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
4737
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
48-
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
38+
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
4939
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
5040
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
5141
return %0 : tensor<1x27x27x16xi8>

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
7070
return %0 : tensor<1x4x4x8xf32>
7171
}
7272

73+
// -----
74+
// CHECK-LABEL: conv2d_unranked_input
75+
func.func @test_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
76+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<*xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
77+
return %0 : tensor<1x4x4x8xf32>
78+
}
79+
7380
// -----
7481
// CHECK-LABEL: conv2d_quant_uniform
7582
func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
@@ -202,6 +209,20 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x
202209
return %0 : tensor<1x32x32x16xf32>
203210
}
204211

212+
// -----
213+
// CHECK-LABEL: transpose_conv2d_unranked_input
214+
func.func @test_transpose_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
215+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<*xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
216+
return %0 : tensor<1x32x32x16xf32>
217+
}
218+
219+
// -----
220+
// CHECK-LABEL: transpose_conv2d_unranked_weight
221+
func.func @test_transpose_conv2d_unranked_weight(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<*xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
222+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<*xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
223+
return %0 : tensor<1x32x32x16xf32>
224+
}
225+
205226
// -----
206227
// CHECK-LABEL: transpose_conv2d_with_local_bound
207228
func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {

0 commit comments

Comments
 (0)