Skip to content

feat (//cpp): Using atol and rtol based tolerance threshold for torchtrtc #1052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions cpp/bin/torchtrtc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ torchtrtc [input_file_path] [output_file_path]
used to select kernels
--workspace-size=[workspace_size] Maximum size of workspace given to
TensorRT
-t[threshold],
--threshold=[threshold] Maximum acceptable numerical deviation
from standard torchscript output
(default 2e-5)
--atol=[atol] Absolute tolerance threshold for acceptable
numerical deviation from standard torchscript
output (default 1e-8)
--rtol=[rtol] Relative tolerance threshold for acceptable
numerical deviation from standard torchscript
output (default 1e-5)
--no-threshold-check Skip checking threshold compliance
--truncate-long-double,
--truncate, --truncate-64bit Truncate weights that are provided in
Expand Down
21 changes: 18 additions & 3 deletions cpp/bin/torchtrtc/accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,24 @@ bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, fl
return diff.abs().max().item<float>() <= threshold * maxValue;
}

bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold) {
return check_rtol(a - b, {a, b}, threshold);
bool almost_equal(
const at::Tensor& computed_tensor,
const at::Tensor& gt_tensor, // gt_tensor : Ground Truth Tensor
float atol,
float rtol) {
auto computed_tensor_float = computed_tensor.toType(at::kFloat);
auto gt_tensor_float = gt_tensor.toType(at::kFloat);

auto diff = computed_tensor_float - gt_tensor_float;
auto result = diff.abs().max().item<float>();
auto threshold = atol + (rtol * gt_tensor.abs().max().item<float>());

torchtrt::logging::log(torchtrt::logging::Level::kDEBUG, std::string("Max Difference: ") + std::to_string(result));
torchtrt::logging::log(
torchtrt::logging::Level::kDEBUG, std::string("Acceptable Threshold: ") + std::to_string(threshold));

return result <= threshold;
}

} // namespace accuracy
} // namespace torchtrtc
} // namespace torchtrtc
4 changes: 2 additions & 2 deletions cpp/bin/torchtrtc/accuracy.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace torchtrtc {
namespace accuracy {

bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold);
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold);
bool almost_equal(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5);

} // namespace accuracy
} // namespace torchtrtc
} // namespace torchtrtc
37 changes: 25 additions & 12 deletions cpp/bin/torchtrtc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ int main(int argc, char** argv) {
parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
args::ValueFlag<uint64_t> workspace_size(
parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
args::ValueFlag<double> threshold(
args::ValueFlag<double> atol(
parser,
"threshold",
"Maximum acceptable numerical deviation from standard torchscript output (default 2e-5)",
{'t', "threshold"});
"atol",
"Absolute tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-8)",
{"atol"});
args::ValueFlag<double> rtol(
parser,
"rtol",
"Relative tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-5)",
{"rtol"});

args::Flag no_threshold_check(
parser, "no-threshold-check", "Skip checking threshold compliance", {"no-threshold-check", "no-threshold-check"});
Expand Down Expand Up @@ -392,9 +397,13 @@ int main(int argc, char** argv) {
(compile_settings.enabled_precisions.size() == 1 &&
compile_settings.enabled_precisions.find(torchtrt::DataType::kFloat) !=
compile_settings.enabled_precisions.end())) {
double threshold_val = 2e-5;
if (threshold) {
threshold_val = args::get(threshold);
double atol_val = 1e-8;
double rtol_val = 1e-5;
if (atol) {
atol_val = args::get(atol);
}
if (rtol) {
rtol_val = args::get(rtol);
}

std::vector<torch::jit::IValue> jit_inputs_ivalues;
Expand Down Expand Up @@ -431,14 +440,18 @@ int main(int argc, char** argv) {
}

for (size_t i = 0; i < trt_results.size(); i++) {
std::ostringstream threshold_ss;
threshold_ss << "atol: " << atol_val << " rtol: " << rtol_val;
if (!torchtrtc::accuracy::almost_equal(
jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold_val)) {
std::ostringstream threshold_ss;
threshold_ss << threshold_val;
jit_results[i], trt_results[i].reshape_as(jit_results[i]), atol_val, rtol_val)) {
torchtrt::logging::log(
torchtrt::logging::Level::kWARNING,
std::string("Maximum numerical deviation for output exceeds set threshold (") + threshold_ss.str() +
std::string(")"));
std::string("Maximum numerical deviation for output exceeds tolerance thresholds (") +
threshold_ss.str() + std::string(")"));
} else {
torchtrt::logging::log(
torchtrt::logging::Level::kDEBUG,
std::string("Maximum numerical deviation within threshold limits ") + threshold_ss.str());
}
}
} else {
Expand Down
10 changes: 6 additions & 4 deletions docsrc/tutorials/torchtrtc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
used to select kernels
--workspace-size=[workspace_size] Maximum size of workspace given to
TensorRT
-t[threshold],
--threshold=[threshold] Maximum acceptable numerical deviation
from standard torchscript output
(default 2e-5)
--atol=[atol] Absolute tolerance threshold for acceptable
numerical deviation from standard torchscript
output (default 1e-8)
--rtol=[rtol] Relative tolerance threshold for acceptable
numerical deviation from standard torchscript
output (default 1e-5)
--no-threshold-check Skip checking threshold compliance
--truncate-long-double,
--truncate, --truncate-64bit Truncate weights that are provided in
Expand Down