@@ -83,6 +83,71 @@ TEST(Evaluators, FullEvaluatesCorrectly) {
83
83
ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
84
84
}
85
85
86
+ TEST (Evaluators, FullLikeEvaluatesCorrectly) {
87
+ const auto graph = R"IR(
88
+ graph(%x.1 : Tensor):
89
+ %9 : None = prim::Constant()
90
+ %13 : float = prim::Constant[value=1.3]()
91
+ %14 : int = prim::Constant[value=4]()
92
+ %35 : Device = prim::Constant[value="cuda:0"]()
93
+ %19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
94
+ return (%19))IR" ;
95
+
96
+ auto in = at::randint (1 , 10 , {1 , 2 , 3 , 5 }, {at::kCUDA });
97
+
98
+ auto g = std::make_shared<torch::jit::Graph>();
99
+ torch::jit::parseIR (graph, g.get ());
100
+
101
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
102
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
103
+
104
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
105
+ ASSERT_TRUE (jit_results[0 ].toTensor ().dtype () == trt_results[0 ].toTensor ().dtype ());
106
+ }
107
+
108
+ TEST (Evaluators, FullLikeNewDtypeEvaluatesCorrectly) {
109
+ const auto graph = R"IR(
110
+ graph(%x.1 : Tensor):
111
+ %9 : None = prim::Constant()
112
+ %13 : Scalar = prim::Constant[value=1]()
113
+ %14 : int = prim::Constant[value=11]()
114
+ %35 : Device = prim::Constant[value="cuda:0"]()
115
+ %19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
116
+ return (%19))IR" ;
117
+
118
+ auto in = at::randint (1 , 10 , {1 , 2 , 3 , 5 }, {at::kCUDA });
119
+
120
+ auto g = std::make_shared<torch::jit::Graph>();
121
+ torch::jit::parseIR (graph, g.get ());
122
+
123
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
124
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
125
+
126
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
127
+ ASSERT_TRUE (jit_results[0 ].toTensor ().dtype () == trt_results[0 ].toTensor ().dtype ());
128
+ }
129
+
130
+ TEST (Evaluators, FullLikeOldDtypeEvaluatesCorrectly) {
131
+ const auto graph = R"IR(
132
+ graph(%x.1 : Tensor):
133
+ %9 : None = prim::Constant()
134
+ %13 : Scalar = prim::Constant[value=1.5]()
135
+ %35 : Device = prim::Constant[value="cuda:0"]()
136
+ %19 : Tensor = aten::full_like(%x.1, %13, %9, %9, %35, %9, %9)
137
+ return (%19))IR" ;
138
+
139
+ auto in = at::randint (1 , 10 , {1 , 2 , 3 , 5 }, {at::kCUDA }).to (torch::kInt32 );
140
+
141
+ auto g = std::make_shared<torch::jit::Graph>();
142
+ torch::jit::parseIR (graph, g.get ());
143
+
144
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
145
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
146
+
147
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
148
+ ASSERT_TRUE (jit_results[0 ].toTensor ().dtype () == trt_results[0 ].toTensor ().dtype ());
149
+ }
150
+
86
151
TEST (Evaluators, OnesDataTypeEvaluatesCorrectly) {
87
152
const auto graph = R"IR(
88
153
graph(%x.1 : Tensor):
0 commit comments