Skip to content

Commit fdbe931

Browse files
committed
[mlir][sparse] Adding getters/setters to DimLvlMap
Depends On D156768 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D156770
1 parent 8a9c51c commit fdbe931

File tree

4 files changed

+52
-18
lines changed

4 files changed

+52
-18
lines changed

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,8 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
262262
// needs to happen before the code for setting every `LvlSpec::elideVar`,
263263
// since if the LvlVar is only used in elided DimExpr, then the
264264
// LvlVar should also be elided.
265-
// NOTE: Whenever we set a new DimExpr, we must make sure to validate it
266-
// against our ranks, to restore the invariant established by `isWF` above.
267-
// TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a
268-
// `Ranks` argument and perform the validation then.
265+
// NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
266+
// to ensure that we maintain the invariant established by `isWF` above.
269267

270268
// Third, we set every `LvlSpec::elideVar` according to whether that
271269
// LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
@@ -300,6 +298,22 @@ bool DimLvlMap::isWF() const {
300298
return true;
301299
}
302300

301+
AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
302+
SmallVector<AffineExpr> lvlAffines;
303+
lvlAffines.reserve(getLvlRank());
304+
for (const auto &lvlSpec : lvlSpecs)
305+
lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
306+
return AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
307+
}
308+
309+
AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
310+
SmallVector<AffineExpr> dimAffines;
311+
dimAffines.reserve(getDimRank());
312+
for (const auto &dimSpec : dimSpecs)
313+
dimAffines.push_back(dimSpec.getExpr().getAffineExpr());
314+
return AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
315+
}
316+
303317
void DimLvlMap::dump() const {
304318
print(llvm::errs(), /*wantElision=*/false);
305319
llvm::errs() << "\n";

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,6 @@ static_assert(IsZeroCostAbstraction<LvlSpec>);
290290

291291
//===----------------------------------------------------------------------===//
292292
class DimLvlMap final {
293-
// TODO(wrengr): Need to define getters
294-
unsigned symRank;
295-
SmallVector<DimSpec> dimSpecs;
296-
SmallVector<LvlSpec> lvlSpecs;
297-
bool mustPrintLvlVars;
298-
299-
// Checks for integrity of variable-binding structure.
300-
// This is already called by the ctor.
301-
[[nodiscard]] bool isWF() const;
302-
303293
public:
304294
DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
305295
ArrayRef<LvlSpec> lvlSpecs);
@@ -310,11 +300,41 @@ class DimLvlMap final {
310300
unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
311301
Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
312302

313-
DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); }
303+
ArrayRef<DimSpec> getDims() const { return dimSpecs; }
304+
const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
305+
SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
306+
return getDim(dim).getSlice();
307+
}
308+
309+
ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
310+
const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
311+
DimLevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
312+
313+
AffineMap getDimToLvlMap(MLIRContext *context) const;
314+
AffineMap getLvlToDimMap(MLIRContext *context) const;
314315

315316
void print(llvm::raw_ostream &os, bool wantElision = true) const;
316317
void print(AsmPrinter &printer, bool wantElision = true) const;
317318
void dump() const;
319+
320+
private:
321+
/// Checks for integrity of variable-binding structure.
322+
/// This is already called by the ctor.
323+
[[nodiscard]] bool isWF() const;
324+
325+
/// Helper function to call `DimSpec::setExpr` while asserting that
326+
/// the invariant established by `DimLvlMap:isWF` is maintained.
327+
/// This is used by the ctor.
328+
void setDimExpr(Dimension dim, DimExpr expr) {
329+
assert(expr && getRanks().isValid(expr));
330+
dimSpecs[dim].setExpr(expr);
331+
}
332+
333+
// All these fields are const-after-ctor.
334+
unsigned symRank;
335+
SmallVector<DimSpec> dimSpecs;
336+
SmallVector<LvlSpec> lvlSpecs;
337+
bool mustPrintLvlVars;
318338
};
319339

320340
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ bool VarSet::occursIn(DimLvlExpr expr) const {
115115
}
116116

117117
void VarSet::add(Var var) {
118-
// NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB.
118+
// NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
119119
impl[var.getKind()][var.getNum()] = true;
120120
}
121121

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
530530
RETURN_ON_FAIL(res);
531531
// Proof of concept result.
532532
// TODO: use DimLvlMap directly as storage representation
533-
for (unsigned i = 0, e = res->getLvlRank(); i < e; i++)
534-
lvlTypes.push_back(res->getDimLevelType(i));
533+
for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++)
534+
lvlTypes.push_back(res->getLvlType(lvl));
535535
}
536536

537537
// Only the last item can omit the comma

0 commit comments

Comments
 (0)