@@ -9,6 +9,36 @@ namespace converters {
9
9
namespace impl {
10
10
namespace {
11
11
12
+ nvinfer1::ITensor* anyDimImplementation (
13
+ ConversionCtx* ctx,
14
+ const torch::jit::Node* n,
15
+ nvinfer1::ITensor* in_tensor,
16
+ int dim,
17
+ bool keepdim) {
18
+ auto in_dims = in_tensor->getDimensions ();
19
+ LOG_DEBUG (" Dim to reduce (original): " << dim);
20
+ dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
21
+ LOG_DEBUG (" Dim to reduce (converted): " << dim);
22
+
23
+ uint32_t axis_mask = 1 << dim;
24
+ LOG_DEBUG (" Axis Mask: " << std::bitset<32 >(axis_mask));
25
+ LOG_DEBUG (" Keep dims: " << keepdim);
26
+
27
+ // Reduce does not work on bool inputs
28
+ if (in_tensor->getType () == nvinfer1::DataType::kBOOL ) {
29
+ in_tensor = castITensor (ctx, in_tensor, nvinfer1::DataType::kINT32 , (util::node_info (n) + " _in" ).c_str ());
30
+ }
31
+ auto sum_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kSUM , axis_mask, keepdim);
32
+
33
+ TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
34
+
35
+ sum_layer->setName (util::node_info (n).c_str ());
36
+ auto out_tensor =
37
+ castITensor (ctx, sum_layer->getOutput (0 ), nvinfer1::DataType::kBOOL , (util::node_info (n) + " _out" ).c_str ());
38
+ out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
39
+ return out_tensor;
40
+ }
41
+
12
42
auto reduce_registrations TORCHTRT_UNUSED =
13
43
RegisterNodeConversionPatterns ()
14
44
.pattern(
@@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED =
224
254
{" aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor" ,
225
255
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
226
256
auto in_tensor = args[0 ].ITensorOrFreeze (ctx);
227
- auto in_dims = in_tensor->getDimensions ();
228
257
auto dim = args[1 ].unwrapToInt ();
229
- LOG_DEBUG (" Dim to reduce (original): " << dim);
230
- dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
231
- LOG_DEBUG (" Dim to reduce (converted): " << dim);
232
-
233
- uint32_t axis_mask = 1 << dim;
234
- LOG_DEBUG (" Axis Mask: " << std::bitset<32 >(axis_mask));
235
-
236
258
auto keepdim = args[2 ].unwrapToBool ();
237
- LOG_DEBUG (" Keep dims: " << keepdim);
238
-
239
- // Reduce does not work on bool inputs
240
- if (in_tensor->getType () == nvinfer1::DataType::kBOOL ) {
241
- in_tensor =
242
- castITensor (ctx, in_tensor, nvinfer1::DataType::kINT32 , (util::node_info (n) + " _in" ).c_str ());
243
- }
244
- auto sum_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kSUM , axis_mask, keepdim);
245
-
246
- TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
247
-
248
- sum_layer->setName (util::node_info (n).c_str ());
249
- auto out_tensor = castITensor (
250
- ctx, sum_layer->getOutput (0 ), nvinfer1::DataType::kBOOL , (util::node_info (n) + " _out" ).c_str ());
259
+ auto out_tensor = anyDimImplementation (ctx, n, in_tensor, dim, keepdim);
251
260
out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
252
261
LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
253
262
return true ;
263
+ }})
264
+ .pattern(
265
+ {" aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor" ,
266
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
267
+ // use Not(Any(Not(input))) to calculate all without a direct all reduction
268
+ auto in_tensor = args[0 ].ITensorOrFreeze (ctx);
269
+ auto dim = args[1 ].unwrapToInt ();
270
+ auto keepdim = args[2 ].unwrapToBool ();
271
+ if (in_tensor->getType () != nvinfer1::DataType::kBOOL ) {
272
+ // unary not layer only supports bool inputs
273
+ in_tensor = castITensor (
274
+ ctx, in_tensor, nvinfer1::DataType::kBOOL , (util::node_info (n) + " _in_to_bool" ).c_str ());
275
+ }
276
+ auto not_input_layer = ctx->net ->addUnary (*in_tensor, nvinfer1::UnaryOperation::kNOT );
277
+ TORCHTRT_CHECK (not_input_layer, " Unable to create logical_not layer from node: " << *n);
278
+ not_input_layer->setName ((util::node_info (n) + " _not_in" ).c_str ());
279
+ auto not_in = not_input_layer->getOutput (0 );
280
+ auto any_out = anyDimImplementation (ctx, n, not_in, dim, keepdim);
281
+ auto not_output_layer = ctx->net ->addUnary (*any_out, nvinfer1::UnaryOperation::kNOT );
282
+ TORCHTRT_CHECK (not_output_layer, " Unable to create logical_not layer from node: " << *n);
283
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], not_output_layer->getOutput (0 ));
284
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
285
+ return true ;
254
286
}});
255
287
} // namespace
256
288
} // namespace impl
0 commit comments