Skip to content

Commit 8466eb7

Browse files
[mlir][sparse] Add more error messages and avoid crashing in new parser (#67034)
Updates: 1. Added more invalid encodings to test the robustness of the new syntax 2. Changed the asserts that caused crashing into returning booleans 3. Modified some error messages to make them clearer and handled failures in parsing quotes as keyword for level formats and properties.
1 parent 62a3d84 commit 8466eb7

File tree

4 files changed

+188
-44
lines changed

4 files changed

+188
-44
lines changed

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ using namespace mlir::sparse_tensor::ir_detail;
4949

5050
FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
5151
StringRef base;
52-
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
53-
uint8_t properties = 0;
5452
const auto loc = parser.getCurrentLocation();
53+
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
54+
"expected valid level format (e.g. dense, compressed or singleton)")
55+
uint8_t properties = 0;
5556

5657
ParseResult res = parser.parseCommaSeparatedList(
5758
mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -73,19 +74,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
7374
} else if (base.compare("singleton") == 0) {
7475
properties |= static_cast<uint8_t>(LevelFormat::Singleton);
7576
} else {
76-
parser.emitError(loc, "unknown level format");
77+
parser.emitError(loc, "unknown level format: ") << base;
7778
return failure();
7879
}
7980

8081
ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
81-
"invalid level type");
82+
"invalid level type: level format doesn't support the properties");
8283
return properties;
8384
}
8485

8586
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
8687
uint8_t *properties) const {
8788
StringRef strVal;
88-
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
89+
auto loc = parser.getCurrentLocation();
90+
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
91+
"expected valid level property (e.g. nonordered, nonunique or high)")
8992
if (strVal.compare("nonunique") == 0) {
9093
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
9194
} else if (strVal.compare("nonordered") == 0) {
@@ -95,7 +98,7 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
9598
} else if (strVal.compare("block2_4") == 0) {
9699
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
97100
} else {
98-
parser.emitError(parser.getCurrentLocation(), "unknown level property");
101+
parser.emitError(loc, "unknown level property: ") << strVal;
99102
return failure();
100103
}
101104
return success();

mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -196,34 +196,25 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
196196
return pair1 <= pair2 ? sm1 : sm2;
197197
}
198198

199-
LLVM_ATTRIBUTE_UNUSED static void
200-
assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) {
201-
#ifndef NDEBUG
199+
bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
202200
const auto &var = env.access(id);
203-
assert(var.getName() == name && "found inconsistent name");
204-
assert(var.getID() == id && "found inconsistent VarInfo::ID");
205-
#endif // NDEBUG
201+
return (var.getName() == name && var.getID() == id);
206202
}
207203

208204
// NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
209205
// (or find some other way to convert SMLoc to FileLineColLoc), then this
210206
// would no longer be `const VarEnv` (and couldn't be a free-function either).
211-
LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
212-
VarInfo::ID id,
213-
llvm::SMLoc loc,
214-
VarKind vk) {
215-
#ifndef NDEBUG
207+
bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
208+
VarKind vk) {
216209
const auto &var = env.access(id);
217-
assert(var.getKind() == vk &&
218-
"a variable of that name already exists with a different VarKind");
219210
// Since the same variable can occur at several locations,
220211
// it would not be appropriate to do `assert(var.getLoc() == loc)`.
221212
/* TODO(wrengr):
222213
const auto minLoc = minSMLoc(_, var.getLoc(), loc);
223214
assert(minLoc && "Location mismatch/incompatibility");
224215
var.loc = minLoc;
225216
// */
226-
#endif // NDEBUG
217+
return var.getKind() == vk;
227218
}
228219

229220
std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
@@ -236,24 +227,23 @@ std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
236227
if (iter == ids.end())
237228
return std::nullopt;
238229
const auto id = iter->second;
239-
#ifndef NDEBUG
240-
assertInternalConsistency(*this, id, name);
241-
#endif // NDEBUG
230+
if (!isInternalConsistent(*this, id, name))
231+
return std::nullopt;
242232
return id;
243233
}
244234

245-
std::pair<VarInfo::ID, bool> VarEnv::create(StringRef name, llvm::SMLoc loc,
246-
VarKind vk, bool verifyUsage) {
235+
std::optional<std::pair<VarInfo::ID, bool>>
236+
VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
247237
const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
248238
const auto id = iter->second;
249239
if (didInsert) {
250240
vars.emplace_back(id, name, loc, vk);
251241
} else {
252-
#ifndef NDEBUG
253-
assertInternalConsistency(*this, id, name);
254-
if (verifyUsage)
255-
assertUsageConsistency(*this, id, loc, vk);
256-
#endif // NDEBUG
242+
if (!isInternalConsistent(*this, id, name))
243+
return std::nullopt;
244+
if (verifyUsage)
245+
if (!isUsageConsistent(*this, id, loc, vk))
246+
return std::nullopt;
257247
}
258248
return std::make_pair(id, didInsert);
259249
}
@@ -265,20 +255,18 @@ VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
265255
case Policy::MustNot: {
266256
const auto oid = lookup(name);
267257
if (!oid)
268-
return std::nullopt; // Doesn't exist, but must not create.
269-
#ifndef NDEBUG
270-
assertUsageConsistency(*this, *oid, loc, vk);
271-
#endif // NDEBUG
258+
return std::nullopt; // Doesn't exist, but must not create.
259+
if (!isUsageConsistent(*this, *oid, loc, vk))
260+
return std::nullopt;
272261
return std::make_pair(*oid, false);
273262
}
274263
case Policy::May:
275264
return create(name, loc, vk, /*verifyUsage=*/true);
276265
case Policy::Must: {
277266
const auto res = create(name, loc, vk, /*verifyUsage=*/false);
278-
// const auto id = res.first;
279-
const auto didCreate = res.second;
267+
const auto didCreate = res->second;
280268
if (!didCreate)
281-
return std::nullopt; // Already exists, but must create.
269+
return std::nullopt; // Already exists, but must create.
282270
return res;
283271
}
284272
}

mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,8 @@ class VarEnv final {
453453
/// for the variable with the given name (i.e., either the newly created
454454
/// variable, or the pre-existing variable), and a bool indicating whether
455455
/// a new variable was created.
456-
std::pair<VarInfo::ID, bool> create(StringRef name, llvm::SMLoc loc,
457-
VarKind vk, bool verifyUsage = false);
456+
std::optional<std::pair<VarInfo::ID, bool>>
457+
create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
458458

459459
/// Attempts to lookup or create a variable according to the given
460460
/// `Policy`. Returns nullopt in one of two circumstances:

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 158 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,49 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

3-
// expected-error@+1 {{expected a non-empty array for lvlTypes}}
4-
#a = #sparse_tensor.encoding<{lvlTypes = []}>
3+
// expected-error@+1 {{expected '(' in dimension-specifier list}}
4+
#a = #sparse_tensor.encoding<{map = []}>
5+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
6+
7+
// -----
8+
9+
// expected-error@+1 {{expected '->'}}
10+
#a = #sparse_tensor.encoding<{map = ()}>
11+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
12+
13+
// -----
14+
15+
// expected-error@+1 {{expected ')' in dimension-specifier list}}
16+
#a = #sparse_tensor.encoding<{map = (d0 -> d0)}>
17+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
18+
19+
// -----
20+
21+
// expected-error@+1 {{expected '(' in dimension-specifier list}}
22+
#a = #sparse_tensor.encoding<{map = d0 -> d0}>
23+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
24+
25+
// -----
26+
27+
// expected-error@+1 {{expected '(' in level-specifier list}}
28+
#a = #sparse_tensor.encoding<{map = (d0) -> d0}>
29+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
30+
31+
// -----
32+
33+
// expected-error@+1 {{expected ':'}}
34+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0)}>
35+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
36+
37+
// -----
38+
39+
// expected-error@+1 {{expected valid level format (e.g. dense, compressed or singleton)}}
40+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0:)}>
41+
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
42+
43+
// -----
44+
45+
// expected-error@+1 {{expected valid level format (e.g. dense, compressed or singleton)}}
46+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : (compressed))}>
547
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
648

749
// -----
@@ -18,17 +60,61 @@ func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
1860

1961
// -----
2062

21-
#a = #sparse_tensor.encoding<{lvlTypes = [1]}> // expected-error {{expected a string value in lvlTypes}}
63+
// expected-error@+1 {{unexpected dimToLvl mapping from 2 to 1}}
64+
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense)}>
65+
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
66+
67+
// -----
68+
69+
// expected-error@+1 {{expected bare identifier}}
70+
#a = #sparse_tensor.encoding<{map = (1)}>
71+
func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
72+
73+
// -----
74+
75+
// expected-error@+1 {{unexpected key: nap}}
76+
#a = #sparse_tensor.encoding<{nap = (d0) -> (d0 : dense)}>
77+
func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
78+
79+
// -----
80+
81+
// expected-error@+1 {{expected '(' in dimension-specifier list}}
82+
#a = #sparse_tensor.encoding<{map = -> (d0 : dense)}>
2283
func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
2384

2485
// -----
2586

26-
#a = #sparse_tensor.encoding<{lvlTypes = ["strange"]}> // expected-error {{unexpected level-type: strange}}
87+
// expected-error@+1 {{unknown level format: strange}}
88+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : strange)}>
2789
func.func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> ()
2890

2991
// -----
3092

31-
#a = #sparse_tensor.encoding<{dimToLvl = "wrong"}> // expected-error {{expected an affine map for dimToLvl}}
93+
// expected-error@+1 {{expected valid level format (e.g. dense, compressed or singleton)}}
94+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : "wrong")}>
95+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
96+
97+
// -----
98+
99+
// expected-error@+1 {{expected valid level property (e.g. nonordered, nonunique or high)}}
100+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed("wrong"))}>
101+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
102+
103+
// -----
104+
// expected-error@+1 {{expected ')' in level-specifier list}}
105+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed[high])}>
106+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
107+
108+
// -----
109+
110+
// expected-error@+1 {{unknown level property: wrong}}
111+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed(wrong))}>
112+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
113+
114+
// -----
115+
116+
// expected-error@+1 {{use of undeclared identifier}}
117+
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed, dense)}>
32118
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
33119

34120
// -----
@@ -39,6 +125,73 @@ func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> ()
39125

40126
// -----
41127

128+
// expected-error@+1 {{unexpected character}}
129+
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed; d1 : dense)}>
130+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
131+
132+
// -----
133+
134+
// expected-error@+1 {{expected attribute value}}
135+
#a = #sparse_tensor.encoding<{map = (d0: d1) -> (d0 : compressed, d1 : dense)}>
136+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
137+
138+
// -----
139+
140+
// expected-error@+1 {{expected ':'}}
141+
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 = compressed, d1 = dense)}>
142+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
143+
144+
// -----
145+
146+
// expected-error@+1 {{expected attribute value}}
147+
#a = #sparse_tensor.encoding<{map = (d0 : compressed, d1 : compressed)}>
148+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
149+
150+
// -----
151+
152+
// expected-error@+1 {{use of undeclared identifier}}
153+
#a = #sparse_tensor.encoding<{map = (d0 = compressed, d1 = compressed)}>
154+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
155+
156+
// -----
157+
158+
// expected-error@+1 {{use of undeclared identifier}}
159+
#a = #sparse_tensor.encoding<{map = (d0 = l0, d1 = l1) {l0, l1} -> (l0 = d0 : dense, l1 = d1 : compressed)}>
160+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
161+
162+
// -----
163+
164+
// expected-error@+1 {{expected '='}}
165+
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : d0 = dense, l1 : d1 = compressed)}>
166+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
167+
168+
// -----
169+
// expected-error@+1 {{use of undeclared identifier 'd0'}}
170+
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : l0 = dense, d1 : l1 = compressed)}>
171+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
172+
173+
// -----
174+
// expected-error@+1 {{use of undeclared identifier 'd0'}}
175+
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : dense, d1 : compressed)}>
176+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
177+
178+
// -----
179+
// expected-error@+1 {{expected '='}}
180+
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : dense, l1 : compressed)}>
181+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
182+
183+
// -----
184+
// expected-error@+1 {{use of undeclared identifier}}
185+
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = dense, l1 = compressed)}>
186+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
187+
188+
// -----
189+
// expected-error@+1 {{use of undeclared identifier 'd0'}}
190+
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 = l0 : dense, d1 = l1 : compressed)}>
191+
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
192+
193+
// -----
194+
42195
#a = #sparse_tensor.encoding<{posWidth = "x"}> // expected-error {{expected an integral position bitwidth}}
43196
func.func private @tensor_no_int_ptr(%arg0: tensor<16x32xf32, #a>) -> ()
44197

0 commit comments

Comments
 (0)