Skip to content

Commit 5e30002

Browse files
address reivew comments
1 parent 2f6c890 commit 5e30002

File tree

3 files changed

+41
-28
lines changed

3 files changed

+41
-28
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
422422
std::optional<uint64_t> getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const;
423423
std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
424424
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
425+
426+
//
427+
// Printing methods.
428+
//
429+
430+
void printSymbol(AffineMap &map, AsmPrinter &printer) const;
431+
void printDimension(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
432+
void printLevel(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::DimLevelType> lvlTypes) const;
425433
}];
426434

427435
let genVerifyDecl = 1;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -587,14 +587,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
587587

588588
void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
589589
auto map = static_cast<AffineMap>(getDimToLvl());
590-
auto lvlTypes = getLvlTypes();
591590
// Empty affine map indicates identity map
592591
if (!map) {
593592
map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
594593
}
595-
// Modified version of AsmPrinter::Impl::printAffineMap.
596594
printer << "<{ map = ";
597-
// Symbolic identifiers.
595+
printSymbol(map, printer);
596+
printer << '(';
597+
printDimension(map, printer, getDimSlices());
598+
printer << ") -> (";
599+
printLevel(map, printer, getLvlTypes());
600+
printer << ')';
601+
// Print remaining members only for non-default values.
602+
if (getPosWidth())
603+
printer << ", posWidth = " << getPosWidth();
604+
if (getCrdWidth())
605+
printer << ", crdWidth = " << getCrdWidth();
606+
printer << " }>";
607+
}
608+
609+
void SparseTensorEncodingAttr::printSymbol(AffineMap &map,
610+
AsmPrinter &printer) const {
598611
if (map.getNumSymbols() != 0) {
599612
printer << '[';
600613
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
@@ -603,9 +616,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
603616
printer << 's' << map.getNumSymbols() - 1;
604617
printer << ']';
605618
}
606-
// Dimension identifiers.
607-
printer << '(';
608-
auto dimSlices = getDimSlices();
619+
}
620+
621+
void SparseTensorEncodingAttr::printDimension(
622+
AffineMap &map, AsmPrinter &printer,
623+
ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
609624
if (!dimSlices.empty()) {
610625
for (unsigned i = 0; i < map.getNumDims() - 1; ++i)
611626
printer << 'd' << i << " : " << dimSlices[i] << ", ";
@@ -618,9 +633,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
618633
if (map.getNumDims() >= 1)
619634
printer << 'd' << map.getNumDims() - 1;
620635
}
621-
printer << ')';
622-
// Level format and properties.
623-
printer << " -> (";
636+
}
637+
638+
void SparseTensorEncodingAttr::printLevel(
639+
AffineMap &map, AsmPrinter &printer,
640+
ArrayRef<DimLevelType> lvlTypes) const {
624641
for (unsigned i = 0; i < map.getNumResults() - 1; ++i) {
625642
map.getResult(i).print(printer.getStream());
626643
printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
@@ -630,13 +647,6 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
630647
map.getResult(lastIndex).print(printer.getStream());
631648
printer << " : " << toMLIRString(lvlTypes[lastIndex]);
632649
}
633-
printer << ')';
634-
// Print remaining members only for non-default values.
635-
if (getPosWidth())
636-
printer << ", posWidth = " << getPosWidth();
637-
if (getCrdWidth())
638-
printer << ", crdWidth = " << getCrdWidth();
639-
printer << " }>";
640650
}
641651

642652
LogicalResult

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,13 @@ class SparseInsertGenerator
474474
const Level lvlRank = stt.getLvlRank();
475475
for (Level l = 0; l < lvlRank; l++) {
476476
std::string lvlType = toMLIRString(stt.getLvlType(l));
477-
replaceWithUnderscore(lvlType);
477+
// Replace/remove punctuations in level properties.
478+
std::replace_if(
479+
lvlType.begin(), lvlType.end(),
480+
[](char c) { return c == '(' || c == ','; }, '_');
481+
lvlType.erase(std::remove_if(lvlType.begin(), lvlType.end(),
482+
[](char c) { return c == ')' || c == ' '; }),
483+
lvlType.end());
478484
nameOstream << lvlType << "_";
479485
}
480486
// Static dim sizes are used in the generated code while dynamic sizes are
@@ -492,17 +498,6 @@ class SparseInsertGenerator
492498

493499
private:
494500
TensorType rtp;
495-
void replaceWithUnderscore(std::string &lvlType) {
496-
for (auto it = lvlType.begin(); it != lvlType.end();) {
497-
if (*it == '(' || *it == ',') {
498-
*it = '_';
499-
} else if (*it == ')' || *it == ' ') {
500-
it = lvlType.erase(it);
501-
continue;
502-
}
503-
it++;
504-
}
505-
}
506501
};
507502

508503
/// Generations insertion finalization code.

0 commit comments

Comments
 (0)