@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
1772
1772
into %res {test_attr } : tensor <1 x1 x?x1 xi32 > -> tensor <7 x?xi32 >
1773
1773
return %unpack : tensor <7 x?xi32 >
1774
1774
}
1775
+
1776
+ // -----
1777
+
1778
+ //===----------------------------------------------------------------------===//
1779
+ // linalg.unpack + tensor.extract_slice
1780
+ //===----------------------------------------------------------------------===//
1781
+
1782
+ func.func @fold_extract_slice_into_unpack (
1783
+ %src : tensor <28 x2 x?x16 x16 xf32 >, %dest : tensor <28 x32 x?xf32 >, %size : index
1784
+ ) -> tensor <28 x28 x?xf32 > {
1785
+ %unpack = linalg.unpack %src
1786
+ outer_dims_perm = [0 , 1 , 2 ]
1787
+ inner_dims_pos = [1 , 2 ]
1788
+ inner_tiles = [16 , 16 ]
1789
+ into %dest : tensor <28 x2 x?x16 x16 xf32 > -> tensor <28 x32 x?xf32 >
1790
+ %extracted_slice = tensor.extract_slice %unpack
1791
+ [0 , 0 , 0 ] [28 , 28 , %size ] [1 , 1 , 1 ] : tensor <28 x32 x?xf32 > to tensor <28 x28 x?xf32 >
1792
+ return %extracted_slice : tensor <28 x28 x?xf32 >
1793
+ }
1794
+
1795
+ // CHECK-LABEL: func @fold_extract_slice_into_unpack
1796
+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1797
+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1798
+ // CHECK-SAME: %[[SIZE:.+]]: index
1799
+ // CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1800
+ // CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1801
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1802
+ // CHECK-SAME: into %[[DEST_SLICE]]
1803
+ // CHECK: return %[[UNPACK]]
1804
+
1805
+ // -----
1806
+
1807
+ func.func @no_fold_extract_slice_into_unpack_rank_reducing (
1808
+ %src : tensor <28 x2 x16 xf32 >, %dest : tensor <28 x32 xf32 >
1809
+ ) -> tensor <28 xf32 > {
1810
+ %unpack = linalg.unpack %src
1811
+ outer_dims_perm = [0 , 1 ]
1812
+ inner_dims_pos = [1 ]
1813
+ inner_tiles = [16 ]
1814
+ into %dest : tensor <28 x2 x16 xf32 > -> tensor <28 x32 xf32 >
1815
+ %extracted_slice = tensor.extract_slice %unpack
1816
+ [0 , 0 ] [1 , 28 ] [1 , 1 ] : tensor <28 x32 xf32 > to tensor <28 xf32 >
1817
+ return %extracted_slice : tensor <28 xf32 >
1818
+ }
1819
+
1820
+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
1821
+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1822
+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1823
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1824
+ // CHECK-SAME: into %[[DEST]]
1825
+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1826
+ // CHECK: return %[[SLICE]]
1827
+
1828
+ // -----
1829
+
1830
+ func.func @no_fold_extract_slice_into_unpack_non_zero_offset (
1831
+ %src : tensor <28 x2 x16 xf32 >, %dest : tensor <28 x32 xf32 >
1832
+ ) -> tensor <28 x28 xf32 > {
1833
+ %unpack = linalg.unpack %src
1834
+ outer_dims_perm = [0 , 1 ]
1835
+ inner_dims_pos = [1 ]
1836
+ inner_tiles = [16 ]
1837
+ into %dest : tensor <28 x2 x16 xf32 > -> tensor <28 x32 xf32 >
1838
+ %extracted_slice = tensor.extract_slice %unpack
1839
+ [0 , 1 ] [28 , 28 ] [1 , 1 ] : tensor <28 x32 xf32 > to tensor <28 x28 xf32 >
1840
+ return %extracted_slice : tensor <28 x28 xf32 >
1841
+ }
1842
+
1843
+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
1844
+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1845
+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1846
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1847
+ // CHECK-SAME: into %[[DEST]]
1848
+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1849
+ // CHECK: return %[[SLICE]]
0 commit comments