9
9
import time
10
10
import argparse
11
11
import logging
12
+ from contextlib import suppress
13
+ from functools import partial
14
+
12
15
import numpy as np
16
+ import pandas as pd
13
17
import torch
14
18
15
- from timm .models import create_model , apply_test_time_pool
16
- from timm .data import ImageDataset , create_loader , resolve_data_config
17
- from timm .utils import AverageMeter , setup_default_logging
19
+ from timm .models import create_model , apply_test_time_pool , load_checkpoint
20
+ from timm .data import create_dataset , create_loader , resolve_data_config
21
+ from timm .utils import AverageMeter , setup_default_logging , set_jit_fuser
22
+
23
+
24
+
25
+ try :
26
+ from apex import amp
27
+ has_apex = True
28
+ except ImportError :
29
+ has_apex = False
30
+
31
+ has_native_amp = False
32
+ try :
33
+ if getattr (torch .cuda .amp , 'autocast' ) is not None :
34
+ has_native_amp = True
35
+ except AttributeError :
36
+ pass
37
+
38
+ try :
39
+ from functorch .compile import memory_efficient_fusion
40
+ has_functorch = True
41
+ except ImportError as e :
42
+ has_functorch = False
43
+
44
+ try :
45
+ import torch ._dynamo
46
+ has_dynamo = True
47
+ except ImportError :
48
+ has_dynamo = False
49
+
18
50
19
51
torch .backends .cudnn .benchmark = True
20
52
_logger = logging .getLogger ('inference' )
23
55
parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Inference' )
24
56
parser .add_argument ('data' , metavar = 'DIR' ,
25
57
help = 'path to dataset' )
26
- parser .add_argument ('--output_dir' , metavar = 'DIR' , default = './' ,
27
- help = 'path to output files' )
58
+ parser .add_argument ('--dataset' , '-d' , metavar = 'NAME' , default = '' ,
59
+ help = 'dataset type (default: ImageFolder/ImageTar if empty)' )
60
+ parser .add_argument ('--split' , metavar = 'NAME' , default = 'validation' ,
61
+ help = 'dataset split (default: validation)' )
28
62
parser .add_argument ('--model' , '-m' , metavar = 'MODEL' , default = 'dpn92' ,
29
63
help = 'model architecture (default: dpn92)' )
30
64
parser .add_argument ('-j' , '--workers' , default = 2 , type = int , metavar = 'N' ,
31
65
help = 'number of data loading workers (default: 2)' )
32
66
parser .add_argument ('-b' , '--batch-size' , default = 256 , type = int ,
33
67
metavar = 'N' , help = 'mini-batch size (default: 256)' )
34
68
parser .add_argument ('--img-size' , default = None , type = int ,
35
- metavar = 'N' , help = 'Input image dimension' )
69
+ metavar = 'N' , help = 'Input image dimension, uses model default if empty ' )
36
70
parser .add_argument ('--input-size' , default = None , nargs = 3 , type = int ,
37
71
metavar = 'N N N' , help = 'Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty' )
72
+ parser .add_argument ('--use-train-size' , action = 'store_true' , default = False ,
73
+ help = 'force use of train input size, even when test size is specified in pretrained cfg' )
74
+ parser .add_argument ('--crop-pct' , default = None , type = float ,
75
+ metavar = 'N' , help = 'Input image center crop pct' )
76
+ parser .add_argument ('--crop-mode' , default = None , type = str ,
77
+ metavar = 'N' , help = 'Input image crop mode (squash, border, center). Model default if None.' )
38
78
parser .add_argument ('--mean' , type = float , nargs = '+' , default = None , metavar = 'MEAN' ,
39
79
help = 'Override mean pixel value of dataset' )
40
- parser .add_argument ('--std' , type = float , nargs = '+' , default = None , metavar = 'STD' ,
80
+ parser .add_argument ('--std' , type = float , nargs = '+' , default = None , metavar = 'STD' ,
41
81
help = 'Override std deviation of of dataset' )
42
82
parser .add_argument ('--interpolation' , default = '' , type = str , metavar = 'NAME' ,
43
83
help = 'Image resize interpolation type (overrides model)' )
44
- parser .add_argument ('--num-classes' , type = int , default = 1000 ,
84
+ parser .add_argument ('--num-classes' , type = int , default = None ,
45
85
help = 'Number classes in dataset' )
86
+ parser .add_argument ('--class-map' , default = '' , type = str , metavar = 'FILENAME' ,
87
+ help = 'path to class to idx mapping file (default: "")' )
46
88
parser .add_argument ('--log-freq' , default = 10 , type = int ,
47
89
metavar = 'N' , help = 'batch logging frequency (default: 10)' )
48
90
parser .add_argument ('--checkpoint' , default = '' , type = str , metavar = 'PATH' ,
51
93
help = 'use pre-trained model' )
52
94
parser .add_argument ('--num-gpu' , type = int , default = 1 ,
53
95
help = 'Number of GPUS to use' )
54
- parser .add_argument ('--no-test-pool' , dest = 'no_test_pool' , action = 'store_true' ,
55
- help = 'disable test time pool' )
56
- parser .add_argument ('--topk' , default = 5 , type = int ,
96
+ parser .add_argument ('--test-pool' , dest = 'test_pool' , action = 'store_true' ,
97
+ help = 'enable test time pool' )
98
+ parser .add_argument ('--channels-last' , action = 'store_true' , default = False ,
99
+ help = 'Use channels_last memory layout' )
100
+ parser .add_argument ('--device' , default = 'cuda' , type = str ,
101
+ help = "Device (accelerator) to use." )
102
+ parser .add_argument ('--amp' , action = 'store_true' , default = False ,
103
+ help = 'use Native AMP for mixed precision training' )
104
+ parser .add_argument ('--amp-dtype' , default = 'float16' , type = str ,
105
+ help = 'lower precision AMP dtype (default: float16)' )
106
+ parser .add_argument ('--use-ema' , dest = 'use_ema' , action = 'store_true' ,
107
+ help = 'use ema version of weights if present' )
108
+ parser .add_argument ('--fuser' , default = '' , type = str ,
109
+ help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
110
+ parser .add_argument ('--dynamo-backend' , default = None , type = str ,
111
+ help = "Select dynamo backend. Default: None" )
112
+
113
+ scripting_group = parser .add_mutually_exclusive_group ()
114
+ scripting_group .add_argument ('--torchscript' , default = False , action = 'store_true' ,
115
+ help = 'torch.jit.script the full model' )
116
+ scripting_group .add_argument ('--aot-autograd' , default = False , action = 'store_true' ,
117
+ help = "Enable AOT Autograd support." )
118
+ scripting_group .add_argument ('--dynamo' , default = False , action = 'store_true' ,
119
+ help = "Enable Dynamo optimization." )
120
+
121
+ parser .add_argument ('--results-dir' ,type = str , default = None ,
122
+ help = 'folder for output results' )
123
+ parser .add_argument ('--results-file' , type = str , default = None ,
124
+ help = 'results filename (relative to results-dir)' )
125
+ parser .add_argument ('--results-format' , type = str , default = 'csv' ,
126
+ help = 'results format (one of "csv", "json", "json-split", "parquet")' )
127
+ parser .add_argument ('--topk' , default = 1 , type = int ,
57
128
metavar = 'N' , help = 'Top-k to output to CSV' )
129
+ parser .add_argument ('--fullname' , action = 'store_true' , default = False ,
130
+ help = 'use full sample name in output (not just basename).' )
131
+ parser .add_argument ('--indices-name' , default = 'index' ,
132
+ help = 'name for output indices column(s)' )
133
+ parser .add_argument ('--outputs-name' , default = None ,
134
+ help = 'name for logit/probs output column(s)' )
135
+ parser .add_argument ('--outputs-type' , default = 'prob' ,
136
+ help = 'output type colum ("prob" for probabilities, "logit" for raw logits)' )
137
+ parser .add_argument ('--separate-columns' , action = 'store_true' , default = False ,
138
+ help = 'separate output columns per result index.' )
139
+ parser .add_argument ('--exclude-outputs' , action = 'store_true' , default = False ,
140
+ help = 'exclude logits/probs from results, just indices. topk must be set !=0.' )
58
141
59
142
60
143
def main ():
@@ -63,48 +146,109 @@ def main():
63
146
# might as well try to do something useful...
64
147
args .pretrained = args .pretrained or not args .checkpoint
65
148
149
+ if torch .cuda .is_available ():
150
+ torch .backends .cuda .matmul .allow_tf32 = True
151
+ torch .backends .cudnn .benchmark = True
152
+
153
+ device = torch .device (args .device )
154
+
155
+ # resolve AMP arguments based on PyTorch / Apex availability
156
+ use_amp = None
157
+ amp_autocast = suppress
158
+ if args .amp :
159
+ assert has_native_amp , 'Please update PyTorch to a version with native AMP (or use APEX).'
160
+ assert args .amp_dtype in ('float16' , 'bfloat16' )
161
+ amp_dtype = torch .bfloat16 if args .amp_dtype == 'bfloat16' else torch .float16
162
+ amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
163
+ _logger .info ('Running inference in mixed precision with native PyTorch AMP.' )
164
+ else :
165
+ _logger .info ('Running inference in float32. AMP not enabled.' )
166
+
167
+ if args .fuser :
168
+ set_jit_fuser (args .fuser )
169
+
66
170
# create model
67
171
model = create_model (
68
172
args .model ,
69
173
num_classes = args .num_classes ,
70
174
in_chans = 3 ,
71
175
pretrained = args .pretrained ,
72
- checkpoint_path = args .checkpoint )
176
+ checkpoint_path = args .checkpoint ,
177
+ )
178
+ if args .num_classes is None :
179
+ assert hasattr (model , 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.'
180
+ args .num_classes = model .num_classes
181
+
182
+ if args .checkpoint :
183
+ load_checkpoint (model , args .checkpoint , args .use_ema )
184
+
185
+ _logger .info (
186
+ f'Model { args .model } created, param count: { sum ([m .numel () for m in model .parameters ()])} ' )
73
187
74
- _logger .info ('Model %s created, param count: %d' %
75
- (args .model , sum ([m .numel () for m in model .parameters ()])))
188
+ data_config = resolve_data_config (vars (args ), model = model )
189
+ test_time_pool = False
190
+ if args .test_pool :
191
+ model , test_time_pool = apply_test_time_pool (model , data_config )
76
192
77
- config = resolve_data_config (vars (args ), model = model )
78
- model , test_time_pool = (model , False ) if args .no_test_pool else apply_test_time_pool (model , config )
193
+ model = model .to (device )
194
+ model .eval ()
195
+ if args .channels_last :
196
+ model = model .to (memory_format = torch .channels_last )
197
+
198
+ if args .torchscript :
199
+ model = torch .jit .script (model )
200
+ elif args .aot_autograd :
201
+ assert has_functorch , "functorch is needed for --aot-autograd"
202
+ model = memory_efficient_fusion (model )
203
+ elif args .dynamo :
204
+ assert has_dynamo , "torch._dynamo is needed for --dynamo"
205
+ torch ._dynamo .reset ()
206
+ if args .dynamo_backend is not None :
207
+ model = torch ._dynamo .optimize (args .dynamo_backend )(model )
208
+ else :
209
+ model = torch ._dynamo .optimize ()(model )
79
210
80
211
if args .num_gpu > 1 :
81
- model = torch .nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
82
- else :
83
- model = model .cuda ()
212
+ model = torch .nn .DataParallel (model , device_ids = list (range (args .num_gpu )))
213
+
214
+ dataset = create_dataset (
215
+ root = args .data ,
216
+ name = args .dataset ,
217
+ split = args .split ,
218
+ class_map = args .class_map ,
219
+ )
220
+
221
+ if test_time_pool :
222
+ data_config ['crop_pct' ] = 1.0
84
223
85
224
loader = create_loader (
86
- ImageDataset (args .data ),
87
- input_size = config ['input_size' ],
225
+ dataset ,
88
226
batch_size = args .batch_size ,
89
227
use_prefetcher = True ,
90
- interpolation = config ['interpolation' ],
91
- mean = config ['mean' ],
92
- std = config ['std' ],
93
228
num_workers = args .workers ,
94
- crop_pct = 1.0 if test_time_pool else config ['crop_pct' ])
229
+ ** data_config ,
230
+ )
95
231
96
- model .eval ()
97
-
98
- k = min (args .topk , args .num_classes )
232
+ top_k = min (args .topk , args .num_classes )
99
233
batch_time = AverageMeter ()
100
234
end = time .time ()
101
- topk_ids = []
235
+ all_indices = []
236
+ all_outputs = []
237
+ use_probs = args .outputs_type == 'prob'
102
238
with torch .no_grad ():
103
239
for batch_idx , (input , _ ) in enumerate (loader ):
104
- input = input .cuda ()
105
- labels = model (input )
106
- topk = labels .topk (k )[1 ]
107
- topk_ids .append (topk .cpu ().numpy ())
240
+
241
+ with amp_autocast ():
242
+ output = model (input )
243
+
244
+ if use_probs :
245
+ output = output .softmax (- 1 )
246
+
247
+ if top_k :
248
+ output , indices = output .topk (top_k )
249
+ all_indices .append (indices .cpu ().numpy ())
250
+
251
+ all_outputs .append (output .cpu ().numpy ())
108
252
109
253
# measure elapsed time
110
254
batch_time .update (time .time () - end )
@@ -114,13 +258,57 @@ def main():
114
258
_logger .info ('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})' .format (
115
259
batch_idx , len (loader ), batch_time = batch_time ))
116
260
117
- topk_ids = np .concatenate (topk_ids , axis = 0 )
261
+ all_indices = np .concatenate (all_indices , axis = 0 ) if all_indices else None
262
+ all_outputs = np .concatenate (all_outputs , axis = 0 ).astype (np .float32 )
263
+ filenames = loader .dataset .filenames (basename = not args .fullname )
264
+
265
+ outputs_name = args .outputs_name or ('prob' if use_probs else 'logit' )
266
+ data_dict = {'filename' : filenames }
267
+ if args .separate_columns and all_outputs .shape [- 1 ] > 1 :
268
+ if all_indices is not None :
269
+ for i in range (all_indices .shape [- 1 ]):
270
+ data_dict [f'{ args .indices_name } _{ i } ' ] = all_indices [:, i ]
271
+ for i in range (all_outputs .shape [- 1 ]):
272
+ data_dict [f'{ outputs_name } _{ i } ' ] = all_outputs [:, i ]
273
+ else :
274
+ if all_indices is not None :
275
+ if all_indices .shape [- 1 ] == 1 :
276
+ all_indices = all_indices .squeeze (- 1 )
277
+ data_dict [args .indices_name ] = list (all_indices )
278
+ if all_outputs .shape [- 1 ] == 1 :
279
+ all_outputs = all_outputs .squeeze (- 1 )
280
+ data_dict [outputs_name ] = list (all_outputs )
281
+
282
+ df = pd .DataFrame (data = data_dict )
283
+
284
+ results_filename = args .results_file
285
+ needs_ext = False
286
+ if not results_filename :
287
+ # base default filename on model name + img-size
288
+ img_size = data_config ["input_size" ][1 ]
289
+ results_filename = f'{ args .model } -{ img_size } '
290
+ needs_ext = True
118
291
119
- with open (os .path .join (args .output_dir , './topk_ids.csv' ), 'w' ) as out_file :
120
- filenames = loader .dataset .filenames (basename = True )
121
- for filename , label in zip (filenames , topk_ids ):
122
- out_file .write ('{0},{1}\n ' .format (
123
- filename , ',' .join ([ str (v ) for v in label ])))
292
+ if args .results_dir :
293
+ results_filename = os .path .join (args .results_dir , results_filename )
294
+
295
+ if args .results_format == 'parquet' :
296
+ if needs_ext :
297
+ results_filename += '.parquet'
298
+ df = df .set_index ('filename' )
299
+ df .to_parquet (results_filename )
300
+ elif args .results_format == 'json' :
301
+ if needs_ext :
302
+ results_filename += '.json'
303
+ df .to_json (results_filename , lines = True , orient = 'records' )
304
+ elif args .results_format == 'json-split' :
305
+ if needs_ext :
306
+ results_filename += '.json'
307
+ df .to_json (results_filename , indent = 4 , orient = 'split' , index = False )
308
+ else :
309
+ if needs_ext :
310
+ results_filename += '.csv'
311
+ df .to_csv (results_filename , index = False )
124
312
125
313
126
314
if __name__ == '__main__' :
0 commit comments