@@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
278
278
279
279
template <typename T>
280
280
static LogicalResult verifyConvOp (T op) {
281
- // All TOSA conv ops have an input and weight arguments which must be ranked
282
- // tensors.
283
- auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
284
- if (!inputType) {
285
- op.emitOpError (" expect a ranked tensor for input, got " ) << op.getInput ();
286
- return failure ();
287
- }
288
-
289
- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
290
- if (!weightType) {
291
- op.emitOpError (" expect a ranked tensor for weight, got " ) << op.getWeight ();
292
- return failure ();
293
- }
281
+ const auto inputType = llvm::dyn_cast<TensorType>(op.getInput ().getType ());
282
+ const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight ().getType ());
294
283
295
284
auto inputEType = inputType.getElementType ();
296
285
auto weightEType = weightType.getElementType ();
@@ -3063,14 +3052,6 @@ LogicalResult TransposeConv2DOp::verify() {
3063
3052
return emitOpError (" expect all stride values to be >= 1, got [" )
3064
3053
<< strides << " ]" ;
3065
3054
3066
- const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
3067
-
3068
- const auto outputType =
3069
- llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
3070
-
3071
- const auto weightType =
3072
- llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
3073
-
3074
3055
const auto checkPadAgainstKernelDim =
3075
3056
[this ](int64_t pad_value, int64_t kernel_dim_size,
3076
3057
llvm::StringRef pad_name,
@@ -3084,69 +3065,77 @@ LogicalResult TransposeConv2DOp::verify() {
3084
3065
};
3085
3066
3086
3067
const llvm::ArrayRef<int64_t > padding = getOutPad ();
3087
-
3088
3068
const int64_t outPadTop = padding[0 ];
3089
3069
const int64_t outPadBottom = padding[1 ];
3070
+ const int64_t outPadLeft = padding[2 ];
3071
+ const int64_t outPadRight = padding[3 ];
3090
3072
3091
- const int64_t kernelHeight = weightType.getDimSize (1 );
3092
-
3093
- if (!ShapedType::isDynamic (kernelHeight)) {
3094
- if (failed (checkPadAgainstKernelDim (outPadTop, kernelHeight, " out_pad_top" ,
3095
- " KH" )))
3096
- return failure ();
3097
-
3098
- if (failed (checkPadAgainstKernelDim (outPadBottom, kernelHeight,
3099
- " out_pad_bottom" , " KH" )))
3100
- return failure ();
3101
- }
3073
+ const auto weightType =
3074
+ llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
3102
3075
3103
- const int64_t kernelWidth = weightType.getDimSize (2 );
3076
+ if (weightType) {
3077
+ const int64_t kernelHeight = weightType.getDimSize (1 );
3078
+ if (!ShapedType::isDynamic (kernelHeight)) {
3079
+ if (failed (checkPadAgainstKernelDim (outPadTop, kernelHeight,
3080
+ " out_pad_top" , " KH" )))
3081
+ return failure ();
3104
3082
3105
- const int64_t outPadLeft = padding[2 ];
3106
- const int64_t outPadRight = padding[3 ];
3083
+ if (failed (checkPadAgainstKernelDim (outPadBottom, kernelHeight,
3084
+ " out_pad_bottom" , " KH" )))
3085
+ return failure ();
3086
+ }
3107
3087
3108
- if (!ShapedType::isDynamic (kernelWidth)) {
3109
- if (failed (checkPadAgainstKernelDim (outPadLeft, kernelWidth, " out_pad_left" ,
3110
- " KW" )))
3111
- return failure ();
3088
+ const int64_t kernelWidth = weightType.getDimSize (2 );
3089
+ if (!ShapedType::isDynamic (kernelWidth)) {
3090
+ if (failed (checkPadAgainstKernelDim (outPadLeft, kernelWidth,
3091
+ " out_pad_left" , " KW" )))
3092
+ return failure ();
3112
3093
3113
- if (failed (checkPadAgainstKernelDim (outPadRight, kernelWidth,
3114
- " out_pad_right" , " KW" )))
3115
- return failure ();
3094
+ if (failed (checkPadAgainstKernelDim (outPadRight, kernelWidth,
3095
+ " out_pad_right" , " KW" )))
3096
+ return failure ();
3097
+ }
3116
3098
}
3117
3099
3118
3100
// Rest of the checks depend on the output type being a RankedTensorType
3101
+ const auto outputType =
3102
+ llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
3119
3103
if (!outputType)
3120
3104
return success ();
3121
3105
3122
- const int64_t inputHeight = inputType.getDimSize (1 );
3123
- const int64_t outputHeight = outputType.getDimSize (1 );
3124
-
3125
- if (!ShapedType::isDynamic (inputHeight) &&
3126
- !ShapedType::isDynamic (outputHeight)) {
3127
- if (outputHeight !=
3128
- (inputHeight - 1 ) * strideY + outPadTop + outPadBottom + kernelHeight)
3129
- return emitOpError (
3130
- " dimension mismatch: expected OH == (IH - 1) * stride_y "
3131
- " + out_pad_top + out_pad_bottom + KH, but got " )
3132
- << outputHeight << " != (" << inputHeight << " - 1) * " << strideY
3133
- << " + " << outPadTop << " + " << outPadBottom << " + "
3134
- << kernelHeight;
3135
- }
3106
+ const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
3107
+ if (inputType && weightType) {
3108
+ const int64_t inputHeight = inputType.getDimSize (1 );
3109
+ const int64_t kernelHeight = weightType.getDimSize (1 );
3110
+ const int64_t outputHeight = outputType.getDimSize (1 );
3111
+
3112
+ if (!ShapedType::isDynamic (inputHeight) &&
3113
+ !ShapedType::isDynamic (outputHeight)) {
3114
+ if (outputHeight !=
3115
+ (inputHeight - 1 ) * strideY + outPadTop + outPadBottom + kernelHeight)
3116
+ return emitOpError (
3117
+ " dimension mismatch: expected OH == (IH - 1) * stride_y "
3118
+ " + out_pad_top + out_pad_bottom + KH, but got " )
3119
+ << outputHeight << " != (" << inputHeight << " - 1) * "
3120
+ << strideY << " + " << outPadTop << " + " << outPadBottom
3121
+ << " + " << kernelHeight;
3122
+ }
3136
3123
3137
- const int64_t inputWidth = inputType.getDimSize (2 );
3138
- const int64_t outputWidth = outputType.getDimSize (2 );
3124
+ const int64_t inputWidth = inputType.getDimSize (2 );
3125
+ const int64_t kernelWidth = weightType.getDimSize (2 );
3126
+ const int64_t outputWidth = outputType.getDimSize (2 );
3139
3127
3140
- if (!ShapedType::isDynamic (inputWidth) &&
3141
- !ShapedType::isDynamic (outputWidth)) {
3142
- if (outputWidth !=
3143
- (inputWidth - 1 ) * strideX + outPadLeft + outPadRight + kernelWidth)
3144
- return emitOpError (
3145
- " dimension mismatch: expected OW == (IW - 1) * stride_x "
3146
- " + out_pad_left + out_pad_right + KW, but got " )
3147
- << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3148
- << " + " << outPadLeft << " + " << outPadRight << " + "
3149
- << kernelWidth;
3128
+ if (!ShapedType::isDynamic (inputWidth) &&
3129
+ !ShapedType::isDynamic (outputWidth)) {
3130
+ if (outputWidth !=
3131
+ (inputWidth - 1 ) * strideX + outPadLeft + outPadRight + kernelWidth)
3132
+ return emitOpError (
3133
+ " dimension mismatch: expected OW == (IW - 1) * stride_x "
3134
+ " + out_pad_left + out_pad_right + KW, but got " )
3135
+ << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3136
+ << " + " << outPadLeft << " + " << outPadRight << " + "
3137
+ << kernelWidth;
3138
+ }
3150
3139
}
3151
3140
3152
3141
const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias ().getType ());
0 commit comments