Skip to content

Commit d3846ec

Browse files
authored
[mlir] Guard sccp pass from crashing with different source type (#120656)
Vector::BroadCastOp expects the identical element type in folding. It causes the crash if the different source type is given to the SCCP pass. We need to guard the pass from crashing if the nonidentical element type is given, but still compatible. (e.g. index vs integer type) #120193
1 parent 34f7000 commit d3846ec

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
25232523
if (!adaptor.getSource())
25242524
return {};
25252525
auto vectorType = getResultVectorType();
2526-
if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2527-
return DenseElementsAttr::get(vectorType, adaptor.getSource());
2526+
if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2527+
if (vectorType.getElementType() != attr.getType())
2528+
return {};
2529+
return DenseElementsAttr::get(vectorType, attr);
2530+
}
2531+
if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2532+
if (vectorType.getElementType() != attr.getType())
2533+
return {};
2534+
return DenseElementsAttr::get(vectorType, attr);
2535+
}
25282536
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
25292537
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
25302538
return {};

mlir/test/Transforms/sccp.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,12 @@ func.func @op_with_region() -> (i32) {
246246
^b:
247247
return %1 : i32
248248
}
249+
250+
// CHECK-LABEL: no_crash_with_different_source_type
251+
func.func @no_crash_with_different_source_type() {
252+
// CHECK: llvm.mlir.constant(0 : index) : i64
253+
%0 = llvm.mlir.constant(0 : index) : i64
254+
// CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
255+
%1 = vector.broadcast %0 : i64 to vector<128xi64>
256+
llvm.return
257+
}

0 commit comments

Comments
 (0)