Skip to content

Commit 7ce53e3

Browse files
committed
[mlir][tosa] Add tosa.conv3d lowering to Linalg
Conv3D has an existing linalg operation for floating point. Adding a quantized variant and corresponding lowering from TOSA. Numerical correctness was validated using the TOSA conformance tests. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D140919
1 parent 75d268d commit 7ce53e3

File tree

5 files changed

+334
-83
lines changed

5 files changed

+334
-83
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2001,6 +2001,145 @@ structured_op: !LinalgStructuredOpConfig
20012001
- !ScalarExpression
20022002
scalar_arg: K
20032003
--- !LinalgOpConfig
2004+
metadata: !LinalgOpMetadata
2005+
name: conv_3d_ndhwc_dhwcf_q
2006+
cpp_class_name: Conv3DNdhwcDhwcfQOp
2007+
doc: |-
2008+
Performs 3-D convolution with zero point offsets.
2009+
2010+
Numeric casting is performed on the operands to the inner multiply, promoting
2011+
them to the same data type as the accumulator/output. This includes the zero
2012+
point offsets common to quantized operations.
2013+
implements:
2014+
- LinalgConvolutionOpInterface
2015+
structured_op: !LinalgStructuredOpConfig
2016+
args:
2017+
- !LinalgOperandDefConfig
2018+
name: I
2019+
kind: input_tensor
2020+
type_var: T1
2021+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
2022+
s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12,
2023+
s13)>
2024+
- !LinalgOperandDefConfig
2025+
name: K
2026+
kind: input_tensor
2027+
type_var: T2
2028+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
2029+
s13, s14] -> (s3, s7, s11, s13, s14)>
2030+
- !LinalgOperandDefConfig
2031+
name: IZp
2032+
kind: scalar
2033+
type_var: I32
2034+
- !LinalgOperandDefConfig
2035+
name: KZp
2036+
kind: scalar
2037+
type_var: I32
2038+
- !LinalgOperandDefConfig
2039+
name: O
2040+
kind: output_tensor
2041+
type_var: U
2042+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
2043+
s13, s14] -> (s0, s1, s5, s9, s14)>
2044+
- !LinalgOperandDefConfig
2045+
name: strides
2046+
kind: index_attr
2047+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
2048+
s12, s13, s14] -> (s2, s6, s10)>
2049+
default_indices:
2050+
- 1
2051+
- 1
2052+
- 1
2053+
- !LinalgOperandDefConfig
2054+
name: dilations
2055+
kind: index_attr
2056+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
2057+
s12, s13, s14] -> (s4, s8, s12)>
2058+
default_indices:
2059+
- 1
2060+
- 1
2061+
- 1
2062+
indexing_maps: !LinalgIndexingMapsConfig
2063+
static_indexing_maps:
2064+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
2065+
s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6
2066+
* s8, d3 * s10 + d7 * s12, d8)>
2067+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
2068+
s7, s8, s9, s10, s11, s12, s13, s14] -> (d5, d6, d7, d8, d4)>
2069+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
2070+
s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
2071+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
2072+
s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
2073+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
2074+
s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1, d2, d3, d4)>
2075+
iterator_types:
2076+
- parallel
2077+
- parallel
2078+
- parallel
2079+
- parallel
2080+
- parallel
2081+
- reduction
2082+
- reduction
2083+
- reduction
2084+
- reduction
2085+
assignments:
2086+
- !ScalarAssign
2087+
arg: O
2088+
value: !ScalarExpression
2089+
scalar_fn:
2090+
kind: binary
2091+
fn_name: add
2092+
operands:
2093+
- !ScalarExpression
2094+
scalar_arg: O
2095+
- !ScalarExpression
2096+
scalar_fn:
2097+
kind: binary
2098+
fn_name: mul
2099+
operands:
2100+
- !ScalarExpression
2101+
scalar_fn:
2102+
kind: binary
2103+
fn_name: sub
2104+
operands:
2105+
- !ScalarExpression
2106+
scalar_fn:
2107+
kind: type
2108+
fn_name: cast_signed
2109+
type_var: U
2110+
operands:
2111+
- !ScalarExpression
2112+
scalar_arg: I
2113+
- !ScalarExpression
2114+
scalar_fn:
2115+
kind: type
2116+
fn_name: cast_signed
2117+
type_var: U
2118+
operands:
2119+
- !ScalarExpression
2120+
scalar_arg: IZp
2121+
- !ScalarExpression
2122+
scalar_fn:
2123+
kind: binary
2124+
fn_name: sub
2125+
operands:
2126+
- !ScalarExpression
2127+
scalar_fn:
2128+
kind: type
2129+
fn_name: cast_signed
2130+
type_var: U
2131+
operands:
2132+
- !ScalarExpression
2133+
scalar_arg: K
2134+
- !ScalarExpression
2135+
scalar_fn:
2136+
kind: type
2137+
fn_name: cast_signed
2138+
type_var: U
2139+
operands:
2140+
- !ScalarExpression
2141+
scalar_arg: KZp
2142+
--- !LinalgOpConfig
20042143
metadata: !LinalgOpMetadata
20052144
name: depthwise_conv_1d_nwc_wc
20062145
cpp_class_name: DepthwiseConv1DNwcWcOp
@@ -4441,3 +4580,4 @@ structured_op: !LinalgStructuredOpConfig
44414580
scalar_const: '2.3283063999999999E-10 : f64'
44424581
- !ScalarExpression
44434582
scalar_arg: min
4583+

0 commit comments

Comments
 (0)