@@ -67,28 +67,6 @@ std::vector<const torch::jit::Value*> get_tensor_inputs(
67
67
if (in->type ()->isSubtypeOf (c10::TensorType::get ()) && static_params.find (in) == static_params.end ()) {
68
68
input_tensors.push_back (in);
69
69
}
70
- // else if (in->type()->cast<c10::TupleType>() && static_params.find(in) == static_params.end()) {
71
- // // } else if (in->type()->isSubtypeOf(c10::TupleType::create()) && static_params.find(in) == static_params.end()) {
72
- // at::ArrayRef<torch::jit::Value*> unpack_tuple = torch::jit::createTupleUnpack(in);
73
- // LOG_DEBUG("Tuple size " << unpack_tuple.size());
74
- // for (auto item: unpack_tuple) {
75
- // input_tensors.push_back(in);
76
- // }
77
- // } else if (in->type()->isSubtypeOf(c10::ListType::ofTensors()) && static_params.find(in) == static_params.end()) {
78
-
79
- // LOG_DEBUG("List use size " << in->uses().size());
80
- // // for (auto use : in->uses()) {
81
- // // LOG_DEBUG(use.user->outputs()[0]->debugName());
82
- // // }
83
- // // TODO: set the correct list number according to the Input IValue
84
- // int n = 2;
85
- // auto unpack_node = g->createListUnpack(in, n);
86
- // g->block()->appendNode(unpack_node);
87
- // for (auto item: unpack_node->outputs()) {
88
- // input_tensors.push_back(item);
89
- // }
90
- // LOG_DEBUG("Unpack List of size " << n);
91
- // }
92
70
}
93
71
return input_tensors;
94
72
}
@@ -101,11 +79,6 @@ std::vector<const torch::jit::Value*> get_collection_inputs(
101
79
LOG_DEBUG (" get_collection_inputs, inputs size " << inputs.size ());
102
80
for (auto in : inputs) {
103
81
LOG_DEBUG (" input debug name: " << in->debugName ());
104
- // Disregarding inputs that are not tensors or are static
105
- //
106
- // Ex.
107
- // self.1:__torch__.alexnet -> ignored
108
- // input.1:Tensor -> used
109
82
if (in->type ()->isSubtypeOf (c10::TensorType::get ()) && static_params.find (in) == static_params.end ()) {
110
83
input_tensors.push_back (in);
111
84
} else if (in->type ()->kind () == torch::jit::TypeKind::TupleType && static_params.find (in) == static_params.end ()) {
@@ -242,21 +215,18 @@ CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block*
242
215
if (i->type () == c10::TensorType::get ()) {
243
216
torch::jit::Value* in = i;
244
217
types.insert ({in, {get_value_first_calc_dtype_opt (b, i)}});
218
+
245
219
} else if (i->type ()->kind () == torch::jit::TypeKind::TupleType) {
246
220
LOG_DEBUG (" get_block_first_calc_dtypes_opt_collection TupleType" );
247
221
// TODO: to evaluate the data type of tuple element
248
222
// make sure very time get the same ptr
249
223
c10::optional<at::ScalarType> tp = get_value_first_calc_dtype_opt (b, i);
250
224
at::ArrayRef<torch::jit::Value*> unpack_tuple = torch::jit::createTupleUnpack (i);
251
225
LOG_DEBUG (" get_block_first_calc_dtypes_opt_collection: tuple size " << unpack_tuple.size ());
252
- // Assume all tuple has the same datatype
226
+ // TODO: calculate the tuple element type
253
227
// std::vector<c10::optional<at::ScalarType>> dytpes(unpack_tuple.size(), tp);
254
228
std::vector<c10::optional<at::ScalarType>> dytpes (unpack_tuple.size ());
255
229
types.insert ({i, dytpes}); // insert an empty
256
- // for (auto item: unpack_tuple) {
257
- // torch::jit::Value* in = item;
258
- // types.insert({in, get_value_first_calc_dtype_opt(b, i)});
259
- // }
260
230
261
231
} else if (i->type ()->kind () == torch::jit::TypeKind::ListType) {
262
232
// TODO: to decide the size of list and type of list element
@@ -265,7 +235,6 @@ CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block*
265
235
// std::vector<c10::optional<at::ScalarType>> dytpes(i->uses().size());
266
236
std::vector<c10::optional<at::ScalarType>> dytpes (i->uses ().size (), tp);
267
237
types.insert ({i, dytpes}); // insert an empty
268
-
269
238
}
270
239
}
271
240
return types;
0 commit comments