Skip to content

Commit 43e9638

Browse files
author
Peiming Liu
committed
[mlir][sparse] introduce sparse_tensor.extract_space operation that
extracts a sparse iteration space to iterate over.
1 parent 3652b2a commit 43e9638

File tree

5 files changed

+355
-0
lines changed

5 files changed

+355
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,57 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14301430
let hasVerifier = 1;
14311431
}
14321432

1433+
//===----------------------------------------------------------------------===//
1434+
// Sparse Tensor Iteration Operations.
1435+
//===----------------------------------------------------------------------===//
1436+
1437+
def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
1438+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1439+
1440+
let arguments = (ins AnySparseTensor:$tensor,
1441+
Optional<AnySparseIterator>:$parentIter,
1442+
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1443+
1444+
let results = (outs AnySparseIterSpace:$resultSpace);
1445+
1446+
let summary = "Extract an iteration space from a sparse tensor between certain levels";
1447+
let description = [{
1448+
Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1449+
certian (consecutive) levels.
1450+
1451+
`tensor`: the input sparse tensor that defines the iteration space.
1452+
`parentIter`: the iterator for the previous level, at which the iteration space
1453+
at the current levels will be extracted.
1454+
`loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1455+
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1456+
iteration space.
1457+
1458+
Example:
1459+
```mlir
1460+
// Extracts a 1-D iteration space from a COO tensor at level 1.
1461+
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1462+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1463+
```
1464+
}];
1465+
1466+
1467+
let extraClassDeclaration = [{
1468+
std::pair<Level, Level> getLvlRange() {
1469+
return std::make_pair(getLoLvl(), getHiLvl());
1470+
}
1471+
unsigned getSpaceDim() {
1472+
return getHiLvl() - getLoLvl();
1473+
}
1474+
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1475+
return getResultSpace().getType().getLvlTypes();
1476+
}
1477+
}];
1478+
1479+
let hasVerifier = 1;
1480+
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1481+
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1482+
}
1483+
14331484
//===----------------------------------------------------------------------===//
14341485
// Sparse Tensor Debugging and Test-Only Operations.
14351486
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
7272
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
7373
"::mlir::sparse_tensor::StorageSpecifierType">;
7474

75+
//===----------------------------------------------------------------------===//
76+
// Sparse Tensor Iteration Types.
77+
//===----------------------------------------------------------------------===//
78+
79+
def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
80+
let mnemonic = "iter_space";
81+
82+
let description = [{
83+
A sparse iteration space that represents an abstract N-D (sparse) iteration space
84+
extracted from a sparse tensor.
85+
86+
Examples:
87+
88+
```mlir
89+
// An iteration space extracted from a CSR tensor between levels [0, 2).
90+
!iter_space<#CSR, lvls = 0 to 2>
91+
```
92+
}];
93+
94+
let parameters = (ins
95+
SparseTensorEncodingAttr : $encoding,
96+
"Level" : $loLvl,
97+
"Level" : $hiLvl
98+
);
99+
100+
let extraClassDeclaration = [{
101+
/// The the dimension of the iteration space.
102+
unsigned getSpaceDim() const {
103+
return getHiLvl() - getLoLvl();
104+
}
105+
106+
/// Get the level types for the iteration space.
107+
ArrayRef<LevelType> getLvlTypes() const {
108+
return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
109+
}
110+
111+
/// Whether the iteration space is unique (i.e., no duplicated coordinate).
112+
bool isUnique() {
113+
return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
114+
}
115+
116+
/// Get the corresponding iterator type.
117+
::mlir::sparse_tensor::IteratorType getIteratorType() const;
118+
}];
119+
120+
let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
121+
}
122+
123+
def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
124+
let mnemonic = "iterator";
125+
126+
let description = [{
127+
An iterator that points to the current element in the corresponding iteration space.
128+
129+
Examples:
130+
131+
```mlir
132+
// An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
133+
!iterator<#CSR, lvls = 0 to 2>
134+
```
135+
}];
136+
137+
let parameters = (ins
138+
SparseTensorEncodingAttr : $encoding,
139+
"Level" : $loLvl,
140+
"Level" : $hiLvl
141+
);
142+
143+
let extraClassDeclaration = [{
144+
/// Get the corresponding iteration space type.
145+
::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
146+
147+
unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
148+
ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
149+
bool isUnique() { return getIterSpaceType().isUnique(); }
150+
}];
151+
152+
let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
153+
}
154+
155+
def IsSparseSparseIterSpaceTypePred
156+
: CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
157+
158+
def IsSparseSparseIteratorTypePred
159+
: CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
160+
161+
def AnySparseIterSpace
162+
: Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
163+
"::mlir::sparse_tensor::IterSpaceType">;
164+
165+
def AnySparseIterator
166+
: Type<IsSparseSparseIteratorTypePred, "sparse iterator",
167+
"::mlir::sparse_tensor::IteratorType">;
168+
169+
75170
#endif // SPARSETENSOR_TYPES

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
3131
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
3232

33+
// Forward declarations, following custom print/parsing methods are referenced
34+
// by the generated code for SparseTensorTypes.td.
35+
static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
36+
mlir::sparse_tensor::Level &,
37+
mlir::sparse_tensor::Level &);
38+
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
39+
mlir::sparse_tensor::Level);
40+
3341
#define GET_TYPEDEF_CLASSES
3442
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
3543

@@ -1953,6 +1961,100 @@ LogicalResult SortOp::verify() {
19531961
return success();
19541962
}
19551963

1964+
//===----------------------------------------------------------------------===//
1965+
// Sparse Tensor Iteration Operations.
1966+
//===----------------------------------------------------------------------===//
1967+
1968+
IterSpaceType IteratorType::getIterSpaceType() const {
1969+
return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
1970+
getHiLvl());
1971+
}
1972+
1973+
IteratorType IterSpaceType::getIteratorType() const {
1974+
return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
1975+
}
1976+
1977+
static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
1978+
Level &lvlHi) {
1979+
if (parser.parseInteger(lvlLo))
1980+
return failure();
1981+
1982+
if (succeeded(parser.parseOptionalKeyword("to"))) {
1983+
if (parser.parseInteger(lvlHi))
1984+
return failure();
1985+
} else {
1986+
lvlHi = lvlLo + 1;
1987+
}
1988+
1989+
if (lvlHi <= lvlLo)
1990+
parser.emitError(parser.getNameLoc(),
1991+
"expect larger level upper bound than lower bound");
1992+
1993+
return success();
1994+
}
1995+
1996+
static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
1997+
IntegerAttr &lvlHiAttr) {
1998+
Level lvlLo, lvlHi;
1999+
if (parseLevelRange(parser, lvlLo, lvlHi))
2000+
return failure();
2001+
2002+
lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2003+
lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2004+
return success();
2005+
}
2006+
2007+
static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2008+
2009+
if (lo + 1 == hi)
2010+
p << lo;
2011+
else
2012+
p << lo << " to " << hi;
2013+
}
2014+
2015+
static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2016+
IntegerAttr lvlHi) {
2017+
unsigned lo = lvlLo.getValue().getZExtValue();
2018+
unsigned hi = lvlHi.getValue().getZExtValue();
2019+
printLevelRange(p, lo, hi);
2020+
}
2021+
2022+
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2023+
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2024+
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2025+
SmallVectorImpl<mlir::Type> &ret) {
2026+
2027+
ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2028+
SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2029+
ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2030+
adaptor.getHiLvl()));
2031+
return success();
2032+
}
2033+
2034+
LogicalResult ExtractIterSpaceOp::verify() {
2035+
if (getLoLvl() >= getHiLvl())
2036+
return emitOpError("expected smaller level low than level high");
2037+
2038+
TypedValue<IteratorType> pIter = getParentIter();
2039+
if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2040+
return emitOpError(
2041+
"parent iterator should be specified iff level lower bound equals 0");
2042+
}
2043+
2044+
if (pIter) {
2045+
IterSpaceType spaceTp = getResultSpace().getType();
2046+
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2047+
return emitOpError(
2048+
"mismatch in parent iterator encoding and iteration space encoding.");
2049+
2050+
if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2051+
return emitOpError("parent iterator should be used to extract an "
2052+
"iteration space from a consecutive level.");
2053+
}
2054+
2055+
return success();
2056+
}
2057+
19562058
/// Materialize a single constant operation from a given attribute value with
19572059
/// the desired resultant type.
19582060
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
10121012
sparse_tensor.print %arg0 : tensor<10x10xf64>
10131013
return
10141014
}
1015+
1016+
// -----
1017+
1018+
#COO = #sparse_tensor.encoding<{
1019+
map = (i, j) -> (
1020+
i : compressed(nonunique),
1021+
j : singleton(soa)
1022+
)
1023+
}>
1024+
1025+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
1026+
// expected-error@+1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
1027+
%l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
1028+
return
1029+
}
1030+
1031+
// -----
1032+
1033+
#COO = #sparse_tensor.encoding<{
1034+
map = (i, j) -> (
1035+
i : compressed(nonunique),
1036+
j : singleton(soa)
1037+
)
1038+
}>
1039+
1040+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
1041+
// expected-error@+1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
1042+
%l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1043+
return
1044+
}
1045+
1046+
// -----
1047+
1048+
#COO = #sparse_tensor.encoding<{
1049+
map = (i, j) -> (
1050+
i : compressed(nonunique),
1051+
j : singleton(soa)
1052+
)
1053+
}>
1054+
1055+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
1056+
// expected-error@+1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
1057+
%l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
1058+
return
1059+
}
1060+
1061+
// -----
1062+
1063+
#COO = #sparse_tensor.encoding<{
1064+
map = (i, j) -> (
1065+
i : compressed(nonunique),
1066+
j : singleton(soa)
1067+
)
1068+
}>
1069+
1070+
#CSR = #sparse_tensor.encoding<{
1071+
map = (i, j) -> (
1072+
i : dense,
1073+
j : compressed
1074+
)
1075+
}>
1076+
1077+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
1078+
// expected-error@+1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
1079+
%l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
1080+
return
1081+
}
1082+
1083+
// -----
1084+
1085+
#COO = #sparse_tensor.encoding<{
1086+
map = (i, j) -> (
1087+
i : compressed(nonunique),
1088+
j : singleton(soa)
1089+
)
1090+
}>
1091+
1092+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
1093+
// expected-error@+1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
1094+
%l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1095+
return
1096+
}

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,3 +738,28 @@ func.func @sparse_has_runtime() -> i1 {
738738
%has_runtime = sparse_tensor.has_runtime_library
739739
return %has_runtime : i1
740740
}
741+
742+
// -----
743+
744+
#COO = #sparse_tensor.encoding<{
745+
map = (i, j) -> (
746+
i : compressed(nonunique),
747+
j : singleton(soa)
748+
)
749+
}>
750+
751+
// CHECK-LABEL: func.func @sparse_extract_iter_space(
752+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
753+
// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
754+
// CHECK: %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
755+
// CHECK: %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
756+
// CHECK: return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
757+
// CHECK: }
758+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
759+
-> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
760+
// Extracting the iteration space for the first level needs no parent iterator.
761+
%l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
762+
// Extracting the iteration space for the second level needs a parent iterator.
763+
%l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
764+
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
765+
}

0 commit comments

Comments
 (0)