Skip to content

[mlir][tosa] Support unranked input/weight tensors for convolution ops #134856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor4D:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Expand Down Expand Up @@ -172,7 +172,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {

let arguments = (ins
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor5D:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Expand Down Expand Up @@ -218,7 +218,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor4D:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Expand Down Expand Up @@ -434,7 +434,7 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor4D:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;

// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;

//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
//===----------------------------------------------------------------------===//
Expand Down
127 changes: 58 additions & 69 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,

template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input and weight arguments which must be ranked
// tensors.
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
if (!inputType) {
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
return failure();
}

auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
if (!weightType) {
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
return failure();
}
const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());

auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();
Expand Down Expand Up @@ -3063,14 +3052,6 @@ LogicalResult TransposeConv2DOp::verify() {
return emitOpError("expect all stride values to be >= 1, got [")
<< strides << "]";

const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());

const auto outputType =
llvm::dyn_cast<RankedTensorType>(getOutput().getType());

const auto weightType =
llvm::dyn_cast<RankedTensorType>(getWeight().getType());

const auto checkPadAgainstKernelDim =
[this](int64_t pad_value, int64_t kernel_dim_size,
llvm::StringRef pad_name,
Expand All @@ -3084,69 +3065,77 @@ LogicalResult TransposeConv2DOp::verify() {
};

const llvm::ArrayRef<int64_t> padding = getOutPad();

const int64_t outPadTop = padding[0];
const int64_t outPadBottom = padding[1];
const int64_t outPadLeft = padding[2];
const int64_t outPadRight = padding[3];

const int64_t kernelHeight = weightType.getDimSize(1);

if (!ShapedType::isDynamic(kernelHeight)) {
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top",
"KH")))
return failure();

if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
"out_pad_bottom", "KH")))
return failure();
}
const auto weightType =
llvm::dyn_cast<RankedTensorType>(getWeight().getType());

const int64_t kernelWidth = weightType.getDimSize(2);
if (weightType) {
const int64_t kernelHeight = weightType.getDimSize(1);
if (!ShapedType::isDynamic(kernelHeight)) {
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
"out_pad_top", "KH")))
return failure();

const int64_t outPadLeft = padding[2];
const int64_t outPadRight = padding[3];
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
"out_pad_bottom", "KH")))
return failure();
}

if (!ShapedType::isDynamic(kernelWidth)) {
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left",
"KW")))
return failure();
const int64_t kernelWidth = weightType.getDimSize(2);
if (!ShapedType::isDynamic(kernelWidth)) {
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
"out_pad_left", "KW")))
return failure();

if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
"out_pad_right", "KW")))
return failure();
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
"out_pad_right", "KW")))
return failure();
}
}

// Rest of the checks depend on the output type being a RankedTensorType
const auto outputType =
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
if (!outputType)
return success();

const int64_t inputHeight = inputType.getDimSize(1);
const int64_t outputHeight = outputType.getDimSize(1);

if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(outputHeight)) {
if (outputHeight !=
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
return emitOpError(
"dimension mismatch: expected OH == (IH - 1) * stride_y "
"+ out_pad_top + out_pad_bottom + KH, but got ")
<< outputHeight << " != (" << inputHeight << " - 1) * " << strideY
<< " + " << outPadTop << " + " << outPadBottom << " + "
<< kernelHeight;
}
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
if (inputType && weightType) {
const int64_t inputHeight = inputType.getDimSize(1);
const int64_t kernelHeight = weightType.getDimSize(1);
const int64_t outputHeight = outputType.getDimSize(1);

if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(outputHeight)) {
if (outputHeight !=
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
return emitOpError(
"dimension mismatch: expected OH == (IH - 1) * stride_y "
"+ out_pad_top + out_pad_bottom + KH, but got ")
<< outputHeight << " != (" << inputHeight << " - 1) * "
<< strideY << " + " << outPadTop << " + " << outPadBottom
<< " + " << kernelHeight;
}

const int64_t inputWidth = inputType.getDimSize(2);
const int64_t outputWidth = outputType.getDimSize(2);
const int64_t inputWidth = inputType.getDimSize(2);
const int64_t kernelWidth = weightType.getDimSize(2);
const int64_t outputWidth = outputType.getDimSize(2);

if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(outputWidth)) {
if (outputWidth !=
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
return emitOpError(
"dimension mismatch: expected OW == (IW - 1) * stride_x "
"+ out_pad_left + out_pad_right + KW, but got ")
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
<< " + " << outPadLeft << " + " << outPadRight << " + "
<< kernelWidth;
if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(outputWidth)) {
if (outputWidth !=
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
return emitOpError(
"dimension mismatch: expected OW == (IW - 1) * stride_x "
"+ out_pad_left + out_pad_right + KW, but got ")
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX
<< " + " << outPadLeft << " + " << outPadRight << " + "
<< kernelWidth;
}
}

const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
Expand Down
16 changes: 3 additions & 13 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,20 @@ func.func @test_const_non_tensor_attr() {

// -----

func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op expect both input and weight to be float or not together, got 'f32' and 'i8'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}

// -----

func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
: (tensor<*xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}

// -----

func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
return %0 : tensor<1x4x4x8xf32>
}

// -----
// CHECK-LABEL: conv2d_unranked_input
func.func @test_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<*xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}

// -----
// CHECK-LABEL: conv2d_quant_uniform
func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
Expand Down Expand Up @@ -202,6 +209,20 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x
return %0 : tensor<1x32x32x16xf32>
}

// -----
// CHECK-LABEL: transpose_conv2d_unranked_input
func.func @test_transpose_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<*xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}

// -----
// CHECK-LABEL: transpose_conv2d_unranked_weight
func.func @test_transpose_conv2d_unranked_weight(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<*xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<*xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}

// -----
// CHECK-LABEL: transpose_conv2d_with_local_bound
func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
Expand Down
Loading