Skip to content

Commit 1d78f43

Browse files
feat: Add ts converter support for aten::all.dim (#1840)
1 parent 6f7627f commit 1d78f43

File tree

2 files changed

+105
-24
lines changed

2 files changed

+105
-24
lines changed

core/conversion/converters/impl/reduce.cpp

+54-22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,36 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12+
nvinfer1::ITensor* anyDimImplementation(
13+
ConversionCtx* ctx,
14+
const torch::jit::Node* n,
15+
nvinfer1::ITensor* in_tensor,
16+
int dim,
17+
bool keepdim) {
18+
auto in_dims = in_tensor->getDimensions();
19+
LOG_DEBUG("Dim to reduce (original): " << dim);
20+
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
21+
LOG_DEBUG("Dim to reduce (converted): " << dim);
22+
23+
uint32_t axis_mask = 1 << dim;
24+
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
25+
LOG_DEBUG("Keep dims: " << keepdim);
26+
27+
// Reduce does not work on bool inputs
28+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
29+
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
30+
}
31+
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
32+
33+
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
34+
35+
sum_layer->setName(util::node_info(n).c_str());
36+
auto out_tensor =
37+
castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
38+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
39+
return out_tensor;
40+
}
41+
1242
auto reduce_registrations TORCHTRT_UNUSED =
1343
RegisterNodeConversionPatterns()
1444
.pattern(
@@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED =
224254
{"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
225255
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
226256
auto in_tensor = args[0].ITensorOrFreeze(ctx);
227-
auto in_dims = in_tensor->getDimensions();
228257
auto dim = args[1].unwrapToInt();
229-
LOG_DEBUG("Dim to reduce (original): " << dim);
230-
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
231-
LOG_DEBUG("Dim to reduce (converted): " << dim);
232-
233-
uint32_t axis_mask = 1 << dim;
234-
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
235-
236258
auto keepdim = args[2].unwrapToBool();
237-
LOG_DEBUG("Keep dims: " << keepdim);
238-
239-
// Reduce does not work on bool inputs
240-
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
241-
in_tensor =
242-
castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
243-
}
244-
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
245-
246-
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
247-
248-
sum_layer->setName(util::node_info(n).c_str());
249-
auto out_tensor = castITensor(
250-
ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
259+
auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim);
251260
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
252261
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
253262
return true;
263+
}})
264+
.pattern(
265+
{"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
266+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
267+
// use Not(Any(Not(input))) to calculate all without a direct all reduction
268+
auto in_tensor = args[0].ITensorOrFreeze(ctx);
269+
auto dim = args[1].unwrapToInt();
270+
auto keepdim = args[2].unwrapToBool();
271+
if (in_tensor->getType() != nvinfer1::DataType::kBOOL) {
272+
// unary not layer only supports bool inputs
273+
in_tensor = castITensor(
274+
ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str());
275+
}
276+
auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT);
277+
TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n);
278+
not_input_layer->setName((util::node_info(n) + "_not_in").c_str());
279+
auto not_in = not_input_layer->getOutput(0);
280+
auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim);
281+
auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT);
282+
TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n);
283+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0));
284+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
285+
return true;
254286
}});
255287
} // namespace
256288
} // namespace impl

tests/core/conversion/converters/test_reduce.cpp

+51-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) {
6262
return (%5))IR";
6363
}
6464

65-
void test_body(const std::string& graph, at::Tensor& in) {
65+
void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) {
6666
auto g = std::make_shared<torch::jit::Graph>();
6767
torch::jit::parseIR(graph, g.get());
6868

@@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) {
7171

7272
in = at::clone(in);
7373
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
74-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
74+
std::vector<at::Tensor> trt_results;
75+
if (dynamic) {
76+
trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in});
77+
} else {
78+
trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
79+
}
7580
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
7681
}
7782
} // namespace
@@ -344,6 +349,50 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
344349
test_body(graph, in);
345350
}
346351

352+
TEST(Converters, ATenAllDimConvertsCorrectly) {
353+
const auto graph = R"IR(
354+
graph(%0 : Tensor):
355+
%1 : int = prim::Constant[value=-1]()
356+
%3 : bool = prim::Constant[value=0]()
357+
%5 : Tensor = aten::all(%0, %1, %3)
358+
return (%5))IR";
359+
auto in = at::randint(0, 2, {64, 2}, at::kCUDA);
360+
test_body(graph, in);
361+
}
362+
363+
TEST(Converters, ATenAllDimKeepDimConvertsCorrectly) {
364+
const auto graph = R"IR(
365+
graph(%0 : Tensor):
366+
%1 : int = prim::Constant[value=0]()
367+
%3 : bool = prim::Constant[value=1]()
368+
%5 : Tensor = aten::all(%0, %1, %3)
369+
return (%5))IR";
370+
auto in = at::randint(-2, 2, {2, 32}, at::kCUDA).to(torch::kBool);
371+
test_body(graph, in);
372+
}
373+
374+
TEST(Converters, ATenAllDimAllTrueConvertsCorrectly) {
375+
const auto graph = R"IR(
376+
graph(%0 : Tensor):
377+
%1 : int = prim::Constant[value=1]()
378+
%3 : bool = prim::Constant[value=0]()
379+
%5 : Tensor = aten::all(%0, %1, %3)
380+
return (%5))IR";
381+
auto in = at::ones({2, 32}, at::kCUDA);
382+
test_body(graph, in);
383+
}
384+
385+
TEST(Converters, ATenAllDimDynamicConvertsCorrectly) {
386+
const auto graph = R"IR(
387+
graph(%0 : Tensor):
388+
%1 : int = prim::Constant[value=-1]()
389+
%3 : bool = prim::Constant[value=0]()
390+
%5 : Tensor = aten::all(%0, %1, %3)
391+
return (%5))IR";
392+
auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf);
393+
test_body(graph, in, true);
394+
}
395+
347396
TEST(Converters, UnpackVarLowersCorrectly) {
348397
const auto graph = R"IR(
349398
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)