14
14
#include " mlir/Dialect/Affine/IR/AffineOps.h"
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17
18
#include " mlir/Dialect/Tensor/IR/Tensor.h"
18
19
#include " mlir/Dialect/Utils/IndexingUtils.h"
19
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -104,10 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104
105
<< " \n " );
105
106
llvm::SmallVector<Operation *, 8 > blockingAccesses;
106
107
Operation *firstOverwriteCandidate = nullptr ;
107
- Value source = write .getSource ();
108
- // Skip subview ops.
109
- while (auto subView = source.getDefiningOp <memref::SubViewOp>())
110
- source = subView.getSource ();
108
+ Value source =
109
+ memref::skipSubViewsAndCasts (cast<MemrefValue>(write .getSource ()));
111
110
llvm::SmallVector<Operation *, 32 > users (source.getUsers ().begin (),
112
111
source.getUsers ().end ());
113
112
llvm::SmallDenseSet<Operation *, 32 > processed;
@@ -116,8 +115,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
116
115
// If the user has already been processed skip.
117
116
if (!processed.insert (user).second )
118
117
continue ;
119
- if (auto subView = dyn_cast <memref::SubViewOp>(user)) {
120
- users.append (subView ->getUsers ().begin (), subView ->getUsers ().end ());
118
+ if (isa <memref::SubViewOp, memref::CastOp >(user)) {
119
+ users.append (user ->getUsers ().begin (), user ->getUsers ().end ());
121
120
continue ;
122
121
}
123
122
if (isMemoryEffectFree (user))
@@ -126,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
126
125
continue ;
127
126
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128
127
// Check candidate that can override the store.
129
- if (write .getSource () == nextWrite.getSource () &&
128
+ if (memref::isSameViewOrTrivialAlias (
129
+ cast<MemrefValue>(nextWrite.getSource ()),
130
+ cast<MemrefValue>(write .getSource ())) &&
130
131
checkSameValueWAW (nextWrite, write ) &&
131
132
postDominators.postDominates (nextWrite, write )) {
132
133
if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
191
192
<< " \n " );
192
193
SmallVector<Operation *, 8 > blockingWrites;
193
194
vector::TransferWriteOp lastwrite = nullptr ;
194
- Value source = read .getSource ();
195
- // Skip subview ops.
196
- while (auto subView = source.getDefiningOp <memref::SubViewOp>())
197
- source = subView.getSource ();
195
+ Value source =
196
+ memref::skipSubViewsAndCasts (cast<MemrefValue>(read .getSource ()));
198
197
llvm::SmallVector<Operation *, 32 > users (source.getUsers ().begin (),
199
198
source.getUsers ().end ());
200
199
llvm::SmallDenseSet<Operation *, 32 > processed;
@@ -203,12 +202,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
203
202
// If the user has already been processed skip.
204
203
if (!processed.insert (user).second )
205
204
continue ;
206
- if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207
- users.append (subView->getUsers ().begin (), subView->getUsers ().end ());
208
- continue ;
209
- }
210
- if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211
- users.append (collapsed->getUsers ().begin (), collapsed->getUsers ().end ());
205
+ if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
206
+ users.append (user->getUsers ().begin (), user->getUsers ().end ());
212
207
continue ;
213
208
}
214
209
if (isMemoryEffectFree (user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
221
216
cast<VectorTransferOpInterface>(read .getOperation ()),
222
217
/* testDynamicValueUsingBounds=*/ true ))
223
218
continue ;
224
- if (write .getSource () == read .getSource () &&
219
+ if (memref::isSameViewOrTrivialAlias (
220
+ cast<MemrefValue>(read .getSource ()),
221
+ cast<MemrefValue>(write .getSource ())) &&
225
222
dominators.dominates (write , read ) && checkSameValueRAW (write , read )) {
226
223
if (lastwrite == nullptr || dominators.dominates (lastwrite, write ))
227
224
lastwrite = write ;
0 commit comments