@@ -29,6 +29,29 @@ TEST(Converters, ATenMaxDimConvertsCorrectly) {
29
29
torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ].reshape_as (jit_results[1 ]), 2e-6 ));
30
30
}
31
31
32
+ TEST (Converters, ATenMaxDimIntInputConvertsCorrectly) {
33
+ const auto graph = R"IR(
34
+ graph(%x.1 : Tensor):
35
+ %2 : int = prim::Constant[value=0]()
36
+ %3 : bool = prim::Constant[value=0]()
37
+ %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
38
+ return (%4, %5))IR" ;
39
+
40
+ auto g = std::make_shared<torch::jit::Graph>();
41
+ torch::jit::parseIR (graph, g.get ());
42
+
43
+ auto in = at::randint (-5 , 5 , {5 , 5 }, {at::kCUDA });
44
+
45
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
46
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
47
+
48
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
49
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
50
+
51
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
52
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ], 2e-6 ));
53
+ }
54
+
32
55
TEST (Converters, ATenMinDimConvertsCorrectly) {
33
56
const auto graph = R"IR(
34
57
graph(%x.1 : Tensor):
@@ -77,6 +100,28 @@ TEST(Converters, ATenArgMaxConvertsCorrectly) {
77
100
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
78
101
}
79
102
103
+ TEST (Converters, ATenArgMaxIntInputConvertsCorrectly) {
104
+ const auto graph = R"IR(
105
+ graph(%x.1 : Tensor):
106
+ %2 : int = prim::Constant[value=0]()
107
+ %3 : bool = prim::Constant[value=0]()
108
+ %4 : Tensor = aten::argmax(%x.1, %2, %3)
109
+ return (%4))IR" ;
110
+
111
+ auto g = std::make_shared<torch::jit::Graph>();
112
+ torch::jit::parseIR (graph, g.get ());
113
+
114
+ auto in = at::randint (-5 , 5 , {5 , 5 }, {at::kCUDA });
115
+
116
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
117
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
118
+
119
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
120
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
121
+
122
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
123
+ }
124
+
80
125
TEST (Converters, ATenArgMaxKeepdimConvertsCorrectly) {
81
126
const auto graph = R"IR(
82
127
graph(%x.1 : Tensor):
0 commit comments