Skip to content

Commit 12189f8

Browse files
author
Peiming Liu
authored
[mlir][sparse] introduce sparse_tensor.extract_value operation. (#101219)
1 parent 99fb40d commit 12189f8

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,31 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15311531
let hasVerifier = 1;
15321532
}
15331533

1534+
def ExtractValOp : SparseTensor_Op<"extract_value", [
1535+
Pure,
1536+
TypesMatchWith<"result type matches element type of tensor",
1537+
"tensor", "result",
1538+
"::llvm::cast<TensorType>($_self).getElementType()">]> {
1539+
let summary = "Extracts a value from a sparse tensor using an iterator.";
1540+
let description = [{
1541+
The `sparse_tensor.extract_value` operation extracts the value
1542+
pointed to by a sparse iterator from a sparse tensor.
1543+
1544+
Example:
1545+
1546+
```mlir
1547+
%val = sparse_tensor.extract_value %sp at %it
1548+
: tensor<?x?xf32, #CSR>, !sparse_tensor.iterator<#CSR, lvl = 1>
1549+
```
1550+
}];
1551+
1552+
let arguments = (ins AnySparseTensor:$tensor, AnySparseIterator:$iterator);
1553+
let results = (outs AnyType:$result);
1554+
1555+
let assemblyFormat = "$tensor `at` $iterator attr-dict `:` type($tensor)`,` qualified(type($iterator))";
1556+
let hasVerifier = 1;
1557+
}
1558+
15341559
def IterateOp : SparseTensor_Op<"iterate",
15351560
[RecursiveMemoryEffects, RecursivelySpeculatable,
15361561
DeclareOpInterfaceMethods<LoopLikeOpInterface,

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,19 @@ LogicalResult ExtractIterSpaceOp::verify() {
22672267
return success();
22682268
}
22692269

2270+
LogicalResult ExtractValOp::verify() {
2271+
auto stt = getSparseTensorType(getTensor());
2272+
auto itTp = getIterator().getType();
2273+
2274+
if (stt.getEncoding() != itTp.getEncoding())
2275+
return emitOpError("mismatch in tensor encoding and iterator encoding.");
2276+
2277+
if (stt.getLvlRank() != itTp.getHiLvl())
2278+
return emitOpError("must use last-level iterator to extract values. ");
2279+
2280+
return success();
2281+
}
2282+
22702283
struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
22712284
using OpRewritePattern::OpRewritePattern;
22722285

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,42 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
10991099
return
11001100
}
11011101

1102+
// -----
1103+
1104+
#COO = #sparse_tensor.encoding<{
1105+
map = (i, j) -> (
1106+
i : compressed(nonunique),
1107+
j : singleton(soa)
1108+
)
1109+
}>
1110+
1111+
#CSR = #sparse_tensor.encoding<{
1112+
map = (i, j) -> (
1113+
i : dense,
1114+
j : compressed
1115+
)
1116+
}>
1117+
1118+
func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 1>) -> f32 {
1119+
// expected-error@+1 {{'sparse_tensor.extract_value' op mismatch in tensor encoding and iterator encoding.}}
1120+
%f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 1>
1121+
return %f : f32
1122+
}
1123+
1124+
// -----
1125+
1126+
#COO = #sparse_tensor.encoding<{
1127+
map = (i, j) -> (
1128+
i : compressed(nonunique),
1129+
j : singleton(soa)
1130+
)
1131+
}>
1132+
1133+
func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) -> f32 {
1134+
// expected-error@+1 {{'sparse_tensor.extract_value' op must use last-level iterator to extract values.}}
1135+
%f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1136+
return %f : f32
1137+
}
11021138

11031139
// -----
11041140

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,27 @@ func.func @sparse_has_runtime() -> i1 {
739739
return %has_runtime : i1
740740
}
741741

742+
// -----
743+
744+
#COO = #sparse_tensor.encoding<{
745+
map = (i, j) -> (
746+
i : compressed(nonunique),
747+
j : singleton(soa)
748+
)
749+
}>
750+
751+
// CHECK-LABEL: func.func @sparse_extract_value(
752+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
753+
// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse, lvls = 1>) -> f32 {
754+
// CHECK: %[[VAL_2:.*]] = sparse_tensor.extract_value %[[VAL_0]] at %[[VAL_1]] : tensor<4x8xf32, #sparse>, !sparse_tensor.iterator<#sparse, lvls = 1>
755+
// CHECK: return %[[VAL_2]] : f32
756+
// CHECK: }
757+
func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 1>) -> f32 {
758+
%f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 1>
759+
return %f : f32
760+
}
761+
762+
742763
// -----
743764

744765
#COO = #sparse_tensor.encoding<{

0 commit comments

Comments
 (0)