@@ -385,6 +385,106 @@ struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
385
385
}
386
386
};
387
387
388
+ // / Sparse rewriting rule for sparse-to-sparse reshape operator.
389
+ struct TensorReshapeRewriter : public OpRewritePattern <tensor::ReshapeOp> {
390
+ public:
391
+ using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
392
+
393
+ LogicalResult matchAndRewrite (tensor::ReshapeOp op,
394
+ PatternRewriter &rewriter) const override {
395
+ Location loc = op.getLoc ();
396
+ Value srcTensor = op.getSource ();
397
+ const auto srcTp = getSparseTensorType (srcTensor);
398
+ const auto dstTp = getSparseTensorType (op.getResult ());
399
+
400
+ if (!srcTp.hasEncoding () || !dstTp.hasEncoding () ||
401
+ !dstTp.hasStaticDimShape ())
402
+ return failure ();
403
+
404
+ SmallVector<Value> srcSizes;
405
+ sizesForTensor (rewriter, srcSizes, loc, srcTp, srcTensor);
406
+ SmallVector<Value> dstSizes;
407
+ for (Dimension d : dstTp.getDimShape ())
408
+ dstSizes.push_back (constantIndex (rewriter, loc, d));
409
+
410
+ Value nnz = rewriter.create <NumberOfEntriesOp>(loc, srcTensor);
411
+ // Only need an unordered COO buffer if input and output are not sorted
412
+ // in the same way.
413
+ Type bufferTp =
414
+ srcTp.isAllOrdered () && srcTp.isIdentity () && dstTp.isIdentity ()
415
+ ? dstTp.getRankedTensorType ()
416
+ : getUnorderedCOOFromType (dstTp);
417
+ SmallVector<Value> dynSizes;
418
+ Value buffer = rewriter
419
+ .create <AllocTensorOp>(loc, bufferTp, dynSizes, Value (),
420
+ nnz, Attribute ())
421
+ .getResult ();
422
+
423
+ // Convert src coordinates to dst coordinates by first collapsing it to 1D
424
+ // and then expand it to the match the rank of the destination tensor.
425
+ // Implemented as follows:
426
+ // foreach srcCoords %srcTensor
427
+ // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
428
+ // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
429
+ // insert expandedCoords, %buffer
430
+ //
431
+ // followed by an optional
432
+ // %t = sparse_tensor.cast %tmp
433
+ // depending on whether the input/output are sorted in the same way.
434
+ const auto encSrc = srcTp.getEncoding ();
435
+ ForeachOp foreachOp = rewriter.create <ForeachOp>(
436
+ loc, srcTensor, buffer,
437
+ [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
438
+ ValueRange reduc) {
439
+ const Dimension srcRank = srcTp.getDimRank ();
440
+ SmallVector<Value> srcDcvs;
441
+ srcDcvs.reserve (srcRank);
442
+ for (Dimension d = 0 ; d < srcRank; d++) {
443
+ // FIXME: `toStoredDim` is deprecated
444
+ Level lvl = toStoredDim (encSrc, d);
445
+ srcDcvs.push_back (srcLcvs[lvl]);
446
+ }
447
+
448
+ Value collapsed_size = constantIndex (builder, loc, 1 );
449
+ for (Dimension d = 0 ; d < srcRank; d++)
450
+ collapsed_size =
451
+ builder.create <arith::MulIOp>(loc, collapsed_size, srcSizes[d]);
452
+ SmallVector<Value, 1 > collapsedSizes = {collapsed_size};
453
+
454
+ ReassociationIndices collapse_indices;
455
+ for (Dimension i = 0 ; i < srcRank; i++)
456
+ collapse_indices.push_back (i);
457
+ SmallVector<ReassociationIndices, 1 > collapse_reassociation = {
458
+ collapse_indices};
459
+ SmallVector<Value, 1 > collapsedDcvs;
460
+ reshapeCvs (builder, loc, collapse_reassociation, srcSizes, srcDcvs,
461
+ collapsedSizes, collapsedDcvs);
462
+
463
+ ReassociationIndices expand_indices;
464
+ for (Dimension i = 0 ; i < dstTp.getDimRank (); i++)
465
+ expand_indices.push_back (i);
466
+ SmallVector<ReassociationIndices, 1 > expand_reassociation = {
467
+ expand_indices};
468
+ SmallVector<Value> dstDcvs;
469
+ reshapeCvs (builder, loc, expand_reassociation, collapsedSizes,
470
+ collapsedDcvs, dstSizes, dstDcvs);
471
+
472
+ auto t = builder.create <InsertOp>(loc, v, reduc.front (), dstDcvs);
473
+ builder.create <sparse_tensor::YieldOp>(loc, t);
474
+ });
475
+
476
+ Value t = rewriter.create <LoadOp>(loc, foreachOp.getResult (0 ), true );
477
+ if (bufferTp != dstTp) {
478
+ auto dstRTT = dstTp.getRankedTensorType ();
479
+ Value converted = rewriter.create <ConvertOp>(loc, dstRTT, t).getResult ();
480
+ rewriter.create <DeallocTensorOp>(loc, t);
481
+ t = converted;
482
+ }
483
+ rewriter.replaceOp (op, t);
484
+ return success ();
485
+ }
486
+ };
487
+
388
488
// / Sparse rewriting rule for sparse-to-sparse reshape operator.
389
489
template <typename ReshapeOp>
390
490
struct Sparse2SparseReshapeRewriter : public OpRewritePattern <ReshapeOp> {
@@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
1169
1269
bool enableForeach,
1170
1270
bool enableConvert) {
1171
1271
patterns.add <ReshapeRewriter<tensor::ExpandShapeOp>,
1172
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext ());
1272
+ ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>(
1273
+ patterns.getContext ());
1173
1274
if (enableForeach)
1174
1275
patterns.add <ForeachRewriter>(patterns.getContext ());
1175
1276
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
0 commit comments