Skip to content

Commit 9d84f9e

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Fix fvcore flops counting for torchvision models
Summary: For now, we only support fvcore flops counting for torchvision models. Reviewed By: FindHao Differential Revision: D47672523 fbshipit-source-id: dce911484ec63823272d49dea0b49be5e2f62398
1 parent 79bc754 commit 9d84f9e

File tree

4 files changed

+13
-21
lines changed

4 files changed

+13
-21
lines changed

run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def printResultSummaryTime(result_summary, metrics_needed=[], model=None, flops_
101101
if flops_model_analyzer.metrics_backend_mapping['flops'] == 'dcgm':
102102
tflops_device_id, tflops = flops_model_analyzer.calculate_flops()
103103
else:
104-
flops, batch_size = model.get_flops()
105-
tflops = flops * batch_size / (cpu_walltime / 1.0e3) / 1.0e12
106-
print('{:<20} {:>20}'.format("GPU %d FLOPS:" % tflops_device_id, "%.4f TFLOPs per second" % tflops, sep=''))
104+
flops = model.get_flops()
105+
tflops = flops / (cpu_walltime / 1.0e3) / 1.0e12
106+
print('{:<20} {:>20}'.format("GPU FLOPS:", "%.4f TFLOPs per second" % tflops, sep=''))
107107
if gpu_peak_mem is not None:
108108
print('{:<20} {:>20}'.format("GPU %d Peak Memory:" % mem_device_id, "%.4f GB" % gpu_peak_mem, sep=''))
109109
if cpu_peak_mem is not None:

torchbenchmark/util/backends/flops.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

torchbenchmark/util/extra_args.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import enum
33
from typing import List, Optional, Tuple
44
from torchbenchmark.util.backends import list_backends, BACKENDS
5-
6-
from torchbenchmark.util.backends.flops import enable_fvcore_flops
7-
from torchbenchmark.util.env_check import is_torchvision_model, is_staged_train_test
5+
from torchbenchmark.util.env_check import is_staged_train_test
86

97
TEST_STAGE = enum.Enum('TEST_STAGE', ['FORWARD', 'BACKWARD', 'OPTIMIZER', 'ALL'])
108
AVAILABLE_PRECISIONS = ["fp32", "tf32", "fp16", "amp", "fx_int8", "bf16","amp_fp16", "amp_bf16"]
@@ -127,7 +125,6 @@ def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dar
127125
def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: List[str]) -> argparse.Namespace:
128126
parser = argparse.ArgumentParser()
129127
parser.add_argument("--backend", choices=list_backends(), help="enable backends")
130-
parser.add_argument("--flops", choices=["fvcore", "dcgm"], help="Return the flops result")
131128
args, extra_args = parser.parse_known_args(opt_args)
132129
if model.jit:
133130
args.backend = "torchscript"
@@ -137,7 +134,5 @@ def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args:
137134
return args, extra_args
138135

139136
def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace):
140-
if args.flops == "fvcore":
141-
enable_fvcore_flops(model)
142137
if args.backend:
143138
model._enable_backend()

torchbenchmark/util/framework/vision/model_factory.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,15 @@ def __init__(self, model_name, test, device, jit=False, batch_size=None, weights
4848
self.real_output = ( torch.rand_like(self.example_outputs), )
4949

5050
def get_flops(self):
51-
return self.flops, self.batch_size
51+
# By default, FlopCountAnalysis count one fused-mult-add (FMA) as one flop.
52+
# However, in our context, we count 1 FMA as 2 flops instead of 1.
53+
# https://github.com/facebookresearch/fvcore/blob/7a0ef0c0839fa0f5e24d2ef7f5d48712f36e7cd7/fvcore/nn/flop_count.py
54+
assert self.test == "eval", "fvcore flops is only available on inference tests, as it doesn't measure backward pass."
55+
from fvcore.nn import FlopCountAnalysis
56+
FLOPS_FMA = 2.0
57+
self.flops = FlopCountAnalysis(self.model, tuple(self.example_inputs)).total()
58+
self.flops = self.flops * FLOPS_FMA
59+
return self.flops
5260

5361
def gen_inputs(self, num_batches:int=1) -> Tuple[Generator, Optional[int]]:
5462
def _gen_inputs():
@@ -96,4 +104,3 @@ def cudagraph_eval(self):
96104
self.g.replay()
97105
break
98106
return (self.example_outputs, )
99-

0 commit comments

Comments
 (0)