Skip to content

Commit db1d881

Browse files
authored
[mlir][spirv] Fix bug for vector.broadcast op in convert-vector-to-spirv pass (llvm#99928)
This PR addresses [!17976](iree-org/iree#17976) by using converted `resultType` instead of the original result type obtained from `castOp.getResultVectorType`. A new LIT test is also included.
1 parent 7467f41 commit db1d881

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ struct VectorBroadcastConvert final
144144

145145
SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
146146
adaptor.getSource());
147-
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
148-
castOp, castOp.getResultVectorType(), source);
147+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
148+
source);
149149
return success();
150150
}
151151
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,19 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
150150

151151
// -----
152152

153+
// CHECK-LABEL: @broadcast_index
154+
// CHECK-SAME: %[[ARG0:.*]]: index
155+
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : index to i32
156+
// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CAST0]], %[[CAST0]], %[[CAST0]], %[[CAST0]] : (i32, i32, i32, i32) -> vector<4xi32>
157+
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
158+
// CHECK: return %[[CAST1]] : vector<4xindex>
159+
func.func @broadcast_index(%a: index) -> vector<4xindex> {
160+
%0 = vector.broadcast %a : index to vector<4xindex>
161+
return %0 : vector<4xindex>
162+
}
163+
164+
// -----
165+
153166
// CHECK-LABEL: @extract
154167
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
155168
// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>

0 commit comments

Comments
 (0)