Skip to content

Commit ebe8522

Browse files
authored
Improve ET_CHECK_OR_RETURN_FALSE error messages in kernels/portable (#9751)
We have a lot of `ET_CHECK_OR_RETURN_FALSE` that log a condition, but not the values of the variables in that condition. This is an attempt to improve debuggability of these errors. cc @larryliu0820 @manuelcandales
1 parent ce74f8e commit ebe8522

14 files changed

+157
-50
lines changed

kernels/portable/cpu/op_convolution_backward.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ bool check_convolution_backward_args(
4141
ET_CHECK_OR_RETURN_FALSE(
4242
transposed == false, "Transposed Convolution Backward not supported yet");
4343
ET_CHECK_OR_RETURN_FALSE(
44-
weight.dim() == 4, "Only 2D Convolution Backward supported for now");
44+
weight.dim() == 4,
45+
"Only 2D Convolution Backward supported for now; weight.dim() = %zd",
46+
weight.dim());
4547

4648
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, input));
4749
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
@@ -91,7 +93,9 @@ bool check_convolution_backward_args(
9193

9294
ET_CHECK_OR_RETURN_FALSE(
9395
grad_output.dim() == input.dim(),
94-
"grad_output should have same number of dimensions as input");
96+
"grad_output should have same number of dimensions as input; grad_output.dim() = %zd, input.dim() = %zd",
97+
grad_output.dim(),
98+
input.dim());
9599

96100
ET_LOG_AND_RETURN_IF_FALSE(
97101
tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));

kernels/portable/cpu/op_linear_scratch_example.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,22 @@ bool check_linear_scratch_example_args(
4141
Tensor& out,
4242
Tensor& scratch) {
4343
ET_CHECK_OR_RETURN_FALSE(
44-
input.size(1) == weight.size(1), "Unexpected weight size 1");
44+
input.size(1) == weight.size(1),
45+
"Unexpected weight size 1; input.size(1) = %zd, weight.size(1) = %zd",
46+
input.size(1),
47+
weight.size(1));
4548

4649
ET_CHECK_OR_RETURN_FALSE(
47-
scratch.size(0) == input.size(0), "Unexpected scratch size 0");
50+
scratch.size(0) == input.size(0),
51+
"Unexpected scratch size 0; scratch.size(0) = %zd, input.size(0) = %zd",
52+
scratch.size(0),
53+
input.size(0));
4854

4955
ET_CHECK_OR_RETURN_FALSE(
50-
scratch.size(1) == weight.size(0), "Unexpected scratch size 1");
56+
scratch.size(1) == weight.size(0),
57+
"Unexpected scratch size 1; scratch.size(1) = %zd, weight.size(0) = %zd",
58+
scratch.size(1),
59+
weight.size(0));
5160

5261
return true;
5362
}

kernels/portable/cpu/op_max_pool2d_with_indices_backward.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ bool check_max_pool2d_backward_args(
6262

6363
ET_CHECK_OR_RETURN_FALSE(
6464
grad_output.dim() == input.dim(),
65-
"grad_output should have same number of dimensions as input");
65+
"grad_output should have same number of dimensions as input; grad_output.dim() = %zd, input.dim() = %zd",
66+
grad_output.dim(),
67+
input.dim());
6668

6769
ET_LOG_AND_RETURN_IF_FALSE(
6870
tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));

kernels/portable/cpu/op_repeat_interleave.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,37 @@ bool check_repeat_interleave_args(
2222
ET_CHECK_OR_RETURN_FALSE(
2323
repeats.scalar_type() == ScalarType::Int ||
2424
repeats.scalar_type() == ScalarType::Long,
25-
"repeats must be int or long");
26-
ET_CHECK_OR_RETURN_FALSE(repeats.dim() == 1, "repeats must be 1D");
25+
"repeats must be int or long; repeats.scalar_type() = %d",
26+
static_cast<int>(repeats.scalar_type()));
27+
ET_CHECK_OR_RETURN_FALSE(
28+
repeats.dim() == 1,
29+
"repeats must be 1-D; repeats.dim() = %zd",
30+
repeats.dim());
2731
ET_CHECK_OR_RETURN_FALSE(
2832
output_size_value == repeats_sum,
29-
"output_size, if provided, must be equal to repeats.sum()");
33+
"output_size, if provided, must be equal to repeats.sum(); output_size_value = %" PRId64
34+
", repeats_sum = %" PRId64,
35+
output_size_value,
36+
repeats_sum);
3037
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(repeats, out));
3138

3239
if (repeats.scalar_type() == ScalarType::Long) {
3340
const int64_t* const repeats_data = repeats.const_data_ptr<int64_t>();
3441
for (const auto i : c10::irange(repeats.numel())) {
3542
ET_CHECK_OR_RETURN_FALSE(
36-
repeats_data[i] >= 0, "repeats cannot be negative");
43+
repeats_data[i] >= 0,
44+
"repeats cannot be negative; repeats_data[%" PRId64 "] = %" PRId64,
45+
static_cast<int64_t>(i),
46+
repeats_data[i]);
3747
}
3848
} else {
3949
const int32_t* const repeats_data = repeats.const_data_ptr<int32_t>();
4050
for (const auto i : c10::irange(repeats.numel())) {
4151
ET_CHECK_OR_RETURN_FALSE(
42-
repeats_data[i] >= 0, "repeats cannot be negative");
52+
repeats_data[i] >= 0,
53+
"repeats cannot be negative; repeats_data[%" PRId64 "] = %d",
54+
static_cast<int64_t>(i),
55+
repeats_data[i]);
4356
}
4457
}
4558

kernels/portable/cpu/op_topk.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ bool check_topk_args(
3030
dim += nonzero_dim(in);
3131
}
3232
ET_CHECK_OR_RETURN_FALSE(
33-
k >= 0 && k <= nonempty_size(in, dim), "selected index k out of range");
33+
k >= 0 && k <= nonempty_size(in, dim),
34+
"selected index k out of range; k = %" PRId64 ", dim = %" PRId64
35+
", in.dim() = %zd, nonempty_size(in, dim) = %zd",
36+
k,
37+
dim,
38+
in.dim(),
39+
nonempty_size(in, dim));
3440
return true;
3541
}
3642

kernels/portable/cpu/util/activation_ops_util.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ bool check_glu_args(const Tensor& in, int64_t dim, Tensor& out) {
4343
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(in, out));
4444
ET_CHECK_OR_RETURN_FALSE(
4545
out.size(non_negative_dim) == dim_size / 2,
46-
"output tensor must have half the size of the input tensor along the specified dimension.");
46+
"output tensor must have half the size of the input tensor along the specified dimension; out.size(%zu) = %zd, dim_size = %zd",
47+
non_negative_dim,
48+
out.size(non_negative_dim),
49+
dim_size);
4750

4851
for (const auto i : c10::irange(in.dim())) {
4952
if (static_cast<size_t>(i) != non_negative_dim) {

kernels/portable/cpu/util/advanced_index_util.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ bool check_indices_dtypes(TensorOptList indices) {
2828
ET_CHECK_OR_RETURN_FALSE(
2929
ix_type == ScalarType::Long || ix_type == ScalarType::Int ||
3030
ix_type == ScalarType::Byte || ix_type == ScalarType::Bool,
31-
"Index tensors should be Long, Int, Byte or Bool");
31+
"Index tensors should be Long, Int, Byte or Bool; got %d",
32+
static_cast<int>(ix_type));
3233
}
3334
}
3435
return true;
@@ -295,11 +296,18 @@ bool get_index_out_target_size(
295296

296297
ET_CHECK_OR_RETURN_FALSE(
297298
static_cast<ssize_t>(num_null_indices + num_indexed_dims) <= in.dim(),
298-
"Indexing too many dimensions");
299+
"Indexing too many dimensions; num_null_indices = %zu, num_indexed_dims = %zu, in.dim() = %zd",
300+
num_null_indices,
301+
num_indexed_dims,
302+
in.dim());
299303

300304
ET_CHECK_OR_RETURN_FALSE(
301305
in.dim() + broadcast_ndim - num_indexed_dims <= kTensorDimensionLimit,
302-
"Out tensor would exceed number of allowed dimensions");
306+
"Out tensor would exceed number of allowed dimensions; in.dim() = %zd, broadcast_ndim = %zu, num_indexed_dims = %zu, kTensorDimensionLimit = %zu",
307+
in.dim(),
308+
broadcast_ndim,
309+
num_indexed_dims,
310+
kTensorDimensionLimit);
303311

304312
(*out_ndim) = in.dim() + broadcast_ndim - num_indexed_dims;
305313

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ bool check_as_strided_copy_args(
4646
Tensor& out) {
4747
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
4848
ET_CHECK_OR_RETURN_FALSE(
49-
size.size() == stride.size(), "mismatch in length of strides and shape");
49+
size.size() == stride.size(),
50+
"mismatch in length of strides and shape; size.size() = %zu, stride.size() = %zu",
51+
size.size(),
52+
stride.size());
5053
for (const auto& val : stride) {
5154
ET_CHECK_OR_RETURN_FALSE(
5255
val >= 0,
@@ -242,7 +245,9 @@ bool check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out) {
242245

243246
// Check that the dimension hasn't been seen previously.
244247
ET_CHECK_OR_RETURN_FALSE(
245-
dim_exist[dim] == false, "duplicate dims are not allowed.");
248+
dim_exist[dim] == false,
249+
"duplicate dims are not allowed; dim = %zu",
250+
dim);
246251

247252
dim_exist[dim] = true;
248253
}
@@ -424,19 +429,27 @@ bool check_split_with_sizes_copy_args(
424429

425430
ET_CHECK_OR_RETURN_FALSE(
426431
split_sizes.size() == out.size(),
427-
"Number of split sizes must match the number of output tensors");
432+
"Number of split sizes must match the number of output tensors; split_sizes.size() = %zu, out.size() = %zu",
433+
split_sizes.size(),
434+
out.size());
428435

429436
int64_t sum = 0;
430437
for (const auto i : c10::irange(split_sizes.size())) {
431438
ET_CHECK_OR_RETURN_FALSE(
432-
split_sizes[i] >= 0, "All split sizes must be non negative.");
439+
split_sizes[i] >= 0,
440+
"All split sizes must be non negative; split_sizes[%zu] = %" PRId64,
441+
i,
442+
split_sizes[i]);
433443
sum += split_sizes[i];
434444
}
435445

436446
const ssize_t dim_size = in.size(dim);
437447
ET_CHECK_OR_RETURN_FALSE(
438448
sum == dim_size,
439-
"Sum of split sizes does not match input size at given dim");
449+
"Sum of split sizes does not match input size at given dim; sum = %" PRId64
450+
", dim_size = %zd",
451+
sum,
452+
dim_size);
440453

441454
return true;
442455
}

kernels/portable/cpu/util/distance_util.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ namespace executor {
1414
bool check_pdist_args(const Tensor& in, double p, const Tensor& out) {
1515
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
1616
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
17-
ET_CHECK_OR_RETURN_FALSE(p >= 0, "pdist only supports non-negative p values");
17+
ET_CHECK_OR_RETURN_FALSE(
18+
p >= 0, "pdist only supports non-negative p values; p = %.6f", p);
1819
return true;
1920
}
2021

@@ -39,7 +40,8 @@ bool check_cdist_args(
3940
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(x2, 2));
4041
ET_LOG_AND_RETURN_IF_FALSE(
4142
tensors_have_same_size_at_dims(x1, x1.dim() - 1, x2, x2.dim() - 1));
42-
ET_CHECK_OR_RETURN_FALSE(p >= 0, "cdist only supports non-negative p values");
43+
ET_CHECK_OR_RETURN_FALSE(
44+
p >= 0, "cdist only supports non-negative p values; p = %.6f", p);
4345
if (compute_mode.has_value()) {
4446
int64_t mode = compute_mode.value();
4547
ET_CHECK_OR_RETURN_FALSE(

kernels/portable/cpu/util/index_util.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@ bool check_gather_args(
2323
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
2424
ET_CHECK_OR_RETURN_FALSE(
2525
index.scalar_type() == ScalarType::Long,
26-
"Expected dypte int64 for index");
26+
"Expected dypte int64 for index; index.scalar_type() = %d",
27+
static_cast<int>(index.scalar_type()));
2728
if (index.numel() != 0) {
2829
ET_CHECK_OR_RETURN_FALSE(
2930
nonzero_dim(in) == nonzero_dim(index),
3031
"self and index should have the same dimensionality when index is not empty "
31-
"except for the case when one has dimension 0 and the other has dimension 1");
32+
"except for the case when one has dimension 0 and the other has dimension 1; nonzero_dim(in) = %zd, nonzero_dim(index) = %zd",
33+
nonzero_dim(in),
34+
nonzero_dim(index));
3235
}
3336

3437
// Normalize dim to non-negative value
@@ -67,7 +70,8 @@ bool check_index_select_args(
6770
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
6871
ET_CHECK_OR_RETURN_FALSE(
6972
nonempty_size(in, dim) > 0,
70-
"index_select: Indexing axis dim should be positive");
73+
"index_select: Indexing axis dim should be positive; nonempty_size(in, dim) = %zd",
74+
nonempty_size(in, dim));
7175

7276
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
7377
ET_CHECK_OR_RETURN_FALSE(
@@ -80,7 +84,8 @@ bool check_index_select_args(
8084
if (index.dim() > 0 && in.dim() == 0) {
8185
ET_CHECK_OR_RETURN_FALSE(
8286
index.numel() == 1,
83-
"index_select: Index to scalar must have exactly 1 value");
87+
"index_select: Index to scalar must have exactly 1 value; index.numel() = %zd",
88+
index.numel());
8489
}
8590

8691
if (index.scalar_type() == ScalarType::Long) {
@@ -150,7 +155,8 @@ bool check_scatter_add_args(
150155
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, src));
151156
ET_CHECK_OR_RETURN_FALSE(
152157
index.scalar_type() == ScalarType::Long,
153-
"Expected dypte int64 for index");
158+
"Expected dypte int64 for index; index.scalar_type() = %d",
159+
static_cast<int>(index.scalar_type()));
154160
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(self, dim));
155161

156162
if (index.numel() == 0) {
@@ -160,7 +166,10 @@ bool check_scatter_add_args(
160166
ET_CHECK_OR_RETURN_FALSE(
161167
nonzero_dim(self) == nonzero_dim(src) &&
162168
nonzero_dim(self) == nonzero_dim(index),
163-
"self, index and src should have same number of dimensions.");
169+
"self, index and src should have same number of dimensions; nonzero_dim(self) = %zd, nonzero_dim(src) = %zd, nonzero_dim(index) = %zd",
170+
nonzero_dim(self),
171+
nonzero_dim(src),
172+
nonzero_dim(index));
164173

165174
// Normalize dim to non-negative value
166175
if (dim < 0) {

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,10 @@ bool check_arange_args(double start, double end, double step, Tensor& out) {
254254

255255
ET_CHECK_OR_RETURN_FALSE(
256256
(step > 0 && (end >= start)) || (step < 0 && (end <= start)),
257-
"upper bound and larger bound inconsistent with step sign");
257+
"upper bound and larger bound inconsistent with step sign; step = %.6f, start = %.6f, end = %.6f",
258+
step,
259+
start,
260+
end);
258261

259262
return true;
260263
}
@@ -276,7 +279,8 @@ bool check_avg_pool2d_args(
276279
ET_CHECK_OR_RETURN_FALSE(
277280
(in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
278281
(in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
279-
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
282+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input; in.dim() = %zd",
283+
in.dim());
280284

281285
ET_LOG_AND_RETURN_IF_FALSE(
282286
kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
@@ -347,8 +351,9 @@ bool check_convolution_args(
347351
ET_CHECK_OR_RETURN_FALSE(
348352
bias.value().size(0) == transposed ? groups * weight.size(1)
349353
: weight.size(0),
350-
"bias length must equal number of output channels, but got %zd",
351-
bias.value().size(0));
354+
"bias length must equal number of output channels, but got %zd; expected %" PRId64,
355+
bias.value().size(0),
356+
transposed ? groups * weight.size(1) : weight.size(0));
352357
}
353358

354359
int64_t kernel_size[2];
@@ -398,7 +403,9 @@ bool check_convolution_args(
398403
} else {
399404
ET_CHECK_OR_RETURN_FALSE(
400405
in.size(1) == weight.size(0),
401-
"input channels must match weight.size(0) in transposed convolution");
406+
"input channels must match weight.size(0) in transposed convolution; in.size(1) = %zd, weight.size(0) = %zd",
407+
in.size(1),
408+
weight.size(0));
402409
}
403410

404411
return true;
@@ -484,7 +491,8 @@ bool check_max_pool2d_with_indices_args(
484491
ET_CHECK_OR_RETURN_FALSE(
485492
(in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
486493
(in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
487-
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
494+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input; in.dim() = %zd",
495+
in.dim());
488496

489497
ET_LOG_AND_RETURN_IF_FALSE(
490498
kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
@@ -545,11 +553,15 @@ bool check_constant_pad_args(
545553
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(in, out));
546554

547555
ET_CHECK_OR_RETURN_FALSE(
548-
pad.size() % 2 == 0, "Padding array must be a multiple of 2");
556+
pad.size() % 2 == 0,
557+
"Padding array must be a multiple of 2; pad.size() = %zu",
558+
pad.size());
549559

550560
ET_CHECK_OR_RETURN_FALSE(
551561
static_cast<ssize_t>(pad.size() / 2) <= in.dim(),
552-
"Padding array contains too many elements");
562+
"Padding array contains too many elements; pad.size()/2 = %zu, in.dim() = %zd",
563+
pad.size() / 2,
564+
in.dim());
553565

554566
return true;
555567
}

0 commit comments

Comments
 (0)