Skip to content

[mlir][vector] Allow multi dim vectors in vector.scatter #132217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2034,19 +2034,19 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$valueToStore)> {
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore)> {

let summary = [{
scatters elements from a vector into memory as defined by an index vector
and a mask vector
}];

let description = [{
The scatter operation stores elements from a 1-D vector into memory as
defined by a base with indices and an additional 1-D index vector, but
only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
The scatter operation stores elements from a n-D vector into memory as
defined by a base with indices and an additional n-D index vector, but
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
```
if (mask[0]) base[index[0]] = value[0]
Expand Down
32 changes: 21 additions & 11 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,25 @@ class VectorGatherOpConversion
LogicalResult
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = gather->getLoc();
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
assert(memRefType && "The base should be bufferized");

if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
return rewriter.notifyMatchFailure(gather, "memref type not supported");

VectorType vType = gather.getVectorType();
if (vType.getRank() > 1)
return failure();

Location loc = gather->getLoc();
if (vType.getRank() > 1) {
return rewriter.notifyMatchFailure(
gather, "only 1-D vectors can be lowered to LLVM");
}

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
return rewriter.notifyMatchFailure(gather,
"could not resolve memref alignment");
}

// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
Expand Down Expand Up @@ -309,15 +312,22 @@ class VectorScatterOpConversion
MemRefType memRefType = scatter.getMemRefType();

if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
return rewriter.notifyMatchFailure(scatter, "memref type not supported");

VectorType vType = scatter.getVectorType();
if (vType.getRank() > 1) {
return rewriter.notifyMatchFailure(
scatter, "only 1-D vectors can be lowered to LLVM");
}

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
return rewriter.notifyMatchFailure(scatter,
"could not resolve memref alignment");
}

// Resolve address.
VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptrs =
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != indVType.getDimSize(0))
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x

// -----

//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//

// Multi-Dimensional scatters are not supported yet. Check that we do not lower
// them.

func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
%0 = arith.constant 0: index
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
return
}

// CHECK-LABEL: func @scatter_with_mask
// CHECK: vector.scatter

// -----

func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
%0 = arith.constant 0: index
// vector.constant_mask only supports 'none set' or 'all set' scalable
// dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
// width vectors above.
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
return
}

// CHECK-LABEL: func @scatter_with_mask_scalable
// CHECK: vector.scatter

// -----

//===----------------------------------------------------------------------===//
// vector.interleave
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
// expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
Expand Down
18 changes: 10 additions & 8 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
return
}

// CHECK-LABEL: @gather_and_scatter_multi_dims
func.func @gather_and_scatter_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
%c0 = arith.constant 0 : index
// CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
// CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
return %0 : vector<2x16xf32>
}

// CHECK-LABEL: @gather_on_tensor
func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
Expand All @@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
return %0 : vector<16xf32>
}

// CHECK-LABEL: @gather_multi_dims
func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
%c0 = arith.constant 0 : index
// CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
return %0 : vector<2x16xf32>
}

// CHECK-LABEL: @expand_and_compress
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
Expand Down
Loading