Skip to content

Commit 84cd18b

Browse files
Max191qedawkins
authored andcommitted
Revert "[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (llvm#92934)"
This reverts commit 2c06fb8.
1 parent 3101524 commit 84cd18b

File tree

2 files changed

+26
-65
lines changed

2 files changed

+26
-65
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,27 +1622,7 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16221622
}
16231623
};
16241624

1625-
// Scalable unit dimensions are not supported. Folding such dimensions would
1626-
// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1627-
// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1628-
// future.
1629-
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1630-
auto inVecShape = inVecTy.getShape();
1631-
SmallVector<int64_t> newShape;
1632-
SmallVector<bool> newScalableDims;
1633-
for (auto [dim, isScalable] :
1634-
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1635-
if (dim == 1 && !isScalable)
1636-
continue;
1637-
1638-
newShape.push_back(dim);
1639-
newScalableDims.push_back(isScalable);
1640-
}
1641-
1642-
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1643-
}
1644-
1645-
/// For vectors with at least an unit dim, replaces:
1625+
/// For vectors with either leading or trailing unit dim, replaces:
16461626
/// elementwise(a, b)
16471627
/// with:
16481628
/// sc_a = shape_cast(a)
@@ -1654,16 +1634,20 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
16541634
/// required to be rank > 1.
16551635
///
16561636
/// Ex:
1637+
/// ```
16571638
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
16581639
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1640+
/// ```
16591641
///
16601642
/// gets converted to:
16611643
///
1644+
/// ```
16621645
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
16631646
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
16641647
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
16651648
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
16661649
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1650+
/// ```
16671651
///
16681652
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
16691653
/// `%cast`.
@@ -1683,29 +1667,42 @@ struct DropUnitDimFromElementwiseOps final
16831667
// guaranteed to have identical shapes (with some exceptions such as
16841668
// `arith.select`) and it suffices to only check one of them.
16851669
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1686-
if (!sourceVectorType || sourceVectorType.getRank() < 2)
1670+
if (!sourceVectorType)
1671+
return failure();
1672+
if (sourceVectorType.getRank() < 2)
1673+
return failure();
1674+
1675+
bool hasTrailingDimUnitFixed =
1676+
((sourceVectorType.getShape().back() == 1) &&
1677+
(!sourceVectorType.getScalableDims().back()));
1678+
bool hasLeadingDimUnitFixed =
1679+
((sourceVectorType.getShape().front() == 1) &&
1680+
(!sourceVectorType.getScalableDims().front()));
1681+
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
16871682
return failure();
16881683

1684+
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1685+
// operands
1686+
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1687+
16891688
SmallVector<Value> newOperands;
16901689
auto loc = op->getLoc();
16911690
for (auto operand : op->getOperands()) {
16921691
auto opVectorType = cast<VectorType>(operand.getType());
1693-
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1694-
if (newVType == opVectorType)
1695-
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1696-
1692+
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
16971693
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16981694
newOperands.push_back(opSC);
16991695
}
17001696

17011697
VectorType newResultVectorType =
1702-
dropNonScalableUnitDimFromType(resultVectorType);
1703-
// Create an updated elementwise Op without unit dim.
1698+
VectorType::Builder(resultVectorType).dropDim(dim);
1699+
// Create an updated elementwise Op without leading/trailing unit dim
17041700
Operation *elementwiseOp =
17051701
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
17061702
newResultVectorType, op->getAttrs());
17071703

1708-
// Restore the unit dim by applying vector.shape_cast to the result.
1704+
// Restore the leading/trailing unit dim by applying vector.shape_cast
1705+
// to the result
17091706
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
17101707
elementwiseOp->getResult(0));
17111708

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -604,42 +604,6 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
604604

605605
// -----
606606

607-
func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
608-
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
609-
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
610-
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
611-
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
612-
return %res : vector<8x3xf128>
613-
}
614-
615-
// CHECK-LABEL: func.func @fold_inner_unit_dim(
616-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
617-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
618-
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
619-
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
620-
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
621-
// CHECK: return %[[VAL_4]] : vector<8x3xf128>
622-
623-
// -----
624-
625-
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
626-
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
627-
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
628-
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
629-
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
630-
return %res : vector<8x[1]x3xf128>
631-
}
632-
633-
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
634-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
635-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
636-
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
637-
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
638-
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
639-
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
640-
641-
// -----
642-
643607
func.func @negative_out_of_bound_transfer_read(
644608
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
645609
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)