Skip to content

Commit 2b8e679

Browse files
author
Peiming Liu
committed
update example code
1 parent 1945de6 commit 2b8e679

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14781478
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
14791479
iteration space.
14801480

1481-
The type of returned the value is automatically inferred to
1481+
The type of returned the value is must be
14821482
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
14831483
The returned iteration space can then be iterated over by
14841484
`sparse_tensor.iterate` operations to visit every stored element
@@ -1489,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14891489
// Extracts a 1-D iteration space from a COO tensor at level 1.
14901490
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
14911491
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1492+
->!sparse_tensor.iter_space<#COO, lvls = 1>
14921493
```
14931494
}];
14941495

@@ -1501,16 +1502,17 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15011502
return getHiLvl() - getLoLvl();
15021503
}
15031504
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1504-
return getResultSpace().getType().getLvlTypes();
1505+
return getExtractedSpace().getType().getLvlTypes();
15051506
}
15061507
}];
15071508

15081509
let arguments = (ins AnySparseTensor:$tensor,
15091510
Optional<AnySparseIterator>:$parentIter,
15101511
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1511-
let results = (outs AnySparseIterSpace:$resultSpace);
1512+
let results = (outs AnySparseIterSpace:$extractedSpace);
15121513
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1513-
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
1514+
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
1515+
"`->` qualified(type($extractedSpace))";
15141516

15151517
let hasVerifier = 1;
15161518
}
@@ -1567,12 +1569,14 @@ def IterateOp : SparseTensor_Op<"iterate",
15671569
```mlir
15681570
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
15691571
// Iterates over the first level of %sp
1570-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1572+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
1573+
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
15711574
%r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
15721575
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
15731576
// Iterates over the second level of %sp
15741577
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
15751578
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1579+
-> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
15761580
%r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
15771581
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
15781582
vector.print %coord0 : index

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2253,7 +2253,7 @@ LogicalResult ExtractIterSpaceOp::verify() {
22532253
}
22542254

22552255
if (pIter) {
2256-
IterSpaceType spaceTp = getResultSpace().getType();
2256+
IterSpaceType spaceTp = getExtractedSpace().getType();
22572257
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
22582258
return emitOpError(
22592259
"mismatch in parent iterator encoding and iteration space encoding.");

0 commit comments

Comments
 (0)