Skip to content

Commit 42a0fb2

Browse files
authored
[mlir][linalg] Add linalg.conv_2d_ngchw_gfchw_q to named ops (#92136)
Adds a named op: linalg.conv_2d_ngchw_gfchw_q. This op is similar to linalg.conv_2d_ngchw_gfchw, but additionally incorporates zero point offset corrections.
1 parent 7c917e8 commit 42a0fb2

File tree

4 files changed

+219
-0
lines changed

4 files changed

+219
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3478,6 +3478,144 @@ structured_op: !LinalgStructuredOpConfig
34783478
- !ScalarExpression
34793479
scalar_arg: K
34803480
--- !LinalgOpConfig
3481+
metadata: !LinalgOpMetadata
3482+
name: conv_2d_ngchw_gfchw_q
3483+
cpp_class_name: Conv2DNgchwGfchwQOp
3484+
doc: |-
3485+
Performs 2-D grouped convolution with zero-point offsets.
3486+
3487+
Layout:
3488+
* Input: NGCHW.
3489+
* Kernel: GFCHW.
3490+
3491+
Numeric casting is performed on the operands to the inner multiply, promoting
3492+
them to the same data type as the accumulator/output. This includes the zero
3493+
point offsets common to quantized operations.
3494+
implements:
3495+
- LinalgConvolutionOpInterface
3496+
structured_op: !LinalgStructuredOpConfig
3497+
args:
3498+
- !LinalgOperandDefConfig
3499+
name: I
3500+
kind: input_tensor
3501+
type_var: T1
3502+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3503+
(s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
3504+
- !LinalgOperandDefConfig
3505+
name: K
3506+
kind: input_tensor
3507+
type_var: T2
3508+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3509+
(s1, s11, s2, s5, s9)>
3510+
- !LinalgOperandDefConfig
3511+
name: IZp
3512+
kind: scalar
3513+
type_var: I32
3514+
- !LinalgOperandDefConfig
3515+
name: KZp
3516+
kind: scalar
3517+
type_var: I32
3518+
- !LinalgOperandDefConfig
3519+
name: O
3520+
kind: output_tensor
3521+
type_var: U
3522+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3523+
(s0, s1, s11, s3, s7)>
3524+
- !LinalgOperandDefConfig
3525+
name: strides
3526+
kind: index_attr
3527+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3528+
-> (s4, s8)>
3529+
default_indices:
3530+
- 1
3531+
- 1
3532+
- !LinalgOperandDefConfig
3533+
name: dilations
3534+
kind: index_attr
3535+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3536+
-> (s6, s10)>
3537+
default_indices:
3538+
- 1
3539+
- 1
3540+
indexing_maps: !LinalgIndexingMapsConfig
3541+
static_indexing_maps:
3542+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3543+
s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
3544+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3545+
s8, s9, s10, s11] -> (d1, d2, d5, d6, d7)>
3546+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3547+
s8, s9, s10, s11] -> ()>
3548+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3549+
s8, s9, s10, s11] -> ()>
3550+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3551+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3552+
iterator_types:
3553+
- parallel
3554+
- parallel
3555+
- parallel
3556+
- parallel
3557+
- parallel
3558+
- reduction
3559+
- reduction
3560+
- reduction
3561+
assignments:
3562+
- !ScalarAssign
3563+
arg: O
3564+
value: !ScalarExpression
3565+
scalar_fn:
3566+
kind: binary
3567+
fn_name: add
3568+
operands:
3569+
- !ScalarExpression
3570+
scalar_arg: O
3571+
- !ScalarExpression
3572+
scalar_fn:
3573+
kind: binary
3574+
fn_name: mul
3575+
operands:
3576+
- !ScalarExpression
3577+
scalar_fn:
3578+
kind: binary
3579+
fn_name: sub
3580+
operands:
3581+
- !ScalarExpression
3582+
scalar_fn:
3583+
kind: type
3584+
fn_name: cast_signed
3585+
type_var: U
3586+
operands:
3587+
- !ScalarExpression
3588+
scalar_arg: I
3589+
- !ScalarExpression
3590+
scalar_fn:
3591+
kind: type
3592+
fn_name: cast_signed
3593+
type_var: U
3594+
operands:
3595+
- !ScalarExpression
3596+
scalar_arg: IZp
3597+
- !ScalarExpression
3598+
scalar_fn:
3599+
kind: binary
3600+
fn_name: sub
3601+
operands:
3602+
- !ScalarExpression
3603+
scalar_fn:
3604+
kind: type
3605+
fn_name: cast_signed
3606+
type_var: U
3607+
operands:
3608+
- !ScalarExpression
3609+
scalar_arg: K
3610+
- !ScalarExpression
3611+
scalar_fn:
3612+
kind: type
3613+
fn_name: cast_signed
3614+
type_var: U
3615+
operands:
3616+
- !ScalarExpression
3617+
scalar_arg: KZp
3618+
--- !LinalgOpConfig
34813619
metadata: !LinalgOpMetadata
34823620
name: conv_3d_ndhwc_dhwcf
34833621
cpp_class_name: Conv3DNdhwcDhwcfOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,41 @@ def conv_2d_ngchw_gfchw(
958958
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
959959

960960

961+
@linalg_structured_op
962+
def conv_2d_ngchw_gfchw_q(
963+
I=TensorDef(
964+
T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
965+
),
966+
K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
967+
IZp=ScalarDef(I32),
968+
KZp=ScalarDef(I32),
969+
O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
970+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
971+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
972+
):
973+
"""Performs 2-D grouped convolution with zero-point offsets.
974+
975+
Layout:
976+
* Input: NGCHW.
977+
* Kernel: GFCHW.
978+
979+
Numeric casting is performed on the operands to the inner multiply, promoting
980+
them to the same data type as the accumulator/output. This includes the zero
981+
point offsets common to quantized operations.
982+
"""
983+
implements(ConvolutionOpInterface)
984+
domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
985+
O[D.n, D.g, D.fg, D.oh, D.ow] += (
986+
TypeFn.cast_signed(
987+
U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
988+
)
989+
- TypeFn.cast_signed(U, IZp)
990+
) * (
991+
TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
992+
- TypeFn.cast_signed(U, KZp)
993+
)
994+
995+
961996
@linalg_structured_op
962997
def conv_3d_ndhwc_dhwcf(
963998
I=TensorDef(

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,37 @@ func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>
204204

205205
// -----
206206

207+
func.func @conv_2d_ngchw_gfchw_q(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xi32>) {
208+
linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
209+
strides = dense<1> : tensor<2xi64>}
210+
ins (%input, %filter, %inputzp, %filterzp: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
211+
outs (%output: memref<?x?x?x?x?xi32>)
212+
return
213+
}
214+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
215+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
216+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
217+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
218+
219+
// CHECK: func @conv_2d_ngchw_gfchw_q
220+
221+
// CHECK: linalg.generic
222+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP3]]]
223+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
224+
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
225+
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xi32>)
226+
227+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i32, %[[BBARG3:.+]]: i32, %[[BBARG4:.+]]: i32)
228+
// CHECK-NEXT: %[[EXTSI0:.+]] = arith.extsi %[[BBARG0]] : i8 to i32
229+
// CHECK-NEXT: %[[SUB0:.+]] = arith.subi %[[EXTSI0]], %[[BBARG2]] : i32
230+
// CHECK-NEXT: %[[EXTSI1:.+]] = arith.extsi %[[BBARG1]] : i8 to i32
231+
// CHECK-NEXT: %[[SUB1:.+]] = arith.subi %[[EXTSI1]], %[[BBARG3]] : i32
232+
// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[SUB0]], %[[SUB1]] : i32
233+
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[BBARG4]], %[[MUL]] : i32
234+
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
235+
236+
// -----
237+
207238
func.func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
208239
linalg.fill ins(%value : f32) outs(%output : memref<?x?xf32>)
209240
return

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,21 @@ func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<
441441

442442
// -----
443443

444+
// CHECK-LABEL: func @conv_2d_ngchw_gfchw_q
445+
func.func @conv_2d_ngchw_gfchw_q(%input: tensor<1x5x3x32x32xi8>, %filter: tensor<5x2x3x3x3xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<1x5x2x30x30xi32>) -> tensor<1x5x2x30x30xi32> {
446+
// CHECK: linalg.conv_2d_ngchw_gfchw_q
447+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
448+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
449+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xi8>, tensor<5x2x3x3x3xi8>, i32, i32)
450+
// CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xi32>) -> tensor<1x5x2x30x30xi32>
451+
%0 = linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
452+
strides = dense<1> : tensor<2xi64>}
453+
ins (%input, %filter, %inputzp, %filterzp: tensor<1x5x3x32x32xi8>, tensor<5x2x3x3x3xi8>, i32, i32)
454+
outs (%init: tensor<1x5x2x30x30xi32>) -> tensor<1x5x2x30x30xi32>
455+
return %0 : tensor<1x5x2x30x30xi32>
456+
}
457+
// -----
458+
444459
// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
445460
func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
446461
// CHECK: %{{.+}} = linalg.conv_3d_ndhwc_dhwcf

0 commit comments

Comments
 (0)