Skip to content

Commit 2964c5d

Browse files
committed
Revert "[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (llvm#92934)"
This reverts commit 2c06fb8.
1 parent 0cc3fe4 commit 2964c5d

File tree

2 files changed

+26
-65
lines changed

2 files changed

+26
-65
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,27 +1623,7 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16231623
}
16241624
};
16251625

1626-
// Scalable unit dimensions are not supported. Folding such dimensions would
1627-
// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1628-
// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1629-
// future.
1630-
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1631-
auto inVecShape = inVecTy.getShape();
1632-
SmallVector<int64_t> newShape;
1633-
SmallVector<bool> newScalableDims;
1634-
for (auto [dim, isScalable] :
1635-
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1636-
if (dim == 1 && !isScalable)
1637-
continue;
1638-
1639-
newShape.push_back(dim);
1640-
newScalableDims.push_back(isScalable);
1641-
}
1642-
1643-
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1644-
}
1645-
1646-
/// For vectors with at least an unit dim, replaces:
1626+
/// For vectors with either leading or trailing unit dim, replaces:
16471627
/// elementwise(a, b)
16481628
/// with:
16491629
/// sc_a = shape_cast(a)
@@ -1655,16 +1635,20 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
16551635
/// required to be rank > 1.
16561636
///
16571637
/// Ex:
1638+
/// ```
16581639
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
16591640
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1641+
/// ```
16601642
///
16611643
/// gets converted to:
16621644
///
1645+
/// ```
16631646
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
16641647
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
16651648
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
16661649
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
16671650
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1651+
/// ```
16681652
///
16691653
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
16701654
/// `%cast`.
@@ -1684,29 +1668,42 @@ struct DropUnitDimFromElementwiseOps final
16841668
// guaranteed to have identical shapes (with some exceptions such as
16851669
// `arith.select`) and it suffices to only check one of them.
16861670
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1687-
if (!sourceVectorType || sourceVectorType.getRank() < 2)
1671+
if (!sourceVectorType)
1672+
return failure();
1673+
if (sourceVectorType.getRank() < 2)
1674+
return failure();
1675+
1676+
bool hasTrailingDimUnitFixed =
1677+
((sourceVectorType.getShape().back() == 1) &&
1678+
(!sourceVectorType.getScalableDims().back()));
1679+
bool hasLeadingDimUnitFixed =
1680+
((sourceVectorType.getShape().front() == 1) &&
1681+
(!sourceVectorType.getScalableDims().front()));
1682+
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
16881683
return failure();
16891684

1685+
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1686+
// operands
1687+
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1688+
16901689
SmallVector<Value> newOperands;
16911690
auto loc = op->getLoc();
16921691
for (auto operand : op->getOperands()) {
16931692
auto opVectorType = cast<VectorType>(operand.getType());
1694-
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1695-
if (newVType == opVectorType)
1696-
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1697-
1693+
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
16981694
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16991695
newOperands.push_back(opSC);
17001696
}
17011697

17021698
VectorType newResultVectorType =
1703-
dropNonScalableUnitDimFromType(resultVectorType);
1704-
// Create an updated elementwise Op without unit dim.
1699+
VectorType::Builder(resultVectorType).dropDim(dim);
1700+
// Create an updated elementwise Op without leading/trailing unit dim
17051701
Operation *elementwiseOp =
17061702
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
17071703
newResultVectorType, op->getAttrs());
17081704

1709-
// Restore the unit dim by applying vector.shape_cast to the result.
1705+
// Restore the leading/trailing unit dim by applying vector.shape_cast
1706+
// to the result
17101707
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
17111708
elementwiseOp->getResult(0));
17121709

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -604,42 +604,6 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
604604

605605
// -----
606606

607-
func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
608-
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
609-
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
610-
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
611-
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
612-
return %res : vector<8x3xf128>
613-
}
614-
615-
// CHECK-LABEL: func.func @fold_inner_unit_dim(
616-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
617-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
618-
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
619-
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
620-
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
621-
// CHECK: return %[[VAL_4]] : vector<8x3xf128>
622-
623-
// -----
624-
625-
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
626-
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
627-
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
628-
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
629-
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
630-
return %res : vector<8x[1]x3xf128>
631-
}
632-
633-
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
634-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
635-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
636-
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
637-
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
638-
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
639-
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
640-
641-
// -----
642-
643607
func.func @negative_out_of_bound_transfer_read(
644608
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
645609
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)