Skip to content

[mlir][sparse] introduce sparse_tensor.extract_iteration_space operation. #88554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2024

Conversation

PeimingLiu
Copy link
Member

A sparse_tensor.extract_space %tensor at %iterator extracts a sparse iteration space defined %tensor, the operation to traverse the iteration space will be introduced in following PRs.

@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

A sparse_tensor.extract_space %tensor at %iterator extracts a sparse iteration space defined %tensor, the operation to traverse the iteration space will be introduced in following PRs.


Full diff: https://github.com/llvm/llvm-project/pull/88554.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+51)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td (+95)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+102)
  • (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+82)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0cfc64f9988a0a..b8e4edc8c8537b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1430,6 +1430,57 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+  let results = (outs AnySparseIterSpace:$resultSpace);
+
+  let summary = "Extract an iteration space from a sparse tensor between certain levels";
+  let description = [{
+      Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
+      certian (consecutive) levels.
+
+      `tensor`: the input sparse tensor that defines the iteration space.
+      `parentIter`: the iterator for the previous level, at which the iteration space
+      at the current levels will be extracted.
+      `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
+      the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
+      iteration space.
+
+      Example:
+      ```mlir
+      // Extracts a 1-D iteration space from a COO tensor at level 1.
+      %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+      ```
+  }];
+
+
+  let extraClassDeclaration = [{
+    std::pair<Level, Level> getLvlRange() {
+      return std::make_pair(getLoLvl(), getHiLvl());
+    }
+    unsigned getSpaceDim() {
+      return getHiLvl() - getLoLvl();
+    }
+    ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+      return getResultSpace().getType().getLvlTypes();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..264a0a5b3bee6c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
     : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
           "::mlir::sparse_tensor::StorageSpecifierType">;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+  let mnemonic = "iter_space";
+
+  let description = [{
+    A sparse iteration space that represents an abstract N-D (sparse) iteration space
+    extracted from a sparse tensor.
+
+    Examples:
+
+    ```mlir
+    // An iteration space extracted from a CSR tensor between levels [0, 2).
+    !iter_space<#CSR, lvls = 0 to 2>
+    ```
+  }];
+
+  let parameters = (ins
+     SparseTensorEncodingAttr : $encoding,
+     "Level" : $loLvl,
+     "Level" : $hiLvl
+  );
+
+  let extraClassDeclaration = [{
+     /// The the dimension of the iteration space.
+     unsigned getSpaceDim() const {
+       return getHiLvl() - getLoLvl();
+     }
+
+     /// Get the level types for the iteration space.
+     ArrayRef<LevelType> getLvlTypes() const {
+       return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
+     }
+
+     /// Whether the iteration space is unique (i.e., no duplicated coordinate).
+     bool isUnique() {
+       return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
+     }
+
+     /// Get the corresponding iterator type.
+     ::mlir::sparse_tensor::IteratorType getIteratorType() const;
+  }];
+
+  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+  let mnemonic = "iterator";
+
+  let description = [{
+    An iterator that points to the current element in the corresponding iteration space.
+
+    Examples:
+
+    ```mlir
+    // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
+    !iterator<#CSR, lvls = 0 to 2>
+    ```
+  }];
+
+  let parameters = (ins
+     SparseTensorEncodingAttr : $encoding,
+     "Level" : $loLvl,
+     "Level" : $hiLvl
+  );
+
+  let extraClassDeclaration = [{
+     /// Get the corresponding iteration space type.
+     ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
+
+     unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
+     ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
+     bool isUnique() { return getIterSpaceType().isUnique(); }
+  }];
+
+  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+          "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+          "::mlir::sparse_tensor::IteratorType">;
+
+
 #endif // SPARSETENSOR_TYPES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e9058394d33da5..a9d2d2b8826f37 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -30,6 +30,14 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
 
+// Forward declarations, following custom print/parsing methods are referenced
+// by the generated code for SparseTensorTypes.td.
+static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
+                                         mlir::sparse_tensor::Level &,
+                                         mlir::sparse_tensor::Level &);
+static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
+                            mlir::sparse_tensor::Level);
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
 
@@ -1953,6 +1961,100 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+IterSpaceType IteratorType::getIterSpaceType() const {
+  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
+                            getHiLvl());
+}
+
+IteratorType IterSpaceType::getIteratorType() const {
+  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
+}
+
+static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
+                                   Level &lvlHi) {
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+
+  if (lvlHi <= lvlLo)
+    parser.emitError(parser.getNameLoc(),
+                     "expect larger level upper bound than lower bound");
+
+  return success();
+}
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+  if (parseLevelRange(parser, lvlLo, lvlHi))
+    return failure();
+
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
+
+  if (lo + 1 == hi)
+    p << lo;
+  else
+    p << lo << " to " << hi;
+}
+
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+                            IntegerAttr lvlHi) {
+  unsigned lo = lvlLo.getValue().getZExtValue();
+  unsigned hi = lvlHi.getValue().getZExtValue();
+  printLevelRange(p, lo, hi);
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
+                                   adaptor.getHiLvl()));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError(
+        "parent iterator should be specified iff level lower bound equals 0");
+  }
+
+  if (pIter) {
+    IterSpaceType spaceTp = getResultSpace().getType();
+    if (pIter.getType().getEncoding() != spaceTp.getEncoding())
+      return emitOpError(
+          "mismatch in parent iterator encoding and iteration space encoding.");
+
+    if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
+      return emitOpError("parent iterator should be used to extract an "
+                         "iteration space from a consecutive level.");
+  }
+
+  return success();
+}
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7f5c05190fc9a2..579625e22f067f 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
   sparse_tensor.print %arg0 : tensor<10x10xf64>
   return
 }
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
+  // expected-error@+1 {{'sparse_tensor.iteration.extract_space' expect larger level upper bound than lower bound}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+  // expected-error@+1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
+  // expected-error@+1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be specified iff level lower bound equals 0}}
+  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
+  // expected-error@+1 {{'sparse_tensor.iteration.extract_space' op mismatch in parent iterator encoding and iteration space encoding.}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+  return
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
+  // expected-error@+1 {{'sparse_tensor.iteration.extract_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
+  %l1 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 12f69c1d37b9cd..6c0887ef4d826d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -738,3 +738,28 @@ func.func @sparse_has_runtime() -> i1 {
   %has_runtime = sparse_tensor.has_runtime_library
   return %has_runtime : i1
 }
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_extract_iter_space(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>)
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] lvls = 0
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.iteration.extract_space %[[VAL_0]] at %[[VAL_1]] lvls = 1
+// CHECK:           return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1>
+// CHECK:         }
+func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
+  -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
+  // Extracting the iteration space for the first level needs no parent iterator.
+  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // Extracting the iteration space for the second level needs a parent iterator.
+  %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
+}

@PeimingLiu PeimingLiu changed the title [mlir][sparse] introduce sparse_tensor.extract_space operation. [mlir][sparse] introduce sparse_tensor.extract_iteration_space operation. Apr 15, 2024
@PeimingLiu PeimingLiu force-pushed the cherry-picks branch 2 times, most recently from 8037c45 to 8f5f891 Compare April 15, 2024 17:04
PeimingLiu pushed a commit that referenced this pull request Apr 16, 2024
A `sparse_tensor.iterate` iterates over a sparse iteration space
extracted from `sparse_tensor.extract_iteration_space` operation
introduced in #88554.

*DO NOT MERGE* before #88554
@PeimingLiu PeimingLiu merged commit 481bd5d into llvm:main Apr 16, 2024
3 of 4 checks passed
@PeimingLiu PeimingLiu deleted the cherry-picks branch April 16, 2024 18:33
PeimingLiu pushed a commit that referenced this pull request Jun 10, 2024
A `sparse_tensor.iterate` iterates over a sparse iteration space
extracted from `sparse_tensor.extract_iteration_space` operation
introduced in #88554.
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Jun 12, 2024
A `sparse_tensor.iterate` iterates over a sparse iteration space
extracted from `sparse_tensor.extract_iteration_space` operation
introduced in llvm#88554.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants