Skip to content

Commit 4afa72a

Browse files
committed
Support int inputs to aten::max/min and aten::argmax/argmin by casting to float (#94)
* Support int inputs to aten::max/min and aten::argmax/argmin by casting to float * correct layer name * address nit, remove local variable
1 parent 5fa6374 commit 4afa72a

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

core/conversion/converters/impl/max.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin
2222
if (dim < 0) {
2323
dim = selfDim.size() + dim;
2424
}
25+
bool int_input = self->getType() == nvinfer1::DataType::kINT32;
26+
if (int_input) {
27+
LOG_DEBUG("topk layer does not support int32 inputs, adding cast to float");
28+
self = castITensor(ctx, self, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_input");
29+
}
2530
uint32_t reduce_axes_mask = 1 << dim;
2631
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
2732
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);
@@ -44,7 +49,10 @@ bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin
4449
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
4550
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
4651
}
47-
52+
if (int_input) {
53+
LOG_DEBUG("Adding cast of topK layer output back to int32");
54+
out0 = castITensor(ctx, out0, nvinfer1::DataType::kINT32, util::node_info(n) + "_output");
55+
}
4856
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
4957
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
5058

@@ -59,6 +67,10 @@ bool arg_min_max(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin
5967
if (dim < 0) {
6068
dim = selfDim.size() + dim;
6169
}
70+
if (self->getType() == nvinfer1::DataType::kINT32) {
71+
LOG_DEBUG("topk layer does not support int32 inputs, adding cast to float");
72+
self = castITensor(ctx, self, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_input");
73+
}
6274
uint32_t reduce_axes_mask = 1 << dim;
6375
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
6476
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);

tests/core/conversion/converters/test_max.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,29 @@ TEST(Converters, ATenMaxDimConvertsCorrectly) {
2929
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
3030
}
3131

32+
TEST(Converters, ATenMaxDimIntInputConvertsCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%x.1 : Tensor):
35+
%2 : int = prim::Constant[value=0]()
36+
%3 : bool = prim::Constant[value=0]()
37+
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
38+
return (%4, %5))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, g.get());
42+
43+
auto in = at::randint(-5, 5, {5, 5}, {at::kCUDA});
44+
45+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
46+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
47+
48+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
49+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
50+
51+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
52+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6));
53+
}
54+
3255
TEST(Converters, ATenMinDimConvertsCorrectly) {
3356
const auto graph = R"IR(
3457
graph(%x.1 : Tensor):
@@ -77,6 +100,28 @@ TEST(Converters, ATenArgMaxConvertsCorrectly) {
77100
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
78101
}
79102

103+
TEST(Converters, ATenArgMaxIntInputConvertsCorrectly) {
104+
const auto graph = R"IR(
105+
graph(%x.1 : Tensor):
106+
%2 : int = prim::Constant[value=0]()
107+
%3 : bool = prim::Constant[value=0]()
108+
%4 : Tensor = aten::argmax(%x.1, %2, %3)
109+
return (%4))IR";
110+
111+
auto g = std::make_shared<torch::jit::Graph>();
112+
torch::jit::parseIR(graph, g.get());
113+
114+
auto in = at::randint(-5, 5, {5, 5}, {at::kCUDA});
115+
116+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
117+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
118+
119+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
120+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
121+
122+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
123+
}
124+
80125
TEST(Converters, ATenArgMaxKeepdimConvertsCorrectly) {
81126
const auto graph = R"IR(
82127
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)