@@ -587,14 +587,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
587
587
588
588
void SparseTensorEncodingAttr::print (AsmPrinter &printer) const {
589
589
auto map = static_cast <AffineMap>(getDimToLvl ());
590
- auto lvlTypes = getLvlTypes ();
591
590
// Empty affine map indicates identity map
592
591
if (!map) {
593
592
map = AffineMap::getMultiDimIdentityMap (getLvlTypes ().size (), getContext ());
594
593
}
595
- // Modified version of AsmPrinter::Impl::printAffineMap.
596
594
printer << " <{ map = " ;
597
- // Symbolic identifiers.
595
+ printSymbol (map, printer);
596
+ printer << ' (' ;
597
+ printDimension (map, printer, getDimSlices ());
598
+ printer << " ) -> (" ;
599
+ printLevel (map, printer, getLvlTypes ());
600
+ printer << ' )' ;
601
+ // Print remaining members only for non-default values.
602
+ if (getPosWidth ())
603
+ printer << " , posWidth = " << getPosWidth ();
604
+ if (getCrdWidth ())
605
+ printer << " , crdWidth = " << getCrdWidth ();
606
+ printer << " }>" ;
607
+ }
608
+
609
+ void SparseTensorEncodingAttr::printSymbol (AffineMap &map,
610
+ AsmPrinter &printer) const {
598
611
if (map.getNumSymbols () != 0 ) {
599
612
printer << ' [' ;
600
613
for (unsigned i = 0 ; i < map.getNumSymbols () - 1 ; ++i)
@@ -603,9 +616,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
603
616
printer << ' s' << map.getNumSymbols () - 1 ;
604
617
printer << ' ]' ;
605
618
}
606
- // Dimension identifiers.
607
- printer << ' (' ;
608
- auto dimSlices = getDimSlices ();
619
+ }
620
+
621
+ void SparseTensorEncodingAttr::printDimension (
622
+ AffineMap &map, AsmPrinter &printer,
623
+ ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
609
624
if (!dimSlices.empty ()) {
610
625
for (unsigned i = 0 ; i < map.getNumDims () - 1 ; ++i)
611
626
printer << ' d' << i << " : " << dimSlices[i] << " , " ;
@@ -618,9 +633,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
618
633
if (map.getNumDims () >= 1 )
619
634
printer << ' d' << map.getNumDims () - 1 ;
620
635
}
621
- printer << ' )' ;
622
- // Level format and properties.
623
- printer << " -> (" ;
636
+ }
637
+
638
+ void SparseTensorEncodingAttr::printLevel (
639
+ AffineMap &map, AsmPrinter &printer,
640
+ ArrayRef<DimLevelType> lvlTypes) const {
624
641
for (unsigned i = 0 ; i < map.getNumResults () - 1 ; ++i) {
625
642
map.getResult (i).print (printer.getStream ());
626
643
printer << " : " << toMLIRString (lvlTypes[i]) << " , " ;
@@ -630,13 +647,6 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
630
647
map.getResult (lastIndex).print (printer.getStream ());
631
648
printer << " : " << toMLIRString (lvlTypes[lastIndex]);
632
649
}
633
- printer << ' )' ;
634
- // Print remaining members only for non-default values.
635
- if (getPosWidth ())
636
- printer << " , posWidth = " << getPosWidth ();
637
- if (getCrdWidth ())
638
- printer << " , crdWidth = " << getCrdWidth ();
639
- printer << " }>" ;
640
650
}
641
651
642
652
LogicalResult
0 commit comments