Skip to content

Commit eceeb94

Browse files
committed
Significant ugprade to inference.py, support for different formats, formatting, etc.
1 parent 4d5c395 commit eceeb94

File tree

1 file changed

+227
-39
lines changed

1 file changed

+227
-39
lines changed

inference.py

Lines changed: 227 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,44 @@
99
import time
1010
import argparse
1111
import logging
12+
from contextlib import suppress
13+
from functools import partial
14+
1215
import numpy as np
16+
import pandas as pd
1317
import torch
1418

15-
from timm.models import create_model, apply_test_time_pool
16-
from timm.data import ImageDataset, create_loader, resolve_data_config
17-
from timm.utils import AverageMeter, setup_default_logging
19+
from timm.models import create_model, apply_test_time_pool, load_checkpoint
20+
from timm.data import create_dataset, create_loader, resolve_data_config
21+
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
22+
23+
24+
25+
try:
26+
from apex import amp
27+
has_apex = True
28+
except ImportError:
29+
has_apex = False
30+
31+
has_native_amp = False
32+
try:
33+
if getattr(torch.cuda.amp, 'autocast') is not None:
34+
has_native_amp = True
35+
except AttributeError:
36+
pass
37+
38+
try:
39+
from functorch.compile import memory_efficient_fusion
40+
has_functorch = True
41+
except ImportError as e:
42+
has_functorch = False
43+
44+
try:
45+
import torch._dynamo
46+
has_dynamo = True
47+
except ImportError:
48+
has_dynamo = False
49+
1850

1951
torch.backends.cudnn.benchmark = True
2052
_logger = logging.getLogger('inference')
@@ -23,26 +55,36 @@
2355
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
2456
parser.add_argument('data', metavar='DIR',
2557
help='path to dataset')
26-
parser.add_argument('--output_dir', metavar='DIR', default='./',
27-
help='path to output files')
58+
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
59+
help='dataset type (default: ImageFolder/ImageTar if empty)')
60+
parser.add_argument('--split', metavar='NAME', default='validation',
61+
help='dataset split (default: validation)')
2862
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
2963
help='model architecture (default: dpn92)')
3064
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
3165
help='number of data loading workers (default: 2)')
3266
parser.add_argument('-b', '--batch-size', default=256, type=int,
3367
metavar='N', help='mini-batch size (default: 256)')
3468
parser.add_argument('--img-size', default=None, type=int,
35-
metavar='N', help='Input image dimension')
69+
metavar='N', help='Input image dimension, uses model default if empty')
3670
parser.add_argument('--input-size', default=None, nargs=3, type=int,
3771
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
72+
parser.add_argument('--use-train-size', action='store_true', default=False,
73+
help='force use of train input size, even when test size is specified in pretrained cfg')
74+
parser.add_argument('--crop-pct', default=None, type=float,
75+
metavar='N', help='Input image center crop pct')
76+
parser.add_argument('--crop-mode', default=None, type=str,
77+
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
3878
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
3979
help='Override mean pixel value of dataset')
40-
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
80+
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
4181
help='Override std deviation of of dataset')
4282
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
4383
help='Image resize interpolation type (overrides model)')
44-
parser.add_argument('--num-classes', type=int, default=1000,
84+
parser.add_argument('--num-classes', type=int, default=None,
4585
help='Number classes in dataset')
86+
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
87+
help='path to class to idx mapping file (default: "")')
4688
parser.add_argument('--log-freq', default=10, type=int,
4789
metavar='N', help='batch logging frequency (default: 10)')
4890
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@@ -51,10 +93,51 @@
5193
help='use pre-trained model')
5294
parser.add_argument('--num-gpu', type=int, default=1,
5395
help='Number of GPUS to use')
54-
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
55-
help='disable test time pool')
56-
parser.add_argument('--topk', default=5, type=int,
96+
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
97+
help='enable test time pool')
98+
parser.add_argument('--channels-last', action='store_true', default=False,
99+
help='Use channels_last memory layout')
100+
parser.add_argument('--device', default='cuda', type=str,
101+
help="Device (accelerator) to use.")
102+
parser.add_argument('--amp', action='store_true', default=False,
103+
help='use Native AMP for mixed precision training')
104+
parser.add_argument('--amp-dtype', default='float16', type=str,
105+
help='lower precision AMP dtype (default: float16)')
106+
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
107+
help='use ema version of weights if present')
108+
parser.add_argument('--fuser', default='', type=str,
109+
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
110+
parser.add_argument('--dynamo-backend', default=None, type=str,
111+
help="Select dynamo backend. Default: None")
112+
113+
scripting_group = parser.add_mutually_exclusive_group()
114+
scripting_group.add_argument('--torchscript', default=False, action='store_true',
115+
help='torch.jit.script the full model')
116+
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
117+
help="Enable AOT Autograd support.")
118+
scripting_group.add_argument('--dynamo', default=False, action='store_true',
119+
help="Enable Dynamo optimization.")
120+
121+
parser.add_argument('--results-dir',type=str, default=None,
122+
help='folder for output results')
123+
parser.add_argument('--results-file', type=str, default=None,
124+
help='results filename (relative to results-dir)')
125+
parser.add_argument('--results-format', type=str, default='csv',
126+
help='results format (one of "csv", "json", "json-split", "parquet")')
127+
parser.add_argument('--topk', default=1, type=int,
57128
metavar='N', help='Top-k to output to CSV')
129+
parser.add_argument('--fullname', action='store_true', default=False,
130+
help='use full sample name in output (not just basename).')
131+
parser.add_argument('--indices-name', default='index',
132+
help='name for output indices column(s)')
133+
parser.add_argument('--outputs-name', default=None,
134+
help='name for logit/probs output column(s)')
135+
parser.add_argument('--outputs-type', default='prob',
136+
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
137+
parser.add_argument('--separate-columns', action='store_true', default=False,
138+
help='separate output columns per result index.')
139+
parser.add_argument('--exclude-outputs', action='store_true', default=False,
140+
help='exclude logits/probs from results, just indices. topk must be set !=0.')
58141

59142

60143
def main():
@@ -63,48 +146,109 @@ def main():
63146
# might as well try to do something useful...
64147
args.pretrained = args.pretrained or not args.checkpoint
65148

149+
if torch.cuda.is_available():
150+
torch.backends.cuda.matmul.allow_tf32 = True
151+
torch.backends.cudnn.benchmark = True
152+
153+
device = torch.device(args.device)
154+
155+
# resolve AMP arguments based on PyTorch / Apex availability
156+
use_amp = None
157+
amp_autocast = suppress
158+
if args.amp:
159+
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
160+
assert args.amp_dtype in ('float16', 'bfloat16')
161+
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
162+
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
163+
_logger.info('Running inference in mixed precision with native PyTorch AMP.')
164+
else:
165+
_logger.info('Running inference in float32. AMP not enabled.')
166+
167+
if args.fuser:
168+
set_jit_fuser(args.fuser)
169+
66170
# create model
67171
model = create_model(
68172
args.model,
69173
num_classes=args.num_classes,
70174
in_chans=3,
71175
pretrained=args.pretrained,
72-
checkpoint_path=args.checkpoint)
176+
checkpoint_path=args.checkpoint,
177+
)
178+
if args.num_classes is None:
179+
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
180+
args.num_classes = model.num_classes
181+
182+
if args.checkpoint:
183+
load_checkpoint(model, args.checkpoint, args.use_ema)
184+
185+
_logger.info(
186+
f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}')
73187

74-
_logger.info('Model %s created, param count: %d' %
75-
(args.model, sum([m.numel() for m in model.parameters()])))
188+
data_config = resolve_data_config(vars(args), model=model)
189+
test_time_pool = False
190+
if args.test_pool:
191+
model, test_time_pool = apply_test_time_pool(model, data_config)
76192

77-
config = resolve_data_config(vars(args), model=model)
78-
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config)
193+
model = model.to(device)
194+
model.eval()
195+
if args.channels_last:
196+
model = model.to(memory_format=torch.channels_last)
197+
198+
if args.torchscript:
199+
model = torch.jit.script(model)
200+
elif args.aot_autograd:
201+
assert has_functorch, "functorch is needed for --aot-autograd"
202+
model = memory_efficient_fusion(model)
203+
elif args.dynamo:
204+
assert has_dynamo, "torch._dynamo is needed for --dynamo"
205+
torch._dynamo.reset()
206+
if args.dynamo_backend is not None:
207+
model = torch._dynamo.optimize(args.dynamo_backend)(model)
208+
else:
209+
model = torch._dynamo.optimize()(model)
79210

80211
if args.num_gpu > 1:
81-
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
82-
else:
83-
model = model.cuda()
212+
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
213+
214+
dataset = create_dataset(
215+
root=args.data,
216+
name=args.dataset,
217+
split=args.split,
218+
class_map=args.class_map,
219+
)
220+
221+
if test_time_pool:
222+
data_config['crop_pct'] = 1.0
84223

85224
loader = create_loader(
86-
ImageDataset(args.data),
87-
input_size=config['input_size'],
225+
dataset,
88226
batch_size=args.batch_size,
89227
use_prefetcher=True,
90-
interpolation=config['interpolation'],
91-
mean=config['mean'],
92-
std=config['std'],
93228
num_workers=args.workers,
94-
crop_pct=1.0 if test_time_pool else config['crop_pct'])
229+
**data_config,
230+
)
95231

96-
model.eval()
97-
98-
k = min(args.topk, args.num_classes)
232+
top_k = min(args.topk, args.num_classes)
99233
batch_time = AverageMeter()
100234
end = time.time()
101-
topk_ids = []
235+
all_indices = []
236+
all_outputs = []
237+
use_probs = args.outputs_type == 'prob'
102238
with torch.no_grad():
103239
for batch_idx, (input, _) in enumerate(loader):
104-
input = input.cuda()
105-
labels = model(input)
106-
topk = labels.topk(k)[1]
107-
topk_ids.append(topk.cpu().numpy())
240+
241+
with amp_autocast():
242+
output = model(input)
243+
244+
if use_probs:
245+
output = output.softmax(-1)
246+
247+
if top_k:
248+
output, indices = output.topk(top_k)
249+
all_indices.append(indices.cpu().numpy())
250+
251+
all_outputs.append(output.cpu().numpy())
108252

109253
# measure elapsed time
110254
batch_time.update(time.time() - end)
@@ -114,13 +258,57 @@ def main():
114258
_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
115259
batch_idx, len(loader), batch_time=batch_time))
116260

117-
topk_ids = np.concatenate(topk_ids, axis=0)
261+
all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
262+
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
263+
filenames = loader.dataset.filenames(basename=not args.fullname)
264+
265+
outputs_name = args.outputs_name or ('prob' if use_probs else 'logit')
266+
data_dict = {'filename': filenames}
267+
if args.separate_columns and all_outputs.shape[-1] > 1:
268+
if all_indices is not None:
269+
for i in range(all_indices.shape[-1]):
270+
data_dict[f'{args.indices_name}_{i}'] = all_indices[:, i]
271+
for i in range(all_outputs.shape[-1]):
272+
data_dict[f'{outputs_name}_{i}'] = all_outputs[:, i]
273+
else:
274+
if all_indices is not None:
275+
if all_indices.shape[-1] == 1:
276+
all_indices = all_indices.squeeze(-1)
277+
data_dict[args.indices_name] = list(all_indices)
278+
if all_outputs.shape[-1] == 1:
279+
all_outputs = all_outputs.squeeze(-1)
280+
data_dict[outputs_name] = list(all_outputs)
281+
282+
df = pd.DataFrame(data=data_dict)
283+
284+
results_filename = args.results_file
285+
needs_ext = False
286+
if not results_filename:
287+
# base default filename on model name + img-size
288+
img_size = data_config["input_size"][1]
289+
results_filename = f'{args.model}-{img_size}'
290+
needs_ext = True
118291

119-
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
120-
filenames = loader.dataset.filenames(basename=True)
121-
for filename, label in zip(filenames, topk_ids):
122-
out_file.write('{0},{1}\n'.format(
123-
filename, ','.join([ str(v) for v in label])))
292+
if args.results_dir:
293+
results_filename = os.path.join(args.results_dir, results_filename)
294+
295+
if args.results_format == 'parquet':
296+
if needs_ext:
297+
results_filename += '.parquet'
298+
df = df.set_index('filename')
299+
df.to_parquet(results_filename)
300+
elif args.results_format == 'json':
301+
if needs_ext:
302+
results_filename += '.json'
303+
df.to_json(results_filename, lines=True, orient='records')
304+
elif args.results_format == 'json-split':
305+
if needs_ext:
306+
results_filename += '.json'
307+
df.to_json(results_filename, indent=4, orient='split', index=False)
308+
else:
309+
if needs_ext:
310+
results_filename += '.csv'
311+
df.to_csv(results_filename, index=False)
124312

125313

126314
if __name__ == '__main__':

0 commit comments

Comments
 (0)