File tree 1 file changed +25
-0
lines changed
1 file changed +25
-0
lines changed Original file line number Diff line number Diff line change @@ -146,6 +146,31 @@ bool Var::isITensor() const {
146
146
}
147
147
}
148
148
149
+ bool Var::isITensorList () {
150
+ // Unpack the Var as a List and check if each entry is a custom class since
151
+ // ITensors are stored in CustomClassHolder
152
+ auto ival_list = ptr_.ivalue ->toList ();
153
+ for (int i = 0 ; i < ival_list.size (); i++) {
154
+ if (!ival_list.get (i).isCustomClass ()) {
155
+ return false ;
156
+ }
157
+ }
158
+ return true ;
159
+ }
160
+
161
+ std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList () {
162
+ TORCHTRT_CHECK (
163
+ isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
164
+ TORCHTRT_CHECK (isITensorList (), " Expected IValue to be an ITensorList" );
165
+ auto ivalue_list = ptr_.ivalue ->toList ();
166
+ std::vector<nvinfer1::ITensor*> outputs;
167
+ for (int i = 0 ; i < ivalue_list.size (); i++) {
168
+ auto element = ivalue_list.get (i).toCustomClass <TensorContainer>()->tensor ();
169
+ outputs.push_back (std::move (element));
170
+ }
171
+ return outputs;
172
+ }
173
+
149
174
bool Var::isIValue () const {
150
175
if (type_ == Type::kIValue ) {
151
176
return true ;
You can’t perform that action at this time.
0 commit comments