@@ -42,3 +42,89 @@ module attributes {transform.with_named_sequence} {
42
42
transform.yield
43
43
}
44
44
}
45
+
46
+ // -----
47
+
48
+ #map = affine_map <(d0 ) -> (d0 )>
49
+
50
+ // This test is intended to check that the produced IR does not contain any
51
+ // type errors from sharing empty tensor operations with different types.
52
+ // The verifiers are sufficient to lock down the intended behavior.
53
+
54
+ // CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
55
+ func.func @collapse_shape_prevents_reuse (%fill_value: f32 ) -> tensor <56 xf32 >
56
+ {
57
+ %init0 = tensor.empty () : tensor <56 xf32 >
58
+ %init1 = tensor.empty () : tensor <56 x1 xf32 >
59
+
60
+ %filled_tensor = linalg.fill
61
+ ins (%fill_value : f32 )
62
+ outs (%init1 : tensor <56 x1 xf32 >) -> tensor <56 x1 xf32 >
63
+
64
+ // The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be
65
+ // pushed into the output of the linalg.generic.
66
+ %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0 , 1 ]]
67
+ : tensor <56 x1 xf32 > into tensor <56 xf32 >
68
+
69
+ %bias = linalg.generic {
70
+ indexing_maps = [#map , #map ],
71
+ iterator_types = [" parallel" ]
72
+ } ins (%reshaped_tensor : tensor <56 xf32 >)
73
+ outs (%init0 : tensor <56 xf32 >) {
74
+ ^bb0 (%in: f32 , %out: f32 ):
75
+ linalg.yield %in : f32
76
+ } -> tensor <56 xf32 >
77
+
78
+ return %bias : tensor <56 xf32 >
79
+ }
80
+
81
+ module attributes {transform.with_named_sequence } {
82
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
83
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
84
+ transform.structured.eliminate_empty_tensors %0 : !transform.any_op
85
+ transform.yield
86
+ }
87
+ }
88
+
89
+ // -----
90
+
91
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
92
+
93
+ // This test is intended to check that the produced IR does not contain any
94
+ // type errors from sharing empty tensor operations with different types.
95
+ // The verifiers are sufficient to lock down the intended behavior.
96
+
97
+ // CHECK-LABEL: func.func @collapse_cast_prevents_reuse(
98
+ func.func @collapse_cast_prevents_reuse (%fill_value: f32 ) -> tensor <56 x?xf32 >
99
+ {
100
+ %c1 = arith.constant 1 : index
101
+ %init0 = tensor.empty (%c1 ) : tensor <56 x?xf32 >
102
+ %init1 = tensor.empty () : tensor <56 x1 xf32 >
103
+
104
+ %filled_tensor = linalg.fill
105
+ ins (%fill_value : f32 )
106
+ outs (%init1 : tensor <56 x1 xf32 >) -> tensor <56 x1 xf32 >
107
+
108
+ // The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be
109
+ // pushed into the output of the linalg.generic.
110
+ %cast = tensor.cast %filled_tensor : tensor <56 x1 xf32 > to tensor <56 x?xf32 >
111
+
112
+ %bias = linalg.generic {
113
+ indexing_maps = [#map , #map ],
114
+ iterator_types = [" parallel" , " parallel" ]
115
+ } ins (%cast : tensor <56 x?xf32 >)
116
+ outs (%init0 : tensor <56 x?xf32 >) {
117
+ ^bb0 (%in: f32 , %out: f32 ):
118
+ linalg.yield %in : f32
119
+ } -> tensor <56 x?xf32 >
120
+
121
+ return %bias : tensor <56 x?xf32 >
122
+ }
123
+
124
+ module attributes {transform.with_named_sequence } {
125
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
126
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
127
+ transform.structured.eliminate_empty_tensors %0 : !transform.any_op
128
+ transform.yield
129
+ }
130
+ }
0 commit comments