@@ -5546,124 +5546,56 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5546
5546
setResultRanges (getResult (), argRanges.front ());
5547
5547
}
5548
5548
5549
- // / Returns true if each element of 'a' is equal to the product of a contiguous
5550
- // / sequence of the elements of 'b'. Returns false otherwise.
5551
- static bool isValidShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
5552
- unsigned rankA = a.size ();
5553
- unsigned rankB = b.size ();
5554
- assert (rankA < rankB);
5555
-
5556
- auto isOne = [](int64_t v) { return v == 1 ; };
5557
-
5558
- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5559
- // casted to a 0-d vector.
5560
- if (rankA == 0 && llvm::all_of (b, isOne))
5561
- return true ;
5549
+ LogicalResult ShapeCastOp::verify () {
5562
5550
5563
- unsigned i = 0 ;
5564
- unsigned j = 0 ;
5565
- while (i < rankA && j < rankB) {
5566
- int64_t dimA = a[i];
5567
- int64_t dimB = 1 ;
5568
- while (dimB < dimA && j < rankB)
5569
- dimB *= b[j++];
5570
- if (dimA != dimB)
5571
- break ;
5572
- ++i;
5551
+ VectorType sourceType = getSourceVectorType ();
5552
+ VectorType resultType = getResultVectorType ();
5573
5553
5574
- // Handle the case when trailing dimensions are of size 1.
5575
- // Include them into the contiguous sequence.
5576
- if (i < rankA && llvm::all_of (a.slice (i), isOne))
5577
- i = rankA;
5578
- if (j < rankB && llvm::all_of (b.slice (j), isOne))
5579
- j = rankB;
5580
- }
5554
+ // Check that element type is preserved
5555
+ if (sourceType.getElementType () != resultType.getElementType ())
5556
+ return emitOpError (" has different source and result element types" );
5581
5557
5582
- return i == rankA && j == rankB;
5583
- }
5584
-
5585
- static LogicalResult verifyVectorShapeCast (Operation *op,
5586
- VectorType sourceVectorType,
5587
- VectorType resultVectorType) {
5588
- // Check that element type is the same.
5589
- if (sourceVectorType.getElementType () != resultVectorType.getElementType ())
5590
- return op->emitOpError (" source/result vectors must have same element type" );
5591
- auto sourceShape = sourceVectorType.getShape ();
5592
- auto resultShape = resultVectorType.getShape ();
5593
-
5594
- // Check that product of source dim sizes matches product of result dim sizes.
5595
- int64_t sourceDimProduct = std::accumulate (
5596
- sourceShape.begin (), sourceShape.end (), 1LL , std::multiplies<int64_t >{});
5597
- int64_t resultDimProduct = std::accumulate (
5598
- resultShape.begin (), resultShape.end (), 1LL , std::multiplies<int64_t >{});
5599
- if (sourceDimProduct != resultDimProduct)
5600
- return op->emitOpError (" source/result number of elements must match" );
5601
-
5602
- // Check that expanding/contracting rank cases.
5603
- unsigned sourceRank = sourceVectorType.getRank ();
5604
- unsigned resultRank = resultVectorType.getRank ();
5605
- if (sourceRank < resultRank) {
5606
- if (!isValidShapeCast (sourceShape, resultShape))
5607
- return op->emitOpError (" invalid shape cast" );
5608
- } else if (sourceRank > resultRank) {
5609
- if (!isValidShapeCast (resultShape, sourceShape))
5610
- return op->emitOpError (" invalid shape cast" );
5558
+ // Check that number of elements is preserved
5559
+ int64_t sourceNElms = sourceType.getNumElements ();
5560
+ int64_t resultNElms = resultType.getNumElements ();
5561
+ if (sourceNElms != resultNElms) {
5562
+ return emitOpError () << " has different number of elements at source ("
5563
+ << sourceNElms << " ) and result (" << resultNElms
5564
+ << " )" ;
5611
5565
}
5612
5566
5613
5567
// Check that (non-)scalability is preserved
5614
- int64_t sourceNScalableDims = sourceVectorType .getNumScalableDims ();
5615
- int64_t resultNScalableDims = resultVectorType .getNumScalableDims ();
5568
+ int64_t sourceNScalableDims = sourceType .getNumScalableDims ();
5569
+ int64_t resultNScalableDims = resultType .getNumScalableDims ();
5616
5570
if (sourceNScalableDims != resultNScalableDims)
5617
- return op->emitOpError (" different number of scalable dims at source (" )
5618
- << sourceNScalableDims << " ) and result (" << resultNScalableDims
5619
- << " )" ;
5620
- sourceVectorType.getNumDynamicDims ();
5621
-
5622
- return success ();
5623
- }
5624
-
5625
- LogicalResult ShapeCastOp::verify () {
5626
- auto sourceVectorType =
5627
- llvm::dyn_cast_or_null<VectorType>(getSource ().getType ());
5628
- auto resultVectorType =
5629
- llvm::dyn_cast_or_null<VectorType>(getResult ().getType ());
5630
-
5631
- // Check if source/result are of vector type.
5632
- if (sourceVectorType && resultVectorType)
5633
- return verifyVectorShapeCast (*this , sourceVectorType, resultVectorType);
5571
+ return emitOpError () << " has different number of scalable dims at source ("
5572
+ << sourceNScalableDims << " ) and result ("
5573
+ << resultNScalableDims << " )" ;
5634
5574
5635
5575
return success ();
5636
5576
}
5637
5577
5638
5578
OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5639
5579
5580
+ VectorType resultType = getType ();
5581
+
5640
5582
// No-op shape cast.
5641
- if (getSource ().getType () == getType () )
5583
+ if (getSource ().getType () == resultType )
5642
5584
return getSource ();
5643
5585
5644
- VectorType resultType = getType ();
5645
-
5646
- // Canceling shape casts .
5586
+ // Y = shape_cast(shape_cast(X)))
5587
+ // -> X, if X and Y have same type
5588
+ // -> shape_cast(X) otherwise .
5647
5589
if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5648
-
5649
- // Only allows valid transitive folding (expand/collapse dimensions).
5650
5590
VectorType srcType = otherOp.getSource ().getType ();
5651
5591
if (resultType == srcType)
5652
5592
return otherOp.getSource ();
5653
- if (srcType.getRank () < resultType.getRank ()) {
5654
- if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
5655
- return {};
5656
- } else if (srcType.getRank () > resultType.getRank ()) {
5657
- if (!isValidShapeCast (resultType.getShape (), srcType.getShape ()))
5658
- return {};
5659
- } else {
5660
- return {};
5661
- }
5662
5593
setOperand (otherOp.getSource ());
5663
5594
return getResult ();
5664
5595
}
5665
5596
5666
- // Cancelling broadcast and shape cast ops.
5597
+ // Y = shape_cast(broadcast(X))
5598
+ // -> X, if X and Y have same type
5667
5599
if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5668
5600
if (bcastOp.getSourceType () == resultType)
5669
5601
return bcastOp.getSource ();
0 commit comments