Skip to content

Commit 3ce8095

Browse files
committed
[mlir][VectorOps] Add ShapeCastOp to the vector ops dialect.
Summary: Add ShapeCastOp to the vector ops dialect. The shape_cast operation casts between an n-D source vector shape and a k-D result vector shape (the element type remains the same). Reviewers: nicolasvasilache, aartbik Reviewed By: nicolasvasilache Subscribers: Joonsoo, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73635
1 parent 801857c commit 3ce8095

File tree

4 files changed

+225
-0
lines changed

4 files changed

+225
-0
lines changed

mlir/include/mlir/Dialect/VectorOps/VectorOps.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,49 @@ def Vector_TransferWriteOp :
963963
}];
964964
}
965965

966+
def Vector_ShapeCastOp :
967+
Vector_Op<"shape_cast", [NoSideEffect]>,
968+
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
969+
Results<(outs AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$result)> {
970+
let summary = "shape_cast casts between vector shapes";
971+
let description = [{
972+
The shape_cast operation casts between an n-D source vector shape and
973+
a k-D result vector shape (the element type remains the same).
974+
975+
If reducing rank (n > k), result dimension sizes must be a product
976+
of contiguous source dimension sizes.
977+
If expanding rank (n < k), source dimensions must factor into a
978+
contiguous sequence of destination dimension sizes.
979+
Each source dim is expanded (or contiguous sequence of source dims combined)
980+
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
981+
sequence of result dims (or a single result dim), in result dimension list
982+
order (i.e. 0 <= j < k). The product of all source dimension sizes and all
983+
result dimension sizes must match.
984+
985+
If the source/result types are a tuple of vectors, the casting operation
986+
described above is applied to each source/result tuple element pair.
987+
988+
It is currently assumed that this operation does not require moving data,
989+
and that it will be canonicalized away before lowering vector operations.
990+
991+
Examples:
992+
993+
```mlir
994+
// Example casting to a lower vector rank.
995+
%1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
996+
997+
// Example casting to a higher vector rank.
998+
%3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
999+
1000+
// Example casting a tuple of vectors of same rank, where tuple elements
1001+
// may have different shapes.
1002+
%5 = vector.shape_cast %4 : tuple<vector<3x4x2xf32>, vector<3x3x2xf32>> to
1003+
tuple<vector<12x2xf32>, vector<9x2xf32>>
1004+
```
1005+
}];
1006+
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
1007+
}
1008+
9661009
def Vector_TypeCastOp :
9671010
Vector_Op<"type_cast", [NoSideEffect]>,
9681011
Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>,

mlir/lib/Dialect/VectorOps/VectorOps.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Support/MathExtras.h"
2727
#include "mlir/Support/STLExtras.h"
2828
#include "llvm/ADT/StringSet.h"
29+
#include <numeric>
2930

3031
using namespace mlir;
3132
using namespace mlir::vector;
@@ -1389,6 +1390,90 @@ static LogicalResult verify(TransferWriteOp op) {
13891390
[&op](Twine t) { return op.emitOpError(t); });
13901391
}
13911392

1393+
//===----------------------------------------------------------------------===//
1394+
// ShapeCastOp
1395+
//===----------------------------------------------------------------------===//
1396+
1397+
/// Returns true if each element of 'a' is equal to the product of a contiguous
1398+
/// sequence of the elements of 'b'. Returns false otherwise.
1399+
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
1400+
unsigned rankA = a.size();
1401+
unsigned rankB = b.size();
1402+
assert(rankA < rankB);
1403+
1404+
unsigned i = 0;
1405+
unsigned j = 0;
1406+
while (i < rankA && j < rankB) {
1407+
int64_t dimA = a[i];
1408+
int64_t dimB = 1;
1409+
while (dimB < dimA && j < rankB)
1410+
dimB *= b[j++];
1411+
if (dimA != dimB)
1412+
break;
1413+
++i;
1414+
}
1415+
1416+
return i == rankA && j == rankB;
1417+
}
1418+
1419+
static LogicalResult verifyVectorShapeCast(Operation *op,
1420+
VectorType sourceVectorType,
1421+
VectorType resultVectorType) {
1422+
// Check that element type is the same.
1423+
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
1424+
return op->emitOpError("source/result vectors must have same element type");
1425+
auto sourceShape = sourceVectorType.getShape();
1426+
auto resultShape = resultVectorType.getShape();
1427+
1428+
// Check that product of source dim sizes matches product of result dim sizes.
1429+
int64_t sourceDimProduct = std::accumulate(
1430+
sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
1431+
int64_t resultDimProduct = std::accumulate(
1432+
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
1433+
if (sourceDimProduct != resultDimProduct)
1434+
return op->emitOpError("source/result number of elements must match");
1435+
1436+
// Check that expanding/contracting rank cases.
1437+
unsigned sourceRank = sourceVectorType.getRank();
1438+
unsigned resultRank = resultVectorType.getRank();
1439+
if (sourceRank < resultRank) {
1440+
if (!isValidShapeCast(sourceShape, resultShape))
1441+
return op->emitOpError("invalid shape cast");
1442+
} else if (sourceRank > resultRank) {
1443+
if (!isValidShapeCast(resultShape, sourceShape))
1444+
return op->emitOpError("invalid shape cast");
1445+
}
1446+
return success();
1447+
}
1448+
1449+
static LogicalResult verify(ShapeCastOp op) {
1450+
auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
1451+
auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
1452+
1453+
// Check if source/result are of vector type.
1454+
if (sourceVectorType && resultVectorType)
1455+
return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
1456+
1457+
// Check if source/result are "tuple of vectors" type.
1458+
auto sourceTupleType = op.source().getType().dyn_cast_or_null<TupleType>();
1459+
auto resultTupleType = op.result().getType().dyn_cast_or_null<TupleType>();
1460+
if (!sourceTupleType || !resultTupleType)
1461+
return op.emitOpError("source/result must be of same type");
1462+
1463+
// Check that source/result tuple sizes are the same.
1464+
if (sourceTupleType.size() != resultTupleType.size())
1465+
return op.emitOpError("source/result tuples must be the same size");
1466+
1467+
// Check each source/result tuple element pair.
1468+
for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i)
1469+
if (failed(verifyVectorShapeCast(
1470+
op, sourceTupleType.getType(i).cast<VectorType>(),
1471+
resultTupleType.getType(i).cast<VectorType>())))
1472+
return failure();
1473+
1474+
return success();
1475+
}
1476+
13921477
//===----------------------------------------------------------------------===//
13931478
// TypeCastOp
13941479
//===----------------------------------------------------------------------===//

mlir/test/Dialect/VectorOps/invalid.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,85 @@ func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
889889
%1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
890890
: vector<3x2x4xf32> to vector<2x3x5xf32>
891891
}
892+
893+
// -----
894+
895+
func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
896+
// expected-error@+1 {{op source/result vectors must have same element type}}
897+
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
898+
}
899+
900+
// -----
901+
902+
func @shape_cast_wrong_element_type_tuple(%arg0 : tuple<vector<5x4x2xf32>,
903+
vector<3x4x2xf32>>) {
904+
// expected-error@+1 {{op source/result vectors must have same element type}}
905+
%0 = vector.shape_cast %arg0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
906+
tuple<vector<20x2xi32>, vector<12x2xi32>>
907+
}
908+
909+
// -----
910+
911+
func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
912+
// expected-error@+1 {{op source/result number of elements must match}}
913+
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
914+
}
915+
916+
// -----
917+
918+
func @shape_cast_wrong_num_elements_tuple(%arg0 : tuple<vector<5x4x2xf32>,
919+
vector<3x4x2xf32>>) {
920+
// expected-error@+1 {{op source/result number of elements must match}}
921+
%0 = vector.shape_cast %arg0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
922+
tuple<vector<21x2xf32>, vector<13x2xf32>>
923+
}
924+
925+
// -----
926+
927+
func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
928+
// expected-error@+1 {{invalid shape cast}}
929+
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
930+
}
931+
932+
// -----
933+
934+
func @shape_cast_invalid_rank_reduction_tuple(%arg0
935+
: tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>) {
936+
// expected-error@+1 {{invalid shape cast}}
937+
%0 = vector.shape_cast %arg0: tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
938+
tuple<vector<10x4xf32>, vector<6x4xf32>>
939+
}
940+
941+
// -----
942+
943+
func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
944+
// expected-error@+1 {{invalid shape cast}}
945+
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
946+
}
947+
948+
// -----
949+
950+
func @shape_cast_invalid_rank_expansion_tuple(%arg0 : tuple<vector<20x2xf32>,
951+
vector<12x2xf32>>) {
952+
// expected-error@+1 {{invalid shape cast}}
953+
%0 = vector.shape_cast %arg0 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
954+
tuple<vector<5x2x4xf32>, vector<4x3x2xf32>>
955+
}
956+
957+
// -----
958+
959+
func @shape_cast_source_result_different_types(
960+
%arg1 : tuple<vector<20x2xf32>, vector<12x2xf32>>) {
961+
// expected-error@+1 {{source/result must be of same type}}
962+
%1 = vector.shape_cast %arg1 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
963+
vector<5x2x4xf32>
964+
}
965+
966+
// -----
967+
968+
func @shape_cast_different_tuple_sizes(
969+
%arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>) {
970+
// expected-error@+1 {{op source/result tuples must be the same size}}
971+
%1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
972+
tuple<vector<20x2xf32>>
973+
}

mlir/test/Dialect/VectorOps/ops.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,18 @@ func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
233233

234234
return %1 : vector<2x3x4xf32>
235235
}
236+
237+
// CHECK-LABEL: shape_cast
238+
func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
239+
%arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>)
240+
-> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>) {
241+
242+
// CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32>
243+
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
244+
245+
// CHECK-NEXT: vector.shape_cast %{{.*}} : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to tuple<vector<20x2xf32>, vector<12x2xf32>>
246+
%1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
247+
tuple<vector<20x2xf32>, vector<12x2xf32>>
248+
249+
return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
250+
}

0 commit comments

Comments
 (0)