Skip to content

Commit 83a9890

Browse files
authored
Merge pull request #1285 from pytorch/anuragd/test_pybind
fix: Fixing pybind error on nightly
2 parents fd1c2cd + 0705a54 commit 83a9890

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ class pyCalibratorTrampoline : public Derived {
2323
using Derived::Derived; // Inherit constructors
2424

2525
int getBatchSize() const noexcept override {
26-
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
26+
try {
27+
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
28+
} catch (std::exception const& e) {
29+
LOG_ERROR("Exception caught in get_batch_size" + std::string(e.what()));
30+
} catch (...) {
31+
LOG_ERROR("Exception caught in get_batch_size");
32+
}
33+
return -1;
2734
}
2835

2936
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
@@ -71,8 +78,15 @@ class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrato
7178
using Derived::Derived;
7279

7380
nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override {
74-
PYBIND11_OVERLOAD_PURE_NAME(
75-
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm);
81+
try {
82+
PYBIND11_OVERLOAD_PURE_NAME(
83+
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm);
84+
} catch (std::exception const& e) {
85+
LOG_ERROR("Exception caught in get_algorithm: " + std::string(e.what()));
86+
} catch (...) {
87+
LOG_ERROR("Exception caught in get_algorithm");
88+
}
89+
return {};
7690
}
7791
};
7892

@@ -82,21 +96,49 @@ class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Leg
8296
using Derived::Derived;
8397

8498
double getQuantile() const noexcept override {
85-
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile);
99+
try {
100+
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile);
101+
} catch (std::exception const& e) {
102+
LOG_ERROR("Exception caught in get_quantile: " + std::string(e.what()));
103+
} catch (...) {
104+
LOG_ERROR("Exception caught in get_quantile");
105+
}
106+
return -1.0;
86107
}
87108

88109
double getRegressionCutoff() const noexcept override {
89-
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
110+
try {
111+
PYBIND11_OVERLOAD_PURE_NAME(
112+
double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
113+
} catch (std::exception const& e) {
114+
LOG_ERROR("Exception caught in get_regression_cutoff: " + std::string(e.what()));
115+
} catch (...) {
116+
LOG_ERROR("Exception caught in get_regression_cutoff");
117+
}
118+
return -1.0;
90119
}
91120

92121
const void* readHistogramCache(std::size_t& length) noexcept override {
93-
PYBIND11_OVERLOAD_PURE_NAME(
94-
const void*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
122+
try {
123+
PYBIND11_OVERLOAD_PURE_NAME(
124+
const char*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
125+
} catch (std::exception const& e) {
126+
LOG_ERROR("Exception caught in read_histogram_cache" + std::string(e.what()));
127+
} catch (...) {
128+
LOG_ERROR("Exception caught in read_histogram_cache");
129+
}
130+
return {};
95131
}
96132

97133
void writeHistogramCache(const void* ptr, std::size_t length) noexcept override {
98-
PYBIND11_OVERLOAD_PURE_NAME(
99-
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
134+
try {
135+
PYBIND11_OVERLOAD_PURE_NAME(
136+
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
137+
} catch (std::exception const& e) {
138+
LOG_ERROR("Exception caught in write_histogram_cache" + std::string(e.what()));
139+
} catch (...) {
140+
LOG_ERROR("Exception caught in write_histogram_cache");
141+
}
100142
}
101143
};
102144

0 commit comments

Comments
 (0)