Skip to content

Commit ee47454

Browse files
authored
[mlir][vector] Refactor createWriteOrMaskedWrite (llvm#138137)
This patch updates `createWriteOrMaskedWrite` to make it consistent with `createReadOrMaskedRead`. Before diving into the details: note that these utilities are currently implemented in different files — "VectorUtils.cpp" (Vector) and "Vectorization.cpp" (Linalg). In a subsequent patch, I plan to move `createWriteOrMaskedWrite` into "VectorUtils.cpp". SUMMARY OF CHANGES: The main change is to remove the logic that creates the destination tensor, which previously looked like: ```cpp Value dest = builder.create<tensor::EmptyOp>(loc, destSizes, inputType.getElementType()); ``` With this patch, createWriteOrMaskedWrite now simply generates: ```mlir %res = vector.transfer_write %vectorToStore into %dest ``` This replaces the previous form: ```mlir %dest = tensor.empty(%destSizes) %res = vector.transfer_write %vectorToStore into %dest ``` In other words, the destination value `%dest` is now passed as an input parameter. This makes `createWriteOrMaskedWrite` re-usable in contexts where the destination tensor is already known — for example, in `vectorizeAsInsertSliceOp`, which I will update in a follow-up patch. OTHER CHANGES: * Added comments and clarified TODOs. * Updated tests: since destination sizes are now computed independently inside `createWriteOrMaskedWrite`, some additional `tensor.dim` ops appear. These will be cleaned up by CSE + canonicalization.
1 parent c632ac3 commit ee47454

File tree

2 files changed

+62
-46
lines changed

2 files changed

+62
-46
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,84 +1506,86 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
15071507
}
15081508

1509-
/// Creates a TransferWriteOp to write `input` into a newly initialized
1510-
/// output tensor.
1509+
/// Creates an optionally masked TransferWriteOp
15111510
///
1512-
/// Given:
1513-
/// - an input vector to write,
1514-
/// - the mixed destination sizes for the output tensor,
1515-
/// - and the vector sizes used for vectorization (i.e., the leading N dims,
1516-
/// for some value of N),
1517-
///
1518-
/// this function generates the following sequence of ops:
1519-
///
1520-
/// %dest = tensor.empty(%destSizes)
1521-
/// %res = vector.transfer_write %input into %dest
1511+
/// Generates the following operation:
1512+
/// %res = vector.transfer_write %vectorToStore into %dest
15221513
///
15231514
/// If the leading N dimensions of the destination tensor do not match
1524-
/// `inputVecSizesForLeadingDims` (where N =
1525-
/// rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
1526-
/// correctness:
1515+
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1516+
/// masking is applied to ensure correctness:
15271517
///
1528-
/// %dest = tensor.empty(%destSizes)
1529-
/// %write = vector.transfer_write %input into %dest
1530-
/// %mask = vector.create_mask(%destSizes)
1531-
/// %res = vector.mask %mask { %write }
1518+
/// %mask = vector.create_mask(%destShape)
1519+
/// %res = vector.mask %mask {
1520+
/// vector.transfer_write %vectorToStore into %dest
1521+
/// }
15321522
///
15331523
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
15341524
/// is used instead of masking:
15351525
///
1536-
/// %dest = tensor.empty(%destSizes)
1526+
/// %write = vector.transfer_write %vectorToStore into %dest
15371527
/// in_bounds_flags = (...)
15381528
/// %res = vector.transfer_write %input into %dest
15391529
/// {in_bounds = in_bounds_flags}
15401530
///
1541-
/// NOTE: all write offsets are set to 0.
1531+
/// NOTE: All write offsets are set to 0.
1532+
/// TODO: Allow specyfying write offsets.
15421533
/// NOTE: When N < rank(input), the missing vector sizes are effectively
15431534
/// extracted from the trailing sizes of `destSizes`. This means those sizes
1544-
/// must be static. Supporting dynamic sizes will require the user to specify
1545-
/// the remaining vector sizes. This is left as a TODO.
1535+
/// must be static.
1536+
/// TODO: Support cases where an arbitrary dim is dynamic - this will require
1537+
/// specifying all the vector sizes.
15461538
static Operation *
1547-
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
1548-
SmallVector<OpFoldResult> destSizes,
1539+
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1540+
Value dest,
15491541
ArrayRef<int64_t> inputVecSizesForLeadingDims,
15501542
bool useInBoundsInsteadOfMasking = false) {
15511543

1552-
auto inputType = cast<VectorType>(input.getType());
1553-
assert(inputType.getRank() == static_cast<int64_t>(destSizes.size()) &&
1544+
ShapedType destType = cast<ShapedType>(dest.getType());
1545+
assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
1546+
static_cast<int64_t>(destType.getRank()) &&
15541547
"Rank mismatch!");
15551548

1556-
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1557-
inputType.getElementType());
15581549
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1559-
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
15601550
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1551+
1552+
// Compute the in_bounds attribute
15611553
SmallVector<bool> inBoundsVal(rank, true);
15621554
if (useInBoundsInsteadOfMasking) {
15631555
// In this case, assume that all the required vector sizes have been
15641556
// provided.
1565-
assert(inputVecSizesForLeadingDims.size() == destSizes.size() &&
1557+
assert(inputVecSizesForLeadingDims.size() ==
1558+
static_cast<size_t>(destType.getRank()) &&
15661559
"Insufficient number of input vector sizes!");
15671560
// Update the inBounds attribute.
15681561
for (unsigned i = 0; i < rank; i++)
15691562
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15701563
!ShapedType::isDynamic(destShape[i]);
15711564
}
1565+
1566+
// Generate the xfer_write Op
1567+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
15721568
Operation *write = builder.create<vector::TransferWriteOp>(
15731569
loc,
1574-
/*vector=*/input,
1570+
/*vector=*/vectorToStore,
15751571
/*source=*/dest,
15761572
/*indices=*/SmallVector<Value>(rank, zero),
15771573
/*inBounds=*/inBoundsVal);
15781574
assert(llvm::none_of(
15791575
destShape.drop_front(inputVecSizesForLeadingDims.size()),
15801576
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
15811577
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1578+
1579+
// If masking is disabled, exit.
15821580
if (useInBoundsInsteadOfMasking)
15831581
return write;
1582+
1583+
// Check if masking is needed.
15841584
bool needMaskForWrite =
15851585
!llvm::equal(inputVecSizesForLeadingDims,
15861586
destShape.take_front(inputVecSizesForLeadingDims.size()));
1587+
1588+
// If masking is needed, generate the mask and mask the operation.
15871589
if (needMaskForWrite) {
15881590
SmallVector<int64_t> writeMaskShape;
15891591
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
@@ -1592,10 +1594,11 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
15921594
inputVecSizesForLeadingDims.size(),
15931595
destShape.end());
15941596
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1595-
Value maskForWrite =
1596-
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1597+
Value maskForWrite = builder.create<vector::CreateMaskOp>(
1598+
loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
15971599
write = mlir::vector::maskOperation(builder, write, maskForWrite);
15981600
}
1601+
15991602
return write;
16001603
}
16011604

@@ -1693,9 +1696,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
16931696
loc, shapeCastOp.getResult(), destPermutation);
16941697

16951698
// Create TransferWriteOp.
1699+
Value dest = rewriter.create<tensor::EmptyOp>(
1700+
loc, reifiedReturnShapes[0],
1701+
transposeOp.getResult().getType().getElementType());
16961702
Operation *write =
1697-
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1698-
/*destSizes=*/reifiedReturnShapes[0],
1703+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
16991704
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
17001705
/*useInBoundsInsteadOfMasking=*/false);
17011706
newResults.push_back(write->getResult(0));
@@ -1830,10 +1835,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18301835
unpackOp.getDestType().hasStaticShape()
18311836
? vectorSizes
18321837
: shapeCastOp.getResultVectorType().getShape());
1833-
Operation *write = createWriteOrMaskedWrite(
1834-
rewriter, loc, shapeCastOp.getResult(), /*destSizes=*/reifiedRetShapes[0],
1835-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1836-
useInBoundsInsteadOfMasking);
1838+
Value dest = rewriter.create<tensor::EmptyOp>(
1839+
loc, reifiedRetShapes[0],
1840+
shapeCastOp.getResult().getType().getElementType());
1841+
Operation *write =
1842+
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
1843+
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1844+
useInBoundsInsteadOfMasking);
18371845
newResults.push_back(write->getResult(0));
18381846
return success();
18391847
}
@@ -1861,10 +1869,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18611869
auto maskedRead = vector::createReadOrMaskedRead(
18621870
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
18631871
/*useInBoundsInsteadOfMasking=*/false);
1864-
Operation *write = createWriteOrMaskedWrite(
1865-
rewriter, loc, maskedRead, reifiedReturnShapes[0],
1866-
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1867-
/*useInBoundsInsteadOfMasking=*/false);
1872+
1873+
// Create Xfer write Op
1874+
Value dest = rewriter.create<tensor::EmptyOp>(
1875+
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1876+
Operation *write =
1877+
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
1878+
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1879+
/*useInBoundsInsteadOfMasking=*/false);
18681880
newResults.push_back(write->getResult(0));
18691881
return success();
18701882
}

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,9 @@ func.func @test_masked_vectorize_dynamic_pad(
641641
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
642642
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
643643
// CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
644-
// CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
644+
// CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?xf32>
645+
// CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?xf32>
646+
// CHECK: %[[mask_2:.*]] = vector.create_mask %[[d2]], %[[d3]] : vector<2x4xi1>
645647
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
646648
// CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
647649
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
@@ -800,7 +802,9 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?
800802
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
801803
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
802804
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
803-
// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
805+
// CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?x16x2xf32>
806+
// CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?x16x2xf32>
807+
// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d2]], %[[d3]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
804808
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
805809
// CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
806810
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>

0 commit comments

Comments
 (0)