@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
356
356
357
357
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack (RewriterBase &rewriter,
358
358
tensor::UnPackOp unPackOp) {
359
- // 1. Filter out NYI cases.
360
- if (!unPackOp.getOuterDimsPerm ().empty () &&
361
- !isIdentityPermutation (unPackOp.getOuterDimsPerm ())) {
362
- return rewriter.notifyMatchFailure (unPackOp,
363
- " non-identity outer dims perm NYI" );
364
- }
365
-
366
359
Location loc = unPackOp->getLoc ();
367
360
OpBuilder::InsertionGuard g (rewriter);
368
361
rewriter.setInsertionPoint (unPackOp);
@@ -391,45 +384,42 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
391
384
return LowerUnPackOpResult{/* emptyOp=*/ nullptr , /* transposeOp=*/ nullptr ,
392
385
/* reshapeOp=*/ nullptr , extractSliceOp};
393
386
}
394
- // 2. Compute the permutation vector to move the last `numPackedDims` into
395
- // the `innerPosDims` of a shape of rank `packedRank`.
396
- int64_t numPackedDims = unPackOp.getInnerDimsPos ().size ();
397
- auto lastDims = llvm::to_vector (
398
- llvm::seq<int64_t >(packedRank - numPackedDims, packedRank));
399
- PackingMetadata packingMetadata =
400
- computePackingMetadata (packedRank, unPackOp.getInnerDimsPos ());
401
- SmallVector<int64_t > lastDimsToInsertPositionsPerm = computePermutationVector (
402
- packedRank, lastDims, packingMetadata.insertPositions );
403
-
404
- // 3. Compute the stripMinedShape: this is the packed shape without outer and
387
+
388
+ // 1. Compute the permutation vector to shuffle packed shape into the shape
389
+ // before any outer or inner permutations have been applied.
390
+ PackingMetadata packingMetadata;
391
+ SmallVector<int64_t > packedToStripMinedShapePerm =
392
+ tensor::getUnPackInverseSrcPerm (unPackOp, packingMetadata);
393
+
394
+ // 2. Compute the stripMinedShape: this is the packed shape without outer and
405
395
// inner permutations.
406
396
SmallVector<int64_t > stripMinedShape (packedTensorType.getShape ());
407
- applyPermutationToVector (stripMinedShape, lastDimsToInsertPositionsPerm );
397
+ applyPermutationToVector (stripMinedShape, packedToStripMinedShapePerm );
408
398
409
- // 4 . Transpose packedShape to stripMinedShape.
399
+ // 3 . Transpose packedShape to stripMinedShape.
410
400
RankedTensorType stripMinedTensorType =
411
401
RankedTensorType::Builder (packedTensorType).setShape (stripMinedShape);
412
402
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
413
403
stripMinedTensorType, packingMetadata.reassociations );
414
404
415
- // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
405
+ // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
416
406
// permutation.
417
407
SmallVector<OpFoldResult, 4 > dims =
418
408
tensor::getMixedSizes (rewriter, loc, unPackOp.getSource ());
419
- applyPermutationToVector (dims, lastDimsToInsertPositionsPerm );
409
+ applyPermutationToVector (dims, packedToStripMinedShapePerm );
420
410
auto emptyOp = rewriter.create <tensor::EmptyOp>(
421
411
loc, dims, stripMinedTensorType.getElementType ());
422
412
auto transposeOp = rewriter.create <linalg::TransposeOp>(
423
- loc, unPackOp.getSource (), emptyOp, lastDimsToInsertPositionsPerm );
413
+ loc, unPackOp.getSource (), emptyOp, packedToStripMinedShapePerm );
424
414
425
415
LLVM_DEBUG (
426
416
DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
427
417
DBGS () << " insertPositions: " );
428
418
DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
429
419
DBGS () << " packedShape: " );
430
420
DBGSNL ();
431
- llvm::interleaveComma (lastDimsToInsertPositionsPerm ,
432
- DBGS () << " lastDimsToInsertPositionsPerm : " );
421
+ llvm::interleaveComma (packedToStripMinedShapePerm ,
422
+ DBGS () << " packedToStripMinedShapePerm : " );
433
423
DBGSNL (); llvm::interleaveComma (
434
424
packingMetadata.reassociations , DBGS () << " reassociations: " ,
435
425
[&](ReassociationIndices ri) {
@@ -439,24 +429,24 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
439
429
llvm::interleaveComma (stripMinedShape, DBGS () << " stripMinedShape: " );
440
430
DBGSNL (); DBGS () << " collapsed type: " << collapsedType; DBGSNL (););
441
431
442
- // 5 . Collapse from the stripMinedShape to the padded result.
432
+ // 4 . Collapse from the stripMinedShape to the padded result.
443
433
auto reshapeOp = rewriter.create <tensor::CollapseShapeOp>(
444
434
loc, collapsedType, transposeOp->getResult (0 ),
445
435
packingMetadata.reassociations );
446
436
447
- // 6 . ExtractSlice.
437
+ // 5 . ExtractSlice.
448
438
int64_t destRank = destTensorType.getRank ();
449
439
auto extractSliceOp = rewriter.create <tensor::ExtractSliceOp>(
450
440
loc, destTensorType, reshapeOp->getResult (0 ),
451
441
SmallVector<OpFoldResult>(destRank, zero),
452
442
tensor::getMixedSizes (rewriter, loc, unPackOp.getDest ()),
453
443
SmallVector<OpFoldResult>(destRank, one));
454
444
455
- // 7 . Inject a copy to preserve DPS.
445
+ // 6 . Inject a copy to preserve DPS.
456
446
auto copyOp = rewriter.create <linalg::CopyOp>(
457
447
loc, extractSliceOp->getResult (0 ), unPackOp.getDest ());
458
448
459
- // 8 . Replace unPackOp by extractSliceOp .
449
+ // 7 . Replace unPackOp by copyOp .
460
450
rewriter.replaceOp (unPackOp, copyOp->getResults ());
461
451
462
452
return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
0 commit comments