Skip to content

[mlir][sparse] introduce sparse_tensor.extract_iteration_space operation. #88554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,66 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//

def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {

let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);

let results = (outs AnySparseIterSpace:$resultSpace);

let summary = "Extracts an iteration space from a sparse tensor between certain levels";
let description = [{
Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
certain (consecutive) levels. For sparse levels, it is usually done by
loading a postion range from the underlying sparse tensor storage.
E.g., for a compressed level, the iteration space is extracted by
[pos[i], pos[i+1]) supposing the the parent iterator points at `i`.

`tensor`: the input sparse tensor that defines the iteration space.
`parentIter`: the iterator for the previous level, at which the iteration space
at the current levels will be extracted.
`loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
iteration space.

The type of returned the value is automatically inferred to
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
The returned iteration space can then be iterated over by
`sparse_tensor.iterate` operations to visit every stored element
(usually nonzeros) in the input sparse tensor.

Example:
```mlir
// Extracts a 1-D iteration space from a COO tensor at level 1.
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
```
}];


let extraClassDeclaration = [{
std::pair<Level, Level> getLvlRange() {
return std::make_pair(getLoLvl(), getHiLvl());
}
unsigned getSpaceDim() {
return getHiLvl() - getLoLvl();
}
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
return getResultSpace().getType().getLvlTypes();
}
}];

let hasVerifier = 1;
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
Expand Down
97 changes: 97 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,101 @@ def SparseTensorStorageSpecifier
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
"::mlir::sparse_tensor::StorageSpecifierType">;

//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Types.
//===----------------------------------------------------------------------===//

def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
let mnemonic = "iter_space";

let description = [{
A sparse iteration space that represents an abstract N-D (sparse) iteration space
extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for
every stored element (usually nonzeros) in a sparse tensor between the specified
[$loLvl, $hiLvl) levels.

Examples:

```mlir
// An iteration space extracted from a CSR tensor between levels [0, 2).
!iter_space<#CSR, lvls = 0 to 2>
```
}];

let parameters = (ins
SparseTensorEncodingAttr : $encoding,
"Level" : $loLvl,
"Level" : $hiLvl
);

let extraClassDeclaration = [{
/// The the dimension of the iteration space.
unsigned getSpaceDim() const {
return getHiLvl() - getLoLvl();
}

/// Get the level types for the iteration space.
ArrayRef<LevelType> getLvlTypes() const {
return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
}

/// Whether the iteration space is unique (i.e., no duplicated coordinate).
bool isUnique() {
return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
}

/// Get the corresponding iterator type.
::mlir::sparse_tensor::IteratorType getIteratorType() const;
}];

let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
}

def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
let mnemonic = "iterator";

let description = [{
An iterator that points to the current element in the corresponding iteration space.

Examples:

```mlir
// An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
!iterator<#CSR, lvls = 0 to 2>
```
}];

let parameters = (ins
SparseTensorEncodingAttr : $encoding,
"Level" : $loLvl,
"Level" : $hiLvl
);

let extraClassDeclaration = [{
/// Get the corresponding iteration space type.
::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;

unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
bool isUnique() { return getIterSpaceType().isUnique(); }
}];

let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
}

def IsSparseSparseIterSpaceTypePred
: CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;

def IsSparseSparseIteratorTypePred
: CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;

def AnySparseIterSpace
: Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
"::mlir::sparse_tensor::IterSpaceType">;

def AnySparseIterator
: Type<IsSparseSparseIteratorTypePred, "sparse iterator",
"::mlir::sparse_tensor::IteratorType">;


#endif // SPARSETENSOR_TYPES
110 changes: 110 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"

// Forward declarations, following custom print/parsing methods are referenced
// by the generated code for SparseTensorTypes.td.
static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
mlir::sparse_tensor::Level &,
mlir::sparse_tensor::Level &);
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
mlir::sparse_tensor::Level);

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"

Expand Down Expand Up @@ -1953,6 +1961,108 @@ LogicalResult SortOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//

IterSpaceType IteratorType::getIterSpaceType() const {
return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
getHiLvl());
}

IteratorType IterSpaceType::getIteratorType() const {
return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
}

/// Parses a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
Level &lvlHi) {
if (parser.parseInteger(lvlLo))
return failure();

if (succeeded(parser.parseOptionalKeyword("to"))) {
if (parser.parseInteger(lvlHi))
return failure();
} else {
lvlHi = lvlLo + 1;
}

if (lvlHi <= lvlLo)
parser.emitError(parser.getNameLoc(),
"expect larger level upper bound than lower bound");

return success();
}

/// Parses a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
IntegerAttr &lvlHiAttr) {
Level lvlLo, lvlHi;
if (parseLevelRange(parser, lvlLo, lvlHi))
return failure();

lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
return success();
}

/// Prints a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {

if (lo + 1 == hi)
p << lo;
else
p << lo << " to " << hi;
}

/// Prints a level range in the form "$lo `to` $hi"
/// or simply "$lo" if $hi - $lo = 1
static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
IntegerAttr lvlHi) {
unsigned lo = lvlLo.getValue().getZExtValue();
unsigned hi = lvlHi.getValue().getZExtValue();
printLevelRange(p, lo, hi);
}

LogicalResult ExtractIterSpaceOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) {

ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
adaptor.getHiLvl()));
return success();
}

LogicalResult ExtractIterSpaceOp::verify() {
if (getLoLvl() >= getHiLvl())
return emitOpError("expected smaller level low than level high");

TypedValue<IteratorType> pIter = getParentIter();
if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
return emitOpError(
"parent iterator should be specified iff level lower bound equals 0");
}

if (pIter) {
IterSpaceType spaceTp = getResultSpace().getType();
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
return emitOpError(
"mismatch in parent iterator encoding and iteration space encoding.");

if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
return emitOpError("parent iterator should be used to extract an "
"iteration space from a consecutive level.");
}

return success();
}

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
Expand Down
82 changes: 82 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
sparse_tensor.print %arg0 : tensor<10x10xf64>
return
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
return
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
return
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

#CSR = #sparse_tensor.encoding<{
map = (i, j) -> (
i : dense,
j : compressed
)
}>

func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
return
}

// -----

#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>

func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
return
}
Loading
Loading