@@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
37
37
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
38
38
::llvm::cast<VectorType>($_self).isScalable()}]>;
39
39
40
+ // Whether a type is a scalable VectorType, with a single trailing scalable dimension.
41
+ // Examples:
42
+ // Valid:
43
+ // - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
44
+ // Invalid
45
+ // - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
46
+ def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
47
+ CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
48
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
49
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
50
+ CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
51
+ ]>;
52
+
40
53
// Whether a type is a VectorType and all dimensions are scalable.
41
54
def allDimsScalableVectorTypePred : And<[
42
55
IsVectorTypePred,
@@ -404,6 +417,15 @@ class ScalableVectorOf<list<Type> allowedTypes> :
404
417
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
405
418
"scalable vector", "::mlir::VectorType">;
406
419
420
+ // Any vector with a single trailing scalable dimension, with an element type in
421
+ // the `allowedTypes` list.
422
+ //
423
+ // Note: This Similar to ScalableVectorOf, with the extra requirement that only
424
+ // the trailing dim is scalable.
425
+ class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
426
+ ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred,
427
+ "trailing scalable vector", "::mlir::VectorType">;
428
+
407
429
// Whether the number of elements of a vector is from the given
408
430
// `allowedRanks` list
409
431
class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,6 +503,40 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
481
503
== }]
482
504
# allowedlength>)>]>;
483
505
506
+ // Normalizes an index so the indices in both directions have the same value.
507
+ // For example, when indexing forwards index 2 is the third element. When
508
+ // indexing in reverse the third element is -3. This helper would map both of
509
+ // these to the "normalized" index of 3. This makes the bounds checking in
510
+ // IsNthDimSizeIsOneOfPred simpler (see first CPred).
511
+ class NormalizeIndex<int value> {
512
+ int ret = !if(!lt(value, 0),
513
+ !sub(0, value) /* -value if negative */,
514
+ !add(value, 1) /* value + 1 if positive*/);
515
+ }
516
+
517
+ // Whether the n-th dim of the shape is contained within `allowedSizes`.
518
+ // Negative values for `n` index in reverse.
519
+ //
520
+ // Examples:
521
+ // IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
522
+ // - Accepts any shape where the first dim is 2, 3, or 4.
523
+ // * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc
524
+ // IsNthDimSizeIsOneOfPred<-1, {16}>
525
+ // - Accepts any shape where the last dim is 16.
526
+ // * This means shapes like 2x16, 16, 1x2x3x4x16, etc
527
+ // IsNthDimSizeIsOneOfPred<-2, {10, 5}>
528
+ // - Accepts any shape where the second to last dim is 10 or 5.
529
+ // * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc
530
+ class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
531
+ : And<[
532
+ CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
533
+ CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
534
+ # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
535
+ # !if(!lt(n, 0),
536
+ "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
537
+ "" # n)
538
+ # "))">]>;
539
+
484
540
// Whether the shape of a vector matches the given `shape` list.
485
541
class IsVectorOfShape<list<int> shape>
486
542
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
@@ -546,6 +602,24 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
546
602
ScalableVectorOfLength<allowedLengths>.summary,
547
603
"::mlir::VectorType">;
548
604
605
+ // Any ShapedType where the size of the n-th dim is contained in `allowedSizes`.
606
+ // Negative values for `n` index in reverse.
607
+ class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
608
+ IsNthDimSizeIsOneOfPred<n, allowedSizes>,
609
+ " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
610
+ "::mlir::ShapedType">;
611
+
612
+ // Any scalable vector with a single trailing scalable dimensions, where the
613
+ // size of the trailing dimension is in `allowedTrailingSizes` list, and the
614
+ // type is in the `allowedTypes` list.
615
+ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
616
+ list<Type> allowedTypes> : AllOfType<
617
+ [VectorWithTrailingDimScalableOf<allowedTypes>,
618
+ ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
619
+ VectorWithTrailingDimScalableOf<allowedTypes>.summary #
620
+ ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
621
+ "::mlir::VectorType">;
622
+
549
623
def AnyVector : VectorOf<[AnyType]>;
550
624
// Temporary vector type clone that allows gradual transition to 0-D vectors.
551
625
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
0 commit comments