Skip to content

Commit 3cbe2e2

Browse files
committed
chore!: Updated the tensor names as per review comments
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 2d51217 commit 3cbe2e2

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

cpp/bin/torchtrtc/accuracy.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@ bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, fl
1919
return diff.abs().max().item<float>() <= threshold * maxValue;
2020
}
2121

22-
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float atol, float rtol) {
23-
auto a_float = a.toType(at::kFloat);
24-
auto b_float = b.toType(at::kFloat);
25-
26-
auto diff = a_float - b_float;
22+
bool almost_equal(
23+
const at::Tensor& computed_tensor,
24+
const at::Tensor& gt_tensor, // gt_tensor : Ground Truth Tensor
25+
float atol,
26+
float rtol) {
27+
auto computed_tensor_float = computed_tensor.toType(at::kFloat);
28+
auto gt_tensor_float = gt_tensor.toType(at::kFloat);
29+
30+
auto diff = computed_tensor_float - gt_tensor_float;
2731
auto result = diff.abs().max().item<float>();
28-
auto threshold = atol + (rtol * b.abs().max().item<float>());
32+
auto threshold = atol + (rtol * gt_tensor.abs().max().item<float>());
2933

3034
torchtrt::logging::log(torchtrt::logging::Level::kDEBUG, std::string("Max Difference: ") + std::to_string(result));
3135
torchtrt::logging::log(

cpp/bin/torchtrtc/accuracy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace torchtrtc {
1212
namespace accuracy {
1313

1414
bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold);
15-
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float atol = 1e-8, float rtol = 1e-5);
15+
bool almost_equal(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5);
1616

1717
} // namespace accuracy
1818
} // namespace torchtrtc

cpp/bin/torchtrtc/main.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,14 +440,18 @@ int main(int argc, char** argv) {
440440
}
441441

442442
for (size_t i = 0; i < trt_results.size(); i++) {
443+
std::ostringstream threshold_ss;
444+
threshold_ss << "atol: " << atol_val << " rtol: " << rtol_val;
443445
if (!torchtrtc::accuracy::almost_equal(
444446
jit_results[i], trt_results[i].reshape_as(jit_results[i]), atol_val, rtol_val)) {
445-
std::ostringstream threshold_ss;
446-
threshold_ss << "atol: " << atol_val << " rtol: " << rtol_val;
447447
torchtrt::logging::log(
448448
torchtrt::logging::Level::kWARNING,
449449
std::string("Maximum numerical deviation for output exceeds tolerance thresholds (") +
450450
threshold_ss.str() + std::string(")"));
451+
} else {
452+
torchtrt::logging::log(
453+
torchtrt::logging::Level::kDEBUG,
454+
std::string("Maximum numerical deviation within threshold limits ") + threshold_ss.str());
451455
}
452456
}
453457
} else {

0 commit comments

Comments
 (0)