Skip to content

Commit 2a82dfd

Browse files
committed
[mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToSCF)
This allows the lowering of > rank 1 transfer_reads/writes to equivalent lower-rank ones when the trailing dimension is scalable. The resulting ops still cannot be completely lowered as they depend on arrays of scalable vectors being enabled, and a few related fixes (see D158517). This patch also explicitly disables lowering transfer_reads/writes with a leading scalable dimension, as more changes would be needed to handle that correctly and it is unclear if it is required. Examples of ops that can now be further lowered: %vec = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32> vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32> Reviewed By: c-rhodes, awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D158753
1 parent eebf8fa commit 2a82dfd

File tree

2 files changed

+129
-8
lines changed

2 files changed

+129
-8
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,15 +314,18 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
314314
/// the VectorType into the MemRefType.
315315
///
316316
/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
317-
static MemRefType unpackOneDim(MemRefType type) {
317+
static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
318318
auto vectorType = dyn_cast<VectorType>(type.getElementType());
319+
// Vectors with leading scalable dims are not supported.
320+
// It may be possible to support these in future by using dynamic memref dims.
321+
if (vectorType.getScalableDims().front())
322+
return failure();
319323
auto memrefShape = type.getShape();
320324
SmallVector<int64_t, 8> newMemrefShape;
321325
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
322326
newMemrefShape.push_back(vectorType.getDimSize(0));
323327
return MemRefType::get(newMemrefShape,
324-
VectorType::get(vectorType.getShape().drop_front(),
325-
vectorType.getElementType()));
328+
VectorType::Builder(vectorType).dropDim(0));
326329
}
327330

328331
/// Given a transfer op, find the memref from which the mask is loaded. This
@@ -542,6 +545,10 @@ LogicalResult checkPrepareXferOp(OpTy xferOp,
542545
return failure();
543546
if (xferOp.getVectorType().getRank() <= options.targetRank)
544547
return failure();
548+
// Currently the unpacking of the leading dimension into the memref is not
549+
// supported for scalable dimensions.
550+
if (xferOp.getVectorType().getScalableDims().front())
551+
return failure();
545552
if (isTensorOp(xferOp) && !options.lowerTensors)
546553
return failure();
547554
// Transfer ops that modify the element type are not supported atm.
@@ -866,8 +873,11 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
866873
auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
867874
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
868875
auto castedDataType = unpackOneDim(dataBufferType);
876+
if (failed(castedDataType))
877+
return failure();
878+
869879
auto castedDataBuffer =
870-
locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
880+
locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
871881

872882
// If the xferOp has a mask: Find and cast mask buffer.
873883
Value castedMaskBuffer;
@@ -882,7 +892,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
882892
// be broadcasted.)
883893
castedMaskBuffer = maskBuffer;
884894
} else {
885-
auto castedMaskType = unpackOneDim(maskBufferType);
895+
// It's safe to assume the mask buffer can be unpacked if the data
896+
// buffer was unpacked.
897+
auto castedMaskType = *unpackOneDim(maskBufferType);
886898
castedMaskBuffer =
887899
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
888900
}
@@ -891,7 +903,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
891903
// Loop bounds and step.
892904
auto lb = locB.create<arith::ConstantIndexOp>(0);
893905
auto ub = locB.create<arith::ConstantIndexOp>(
894-
castedDataType.getDimSize(castedDataType.getRank() - 1));
906+
castedDataType->getDimSize(castedDataType->getRank() - 1));
895907
auto step = locB.create<arith::ConstantIndexOp>(1);
896908
// TransferWriteOps that operate on tensors return the modified tensor and
897909
// require a loop state.
@@ -1074,8 +1086,14 @@ struct UnrollTransferReadConversion
10741086
auto vec = getResultVector(xferOp, rewriter);
10751087
auto vecType = dyn_cast<VectorType>(vec.getType());
10761088
auto xferVecType = xferOp.getVectorType();
1077-
auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
1078-
xferVecType.getElementType());
1089+
1090+
if (xferVecType.getScalableDims()[0]) {
1091+
// Cannot unroll a scalable dimension at compile time.
1092+
return failure();
1093+
}
1094+
1095+
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1096+
10791097
int64_t dimSize = xferVecType.getShape()[0];
10801098

10811099
// Generate fully unrolled loop of transfer ops.

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,106 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
635635
// CHECK: vector.print
636636
// CHECK: return
637637
// CHECK: }
638+
639+
// -----
640+
641+
func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> {
642+
%c0 = arith.constant 0 : index
643+
%c1 = arith.constant 1 : index
644+
%cst = arith.constant 0.000000e+00 : f32
645+
%dim = memref.dim %arg0, %c1 : memref<3x?xf32>
646+
%mask = vector.create_mask %c1, %dim : vector<3x[4]xi1>
647+
%read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32>
648+
return %read : vector<3x[4]xf32>
649+
}
650+
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
651+
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
652+
// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
653+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
654+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
655+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
656+
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
657+
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
658+
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
659+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
660+
// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
661+
// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
662+
// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
663+
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
664+
// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
665+
// CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
666+
// CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
667+
// CHECK: }
668+
// CHECK: %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
669+
// CHECK: return %[[RESULT]] : vector<3x[4]xf32>
670+
// CHECK: }
671+
672+
// -----
673+
674+
func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memref<3x?xf32>) {
675+
%c0 = arith.constant 0 : index
676+
%c1 = arith.constant 1 : index
677+
%cst = arith.constant 0.000000e+00 : f32
678+
%dim = memref.dim %arg0, %c1 : memref<3x?xf32>
679+
%mask = vector.create_mask %c1, %dim : vector<3x[4]xi1>
680+
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32>
681+
return
682+
}
683+
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
684+
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
685+
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
686+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
687+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
688+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
689+
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
690+
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
691+
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
692+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
693+
// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
694+
// CHECK: memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
695+
// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
696+
// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
697+
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
698+
// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
699+
// CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
700+
// CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
701+
// CHECK: }
702+
// CHECK: return
703+
// CHECK: }
704+
705+
// -----
706+
707+
/// The following two tests currently cannot be lowered via unpacking the leading dim since it is scalable.
708+
/// It may be possible to special case this via a dynamic dim in future.
709+
710+
func.func @cannot_lower_transfer_write_with_leading_scalable(%vec: vector<[4]x4xf32>, %arg0: memref<?x4xf32>) {
711+
%c0 = arith.constant 0 : index
712+
%c4 = arith.constant 4 : index
713+
%cst = arith.constant 0.000000e+00 : f32
714+
%dim = memref.dim %arg0, %c0 : memref<?x4xf32>
715+
%mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
716+
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
717+
return
718+
}
719+
// CHECK-LABEL: func.func @cannot_lower_transfer_write_with_leading_scalable(
720+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x4xf32>,
721+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x4xf32>)
722+
// CHECK: vector.transfer_write %[[VEC]], %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
723+
724+
// -----
725+
726+
func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf32>) -> vector<[4]x4xf32> {
727+
%c0 = arith.constant 0 : index
728+
%c1 = arith.constant 1 : index
729+
%c4 = arith.constant 4 : index
730+
%cst = arith.constant 0.000000e+00 : f32
731+
%dim = memref.dim %arg0, %c0 : memref<?x4xf32>
732+
%mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
733+
%read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
734+
return %read : vector<[4]x4xf32>
735+
}
736+
// CHECK-LABEL: func.func @cannot_lower_transfer_read_with_leading_scalable(
737+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x4xf32>)
738+
// CHECK: %{{.*}} = vector.transfer_read %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
739+
740+

0 commit comments

Comments
 (0)