Skip to content

Commit 67b302c

Browse files
authored
[mlir][vector] Add vector.step operation (#96776)
This patch adds a new vector.step operation to the Vector dialect. It produces a linear sequence of index values from 0 to N, where N is the number of elements in the result vector, and can be used to create vectors of indices. It supports both fixed-width and scalable vectors. For fixed the canonical representation is `arith.constant dense<[0, .., N]>`. A scalable step cannot be represented as a constant and is lowered to the `llvm.experimental.stepvector` intrinsic [1]. This op enables scalable vectorization of linalg.index ops, see #96778. It can also be used in the SparseVectorizer in-place of lower-level stepvector intrinsic, see [2] (patch to follow). [1] https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic [2] https://github.com/llvm/llvm-project/blob/acf675b63f9426e61aac2155e29280f7d21f9421/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp#L385-L388
1 parent 927def4 commit 67b302c

File tree

7 files changed

+101
-3
lines changed

7 files changed

+101
-3
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3017,6 +3017,31 @@ def Vector_ScanOp :
30173017
let hasVerifier = 1;
30183018
}
30193019

3020+
//===----------------------------------------------------------------------===//
3021+
// VectorStepOp
3022+
//===----------------------------------------------------------------------===//
3023+
3024+
def Vector_StepOp : Vector_Op<"step", [Pure]> {
3025+
let summary = "A linear sequence of values from 0 to N";
3026+
let description = [{
3027+
A `step` operation produces an index vector, i.e. a 1-D vector of values of
3028+
index type that represents a linear sequence from 0 to N-1, where N is the
3029+
number of elements in the `result` vector.
3030+
3031+
Supports fixed-width and scalable vectors.
3032+
3033+
Examples:
3034+
3035+
```mlir
3036+
%0 = vector.step : vector<4xindex> ; [0, 1, 2, 3]
3037+
%1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
3038+
```
3039+
}];
3040+
let hasFolder = 1;
3041+
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
3042+
let assemblyFormat = "attr-dict `:` type($result)";
3043+
}
3044+
30203045
def Vector_YieldOp : Vector_Op<"yield", [
30213046
Pure, ReturnLike, Terminator]> {
30223047
let summary = "Terminates and yields values from vector regions.";

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,6 +1860,19 @@ struct VectorFromElementsLowering
18601860
}
18611861
};
18621862

1863+
/// Conversion pattern for vector.step.
1864+
struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
1865+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1866+
1867+
LogicalResult
1868+
matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1869+
ConversionPatternRewriter &rewriter) const override {
1870+
Type llvmType = typeConverter->convertType(stepOp.getType());
1871+
rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1872+
return success();
1873+
}
1874+
};
1875+
18631876
} // namespace
18641877

18651878
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1885,8 +1898,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
18851898
VectorSplatOpLowering, VectorSplatNdOpLowering,
18861899
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
18871900
MaskedReductionOpConversion, VectorInterleaveOpLowering,
1888-
VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
1889-
converter);
1901+
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1902+
VectorStepOpLowering>(converter);
18901903
// Transfer ops with rank > 1 are handled by VectorToSCF.
18911904
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
18921905
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6312,6 +6312,20 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
63126312
return SplatElementsAttr::get(getType(), {constOperand});
63136313
}
63146314

6315+
//===----------------------------------------------------------------------===//
6316+
// StepOp
6317+
//===----------------------------------------------------------------------===//
6318+
6319+
OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
6320+
auto resultType = cast<VectorType>(getType());
6321+
if (resultType.isScalable())
6322+
return nullptr;
6323+
SmallVector<APInt> indices;
6324+
for (unsigned i = 0; i < resultType.getNumElements(); i++)
6325+
indices.push_back(APInt(/*width=*/64, i));
6326+
return DenseElementsAttr::get(resultType, indices);
6327+
}
6328+
63156329
//===----------------------------------------------------------------------===//
63166330
// WarpExecuteOnLane0Op
63176331
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,3 +2621,14 @@ func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
26212621
%0 = vector.from_elements %a : vector<f32>
26222622
return %0 : vector<f32>
26232623
}
2624+
2625+
// -----
2626+
2627+
// CHECK-LABEL: @vector_step_scalable
2628+
// CHECK: %[[STEPVECTOR:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi64>
2629+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[STEPVECTOR]] : vector<[4]xi64> to vector<[4]xindex>
2630+
// CHECK: return %[[CAST]] : vector<[4]xindex>
2631+
func.func @vector_step_scalable() -> vector<[4]xindex> {
2632+
%0 = vector.step : vector<[4]xindex>
2633+
return %0 : vector<[4]xindex>
2634+
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2719,3 +2719,13 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
27192719
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
27202720
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
27212721
}
2722+
2723+
// -----
2724+
2725+
// CHECK-LABEL: @fold_vector_step_to_constant
2726+
// CHECK: %[[CONSTANT:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
2727+
// CHECK: return %[[CONSTANT]] : vector<4xindex>
2728+
func.func @fold_vector_step_to_constant() -> vector<4xindex> {
2729+
%0 = vector.step : vector<4xindex>
2730+
return %0 : vector<4xindex>
2731+
}

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,3 +1871,19 @@ func.func @invalid_from_elements(%a: f32, %b: i32) {
18711871
vector.from_elements %a, %b : vector<2xf32>
18721872
return
18731873
}
1874+
1875+
// -----
1876+
1877+
func.func @invalid_step_0d() {
1878+
// expected-error @+1 {{vector.step' op result #0 must be vector of index values of ranks 1, but got 'vector<f32>'}}
1879+
vector.step : vector<f32>
1880+
return
1881+
}
1882+
1883+
// -----
1884+
1885+
func.func @invalid_step_2d() {
1886+
// expected-error @+1 {{vector.step' op result #0 must be vector of index values of ranks 1, but got 'vector<2x4xf32>'}}
1887+
vector.step : vector<2x4xf32>
1888+
return
1889+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,4 +1171,13 @@ func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vecto
11711171
// CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
11721172
%3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
11731173
return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
1174-
}
1174+
}
1175+
1176+
// CHECK-LABEL: @step
1177+
func.func @step() {
1178+
// CHECK: vector.step : vector<2xindex>
1179+
%0 = vector.step : vector<2xindex>
1180+
// CHECK: vector.step : vector<[4]xindex>
1181+
%1 = vector.step : vector<[4]xindex>
1182+
return
1183+
}

0 commit comments

Comments
 (0)