@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
88
88
// Inspect constant mask index. If the index exceeds the
89
89
// dimension size, all bits are set. If the index is zero
90
90
// or less, no bits are set.
91
- ArrayAttr masks = m.getMaskDimSizes ();
91
+ ArrayRef< int64_t > masks = m.getMaskDimSizes ();
92
92
auto shape = m.getType ().getShape ();
93
93
bool allTrue = true ;
94
94
bool allFalse = true ;
95
95
for (auto [maskIdx, dimSize] : llvm::zip_equal (masks, shape)) {
96
- int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt ();
97
- if (i < dimSize)
96
+ if (maskIdx < dimSize)
98
97
allTrue = false ;
99
- if (i > 0 )
98
+ if (maskIdx > 0 )
100
99
allFalse = false ;
101
100
}
102
101
if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
3593
3592
if (extractStridedSliceOp.hasNonUnitStrides ())
3594
3593
return failure ();
3595
3594
// Gather constant mask dimension sizes.
3596
- SmallVector<int64_t , 4 > maskDimSizes;
3597
- populateFromInt64AttrArray (constantMaskOp.getMaskDimSizes (), maskDimSizes);
3595
+ ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
3598
3596
// Gather strided slice offsets and sizes.
3599
3597
SmallVector<int64_t , 4 > sliceOffsets;
3600
3598
populateFromInt64AttrArray (extractStridedSliceOp.getOffsets (),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
3625
3623
// region.
3626
3624
rewriter.replaceOpWithNewOp <ConstantMaskOp>(
3627
3625
extractStridedSliceOp, extractStridedSliceOp.getResult ().getType (),
3628
- vector::getVectorSubscriptAttr (rewriter, sliceMaskDimSizes) );
3626
+ sliceMaskDimSizes);
3629
3627
return success ();
3630
3628
}
3631
3629
};
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
5410
5408
}
5411
5409
5412
5410
if (constantMaskOp) {
5413
- auto maskDimSizes = constantMaskOp.getMaskDimSizes (). getValue () ;
5411
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes ();
5414
5412
auto numMaskOperands = maskDimSizes.size ();
5415
5413
5416
5414
// Check every mask dim size to see whether it can be dropped
5417
5415
for (size_t i = numMaskOperands - 1 ; i >= numMaskOperands - numDimsToDrop;
5418
5416
--i) {
5419
- if (cast<IntegerAttr>( maskDimSizes[i]). getValue () != 1 )
5417
+ if (maskDimSizes[i] != 1 )
5420
5418
return failure ();
5421
5419
}
5422
5420
5423
5421
auto newMaskOperands = maskDimSizes.drop_back (numDimsToDrop);
5424
- ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr (newMaskOperands);
5425
-
5426
5422
rewriter.replaceOpWithNewOp <vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5427
- newMaskOperandsAttr );
5423
+ newMaskOperands );
5428
5424
return success ();
5429
5425
}
5430
5426
@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5804
5800
5805
5801
// ConstantMaskOp case.
5806
5802
auto maskDimSizes = constantMaskOp.getMaskDimSizes ();
5807
- SmallVector<Attribute> newMaskDimSizes (maskDimSizes.getValue ());
5808
- applyPermutationToVector (newMaskDimSizes, permutation);
5803
+ auto newMaskDimSizes = applyPermutation (maskDimSizes, permutation);
5809
5804
5810
5805
rewriter.replaceOpWithNewOp <vector::ConstantMaskOp>(
5811
- transpOp, transpOp.getResultVectorType (),
5812
- ArrayAttr::get (transpOp.getContext (), newMaskDimSizes));
5806
+ transpOp, transpOp.getResultVectorType (), newMaskDimSizes);
5813
5807
return success ();
5814
5808
}
5815
5809
};
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
5832
5826
if (resultType.getRank () == 0 ) {
5833
5827
if (getMaskDimSizes ().size () != 1 )
5834
5828
return emitError (" array attr must have length 1 for 0-D vectors" );
5835
- auto dim = llvm::cast<IntegerAttr>( getMaskDimSizes ()[0 ]). getInt () ;
5829
+ auto dim = getMaskDimSizes ()[0 ];
5836
5830
if (dim != 0 && dim != 1 )
5837
5831
return emitError (" mask dim size must be either 0 or 1 for 0-D vectors" );
5838
5832
return success ();
@@ -5846,17 +5840,15 @@ LogicalResult ConstantMaskOp::verify() {
5846
5840
// result dimension size.
5847
5841
auto resultShape = resultType.getShape ();
5848
5842
auto resultScalableDims = resultType.getScalableDims ();
5849
- SmallVector<int64_t , 4 > maskDimSizes;
5850
- for (const auto [index , intAttr] : llvm::enumerate (getMaskDimSizes ())) {
5851
- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt ();
5843
+ ArrayRef<int64_t > maskDimSizes = getMaskDimSizes ();
5844
+ for (const auto [index , maskDimSize] : llvm::enumerate (maskDimSizes)) {
5852
5845
if (maskDimSize < 0 || maskDimSize > resultShape[index ])
5853
5846
return emitOpError (
5854
5847
" array attr of size out of bounds of vector result dimension size" );
5855
5848
if (resultScalableDims[index ] && maskDimSize != 0 &&
5856
5849
maskDimSize != resultShape[index ])
5857
5850
return emitOpError (
5858
5851
" only supports 'none set' or 'all set' scalable dimensions" );
5859
- maskDimSizes.push_back (maskDimSize);
5860
5852
}
5861
5853
// Verify that if one mask dim size is zero, they all should be zero (because
5862
5854
// the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
5873
5865
// Check the corner case of 0-D vectors first.
5874
5866
if (resultType.getRank () == 0 ) {
5875
5867
assert (getMaskDimSizes ().size () == 1 && " invalid sizes for zero rank mask" );
5876
- return llvm::cast<IntegerAttr>( getMaskDimSizes ()[0 ]). getInt () == 1 ;
5868
+ return getMaskDimSizes ()[0 ] == 1 ;
5877
5869
}
5878
- for (const auto [resultSize, intAttr ] :
5870
+ for (const auto [resultSize, maskDimSize ] :
5879
5871
llvm::zip_equal (resultType.getShape (), getMaskDimSizes ())) {
5880
- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt ();
5881
5872
if (maskDimSize < resultSize)
5882
5873
return false ;
5883
5874
}
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6007
5998
}
6008
5999
6009
6000
// Replace 'createMaskOp' with ConstantMaskOp.
6010
- rewriter.replaceOpWithNewOp <ConstantMaskOp>(
6011
- createMaskOp, retTy,
6012
- vector::getVectorSubscriptAttr (rewriter, maskDimSizes));
6001
+ rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, retTy,
6002
+ maskDimSizes);
6013
6003
return success ();
6014
6004
}
6015
6005
};
0 commit comments