@@ -606,7 +606,7 @@ struct MmaSyncBuilder {
606
606
// / IndexCalculator callback.
607
607
SmallVector<Value> buildMemRefLoads (OpBuilder &b, Location loc,
608
608
OpFoldResult laneId, Value memref,
609
- IndexCalculator indexFn);
609
+ const IndexCalculator & indexFn);
610
610
611
611
// / Perform a distributed load of a vector operand of `vectorShape` for a
612
612
// / particular MMA instruction whose `(row, col)` indices are specified via
@@ -625,7 +625,7 @@ struct MmaSyncBuilder {
625
625
SmallVector<Operation *> buildMemRefStores (OpBuilder &b, Location loc,
626
626
ValueRange toStore,
627
627
OpFoldResult laneId, Value memref,
628
- IndexCalculator indexFn);
628
+ const IndexCalculator & indexFn);
629
629
630
630
// / Perform a distributed store of a vector operand of `vectorShape` for a
631
631
// / particular MMA instruction whose `(row, col)` indices are specified via
@@ -660,10 +660,10 @@ static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
660
660
}
661
661
}
662
662
663
- SmallVector<Value> MmaSyncBuilder::buildMemRefLoads (OpBuilder &b, Location loc,
664
- OpFoldResult laneId ,
665
- Value memref,
666
- IndexCalculator indexFn) {
663
+ SmallVector<Value>
664
+ MmaSyncBuilder::buildMemRefLoads (OpBuilder &b, Location loc ,
665
+ OpFoldResult laneId, Value memref,
666
+ const IndexCalculator & indexFn) {
667
667
auto aff = [&](AffineExpr e) {
668
668
return affine::makeComposedFoldedAffineApply (b, loc, e, laneId);
669
669
};
@@ -681,7 +681,7 @@ SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
681
681
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand (
682
682
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683
683
IndexCalculator indexFn, ArrayRef<int64_t > vectorShape) {
684
- auto loads = buildMemRefLoads (b, loc, laneId, memref, indexFn);
684
+ auto loads = buildMemRefLoads (b, loc, laneId, memref, std::move ( indexFn) );
685
685
686
686
Type elementType = getElementTypeOrSelf (memref.getType ());
687
687
auto vt = VectorType::get (vectorShape, elementType);
@@ -700,10 +700,9 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
700
700
return res;
701
701
}
702
702
703
- SmallVector<Operation *>
704
- MmaSyncBuilder::buildMemRefStores (OpBuilder &b, Location loc,
705
- ValueRange toStore, OpFoldResult laneId,
706
- Value memref, IndexCalculator indexFn) {
703
+ SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores (
704
+ OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705
+ Value memref, const IndexCalculator &indexFn) {
707
706
auto aff = [&](AffineExpr e) {
708
707
return affine::makeComposedFoldedAffineApply (b, loc, e, laneId);
709
708
};
@@ -734,7 +733,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
734
733
[&](Value v, int64_t linearIdx, ArrayRef<int64_t > indices) {
735
734
toStore.push_back (v);
736
735
});
737
- return buildMemRefStores (b, loc, toStore, laneId, memref, indexFn);
736
+ return buildMemRefStores (b, loc, toStore, laneId, memref, std::move ( indexFn) );
738
737
}
739
738
740
739
static std::tuple<SmallVector<int64_t >, SmallVector<int64_t >,
0 commit comments