Skip to content

Commit c194bc7

Browse files
authored
[mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA (#98620)
This adds a workaround rewrite that allows stores of unsupported SVE transposes such as: ```mlir %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32> ``` To use SME tiles, which are possible to lower (when SME is available): ```mlir // Insert vector<2x[4]xf32> into an SME tile: %0 = arm_sme.get_tile : vector<[4]x[4]xf32> %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> // Store the tile with a transpose + mask: %c4_vscale = arith.muli %vscale, %c4 : index %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> vector.transfer_write %4, %arg1[%arg2, %arg3], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<[4]x[4]xf32>, memref<?x?xf32> ```
1 parent 99bb9a7 commit c194bc7

File tree

5 files changed

+278
-12
lines changed

5 files changed

+278
-12
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def VectorLegalization
202202
"func::FuncDialect",
203203
"arm_sme::ArmSMEDialect",
204204
"vector::VectorDialect",
205-
"arith::ArithDialect"
205+
"arith::ArithDialect",
206+
"index::IndexDialect"
206207
];
207208
}
208209

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1010
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1111

12+
#include "mlir/Dialect/Arith/IR/Arith.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1415
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -101,6 +102,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
101102
std::optional<StaticTileOffsetRange>
102103
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
103104

105+
/// Returns a functor (int64_t -> Value) which returns a constant vscale
106+
/// multiple.
107+
///
108+
/// Example:
109+
/// ```c++
110+
/// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
111+
/// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
112+
/// ```
113+
inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
114+
Value vscale = nullptr;
115+
return [loc, vscale, &rewriter](int64_t multiplier) mutable {
116+
if (!vscale)
117+
vscale = rewriter.create<vector::VectorScaleOp>(loc);
118+
return rewriter.create<arith::MulIOp>(
119+
loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
120+
};
121+
}
122+
104123
/// A wrapper for getMixedSizes for vector.transfer_read and
105124
/// vector.transfer_write Ops (for source and destination, respectively).
106125
///

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
1616
MLIRFuncDialect
1717
MLIRLLVMCommonConversion
1818
MLIRVectorDialect
19+
MLIRIndexDialect
1920
MLIRSCFDialect
2021
MLIRSCFTransforms
2122
MLIRFuncTransforms

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 154 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
21+
#include "mlir/Dialect/Index/IR/IndexDialect.h"
22+
#include "mlir/Dialect/Index/IR/IndexOps.h"
2123
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2224
#include "mlir/Dialect/SCF/IR/SCF.h"
2325
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2426
#include "mlir/Dialect/Utils/IndexingUtils.h"
27+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2528
#include "mlir/Transforms/OneToNTypeConversion.h"
2629

2730
#define DEBUG_TYPE "arm-sme-vector-legalization"
@@ -140,11 +143,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
140143
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
141144
VectorType smeTileType,
142145
bool transposeIndices = false) {
143-
assert(isMultipleOfSMETileVectorType(type) &&
144-
"`type` not multiple of SME tiles");
145146
return llvm::map_range(
146-
StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
147-
smeTileType.getDimSize(1)}),
147+
StaticTileOffsetRange(
148+
type.getShape(),
149+
{std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
150+
std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
148151
[=](auto indices) {
149152
int row = int(indices[0]);
150153
int col = int(indices[1]);
@@ -440,12 +443,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
440443
kMatchFailureUnsupportedMaskOp);
441444

442445
auto loc = writeOp.getLoc();
443-
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
444-
auto createVscaleMultiple = [&](int64_t multiplier) {
445-
return rewriter.create<arith::MulIOp>(
446-
loc, vscale,
447-
rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
448-
};
446+
auto createVscaleMultiple =
447+
vector::makeVscaleConstantBuilder(rewriter, loc);
449448

450449
// Get SME tile and slice types.
451450
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -775,6 +774,149 @@ struct ConvertIllegalShapeCastOpsToTransposes
775774
}
776775
};
777776

777+
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
778+
/// the ZA state. This workaround rewrite to support these transposes when ZA is
779+
/// available.
780+
///
781+
/// Example:
782+
///
783+
/// BEFORE:
784+
/// ```mlir
785+
/// %transpose = vector.transpose %vec, [1, 0]
786+
/// : vector<2x[4]xf32> to vector<[4]x2xf32>
787+
/// vector.transfer_write %transpose, %dest[%y, %x]
788+
/// : vector<[4]x2xf32>, memref<?x?xf32>
789+
/// ```
790+
///
791+
/// AFTER:
792+
/// ```mlir
793+
/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
794+
/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
795+
/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
796+
/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
797+
/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
798+
/// %c4_vscale = arith.muli %vscale, %c4 : index
799+
/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
800+
/// vector.transfer_write %4, %dest[%y, %x], %mask
801+
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
802+
/// : vector<[4]x[4]xf32>, memref<?x?xf32>
803+
/// ```
804+
///
805+
/// Values larger than a single tile are supported via decomposition.
806+
struct LowerIllegalTransposeStoreViaZA
807+
: public OpRewritePattern<vector::TransferWriteOp> {
808+
using OpRewritePattern::OpRewritePattern;
809+
810+
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
811+
PatternRewriter &rewriter) const override {
812+
if (!isSupportedMaskOp(writeOp.getMask()))
813+
return rewriter.notifyMatchFailure(writeOp,
814+
kMatchFailureUnsupportedMaskOp);
815+
816+
auto permutationMap = writeOp.getPermutationMap();
817+
if (!permutationMap.isIdentity())
818+
return rewriter.notifyMatchFailure(writeOp,
819+
kMatchFailureNonPermutationMap);
820+
821+
auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
822+
if (!transposeOp)
823+
return failure();
824+
825+
auto sourceType = transposeOp.getSourceVectorType();
826+
auto resultType = transposeOp.getResultVectorType();
827+
828+
if (resultType.getRank() != 2)
829+
return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
830+
831+
if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
832+
return rewriter.notifyMatchFailure(
833+
transposeOp, "not illegal/unsupported SVE transpose");
834+
835+
auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
836+
VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
837+
838+
if (sourceType.getDimSize(0) <= 1 ||
839+
sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
840+
return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
841+
842+
auto loc = writeOp.getLoc();
843+
auto createVscaleMultiple =
844+
vector::makeVscaleConstantBuilder(rewriter, loc);
845+
846+
auto transposeMap = AffineMapAttr::get(
847+
AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
848+
849+
// Note: We need to use `get_tile` as there's no vector-level `undef`.
850+
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
851+
Value destTensorOrMemref = writeOp.getSource();
852+
auto numSlicesPerTile =
853+
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
854+
auto numSlices =
855+
rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
856+
for (auto [index, smeTile] : llvm::enumerate(
857+
decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
858+
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number
859+
// of slices from the source type into the SME tile. Without checking
860+
// vscale (and emitting multiple implementations) we can't make use of the
861+
// rows of the tile after 1*vscale rows.
862+
Value tile = undefTile;
863+
for (int d = 0; d < numSlicesPerTile; ++d) {
864+
Value vector = rewriter.create<vector::ExtractOp>(
865+
loc, transposeOp.getVector(),
866+
rewriter.getIndexAttr(d + smeTile.row));
867+
if (vector.getType() != smeSliceType) {
868+
vector = rewriter.create<vector::ScalableExtractOp>(
869+
loc, smeSliceType, vector, smeTile.col);
870+
}
871+
tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
872+
}
873+
874+
// 2. Transpose the tile position.
875+
auto transposedRow = createVscaleMultiple(smeTile.col);
876+
auto transposedCol =
877+
rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
878+
879+
// 3. Compute mask for tile store.
880+
Value maskRows;
881+
Value maskCols;
882+
if (auto mask = writeOp.getMask()) {
883+
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
884+
maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
885+
transposedRow);
886+
maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
887+
transposedCol);
888+
maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
889+
} else {
890+
maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
891+
maskCols = numSlices;
892+
}
893+
auto subMask = rewriter.create<vector::CreateMaskOp>(
894+
loc, smeTileType.clone(rewriter.getI1Type()),
895+
ValueRange{maskRows, maskCols});
896+
897+
// 4. Emit a transposed tile write.
898+
auto writeIndices = writeOp.getIndices();
899+
Value destRow =
900+
rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
901+
Value destCol =
902+
rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
903+
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
904+
loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
905+
transposeMap, subMask, writeOp.getInBounds());
906+
907+
if (writeOp.hasPureTensorSemantics())
908+
destTensorOrMemref = smeWrite.getResult();
909+
}
910+
911+
if (writeOp.hasPureTensorSemantics())
912+
rewriter.replaceOp(writeOp, destTensorOrMemref);
913+
else
914+
rewriter.eraseOp(writeOp);
915+
916+
return success();
917+
}
918+
};
919+
778920
struct VectorLegalizationPass
779921
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
780922
void runOnOperation() override {
@@ -796,7 +938,8 @@ struct VectorLegalizationPass
796938

797939
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
798940
LiftIllegalVectorTransposeToMemory,
799-
ConvertIllegalShapeCastOpsToTransposes>(context);
941+
ConvertIllegalShapeCastOpsToTransposes,
942+
LowerIllegalTransposeStoreViaZA>(context);
800943
// Note: These two patterns are added with a high benefit to ensure:
801944
// - Masked outer products are handled before unmasked ones
802945
// - Multi-tile writes are lowered as a store loop (if possible)

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,105 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
544544
%0 = arith.constant dense<42> : vector<[8]x[8]xi32>
545545
return %0 : vector<[8]x[8]xi32>
546546
}
547+
548+
// -----
549+
550+
// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
551+
552+
// CHECK-LABEL: @transpose_store_scalable_via_za(
553+
// CHECK-SAME: %[[VEC:.*]]: vector<2x[4]xf32>
554+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
555+
// CHECK-SAME: %[[I:.*]]: index,
556+
// CHECK-SAME: %[[J:.*]]: index)
557+
func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
558+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
559+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
560+
// CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
561+
// CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
562+
// CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
563+
// CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
564+
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
565+
// CHECK-NEXT: %[[VSCALE:.*]] = vector.vscale
566+
// CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
567+
// CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
568+
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
569+
%tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
570+
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
571+
return
572+
}
573+
574+
// -----
575+
576+
// CHECK-LABEL: @transpose_store_scalable_via_za_masked(
577+
// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
578+
// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
579+
func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
580+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
581+
// CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
582+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
583+
// CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
584+
%c0 = arith.constant 0 : index
585+
%mask = vector.create_mask %a, %b : vector<[4]x2xi1>
586+
%tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
587+
vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
588+
return
589+
}
590+
591+
// -----
592+
593+
// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile(
594+
// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32>
595+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
596+
// CHECK-SAME: %[[I:.*]]: index,
597+
// CHECK-SAME: %[[J:.*]]: index)
598+
func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
599+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
600+
601+
// <skip 3x other extract+insert chain>
602+
// CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
603+
// CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
604+
// CHECK: %[[VSCALE:.*]] = vector.vscale
605+
// CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
606+
// CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
607+
// CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
608+
609+
// <skip 3x other extract+insert chain>
610+
// CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
611+
// CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
612+
// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
613+
// CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
614+
%tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
615+
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>, memref<?x?xf32>
616+
return
617+
}
618+
619+
// -----
620+
621+
// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_wide
622+
func.func @transpose_store_scalable_via_za_multi_tile_wide(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
623+
// <check extracts from lower 4 x vscale of %vec>
624+
// CHECK: vector.scalable.extract
625+
// CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32>
626+
// CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_LOWER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
627+
// CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I:.[a-z0-9]+]], %[[J:[a-z0-9]+]]]
628+
629+
// <check extracts from upper 4 x vscale of %vec>
630+
// CHECK: vector.scalable.extract
631+
// CHECK: %[[ROW_2_UPPER:.*]] = vector.scalable.extract %{{.*}}[4] : vector<[4]xf32> from vector<[8]xf32>
632+
// CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_UPPER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
633+
// CHECK: %[[I_OFFSET:.*]] = arith.addi %c4_vscale, %[[I]] : index
634+
// CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I_OFFSET]], %[[J]]]
635+
%tr = vector.transpose %vec, [1, 0] : vector<2x[8]xf32> to vector<[8]x2xf32>
636+
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[8]x2xf32>, memref<?x?xf32>
637+
return
638+
}
639+
640+
// -----
641+
642+
// CHECK-LABEL: @negative_transpose_store_scalable_via_za__bad_source_shape
643+
// CHECK-NOT: arm_sme.get_tile
644+
func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vector<2x[7]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
645+
%tr = vector.transpose %vec, [1, 0] : vector<2x[7]xf32> to vector<[7]x2xf32>
646+
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
647+
return
648+
}

0 commit comments

Comments
 (0)