diff --git a/fvcore/nn/flop_count.py b/fvcore/nn/flop_count.py index 6e5043e..9488f04 100644 --- a/fvcore/nn/flop_count.py +++ b/fvcore/nn/flop_count.py @@ -2,12 +2,13 @@ # pyre-ignore-all-errors[2,33] from collections import defaultdict +from copy import deepcopy from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union import torch.nn as nn from torch import Tensor -from .jit_analysis import JitModelAnalysis +from .jit_analysis import JitModelAnalysis, Statistics from .jit_handles import ( Handle, addmm_flop_jit, @@ -31,6 +32,9 @@ "aten::matmul": matmul_flop_jit, "aten::mm": matmul_flop_jit, "aten::linear": linear_flop_jit, + # Flops for the following ops are just estimates as they are not very well + # defined and don't correlate with wall time very much. They shouldn't take + # a big portion of any model anyway. # You might want to ignore BN flops due to inference-time fusion. # Use `set_op_handle("aten::batch_norm", None) "aten::batch_norm": batchnorm_flop_jit, @@ -38,25 +42,31 @@ "aten::layer_norm": norm_flop_counter(2), "aten::instance_norm": norm_flop_counter(1), "aten::upsample_nearest2d": elementwise_flop_counter(0, 1), - "aten::upsample_bilinear2d": elementwise_flop_counter(0, 4), - "aten::adaptive_avg_pool2d": elementwise_flop_counter(1, 0), - "aten::grid_sampler": elementwise_flop_counter(0, 4), # assume bilinear + "aten::upsample_bilinear2d": elementwise_flop_counter(0, 8), + "aten::adaptive_avg_pool2d": elementwise_flop_counter(2, 0), + "aten::grid_sampler": elementwise_flop_counter(0, 8), # assume bilinear } class FlopCountAnalysis(JitModelAnalysis): """ - Provides access to per-submodule model flop count obtained by - tracing a model with pytorch's jit tracing functionality. By default, - comes with standard flop counters for a few common operators. - Note that: + Provides access to per-submodule flop count obtained by tracing a model + with pytorch's jit tracing functionality. By default, comes with standard + flop counters for a few common operators. - 1. Flop is not a well-defined concept. We just produce our best estimate. - 2. We count one fused multiply-add as one flop. + Flop represents floating point operations. Another common metric is MAC + (multiply-add count), which represents a multiply and an add operations. + We count MAC (multiply-add counts) by default, but this can be changed + by `set_use_mac(False)`. We just assume MAC is half of flops, which + is true for most expensive operators we care. + + Note that flop/MAC is not a well-defined concept for many ops. We just produce + our best estimate. Handles for additional operators may be added, or the default ones overwritten, using the ``.set_op_handle(name, func)`` method. See the method documentation for details. + The handler for each op should always calculate flops instead of MAC. Flop counts can be obtained as: @@ -112,6 +122,28 @@ def __init__( ) -> None: super().__init__(model=model, inputs=inputs) self.set_op_handle(**_DEFAULT_SUPPORTED_OPS) + self._use_mac = True # NOTE: maybe we'll want to change the default to False + + def set_use_mac(self, enabled: bool) -> "FlopCountAnalysis": + """ + Decide whether to count MAC (multiply-add counts) rather than flops. + Default to True because this is the convention in many computer vision papers. + Unfortunately this concept is typically misused as flops. + + To implement counting of MAC, we simply assume MAC is half of flops. + Although we note that this is not true for all ops. + """ + self._use_mac = enabled + return self + + def _analyze(self) -> Statistics: + stats = super()._analyze() + if self._use_mac: + stats = deepcopy(stats) + for v in stats.counts.values(): + for k in list(v.keys()): + v[k] = v[k] // 2 + return stats __init__.__doc__ = JitModelAnalysis.__init__.__doc__ @@ -122,8 +154,10 @@ def flop_count( supported_ops: Optional[Dict[str, Handle]] = None, ) -> Tuple[DefaultDict[str, float], Counter[str]]: """ - Given a model and an input to the model, compute the per-operator Gflops - of the given model. + Given a model and an input to the model, compute the per-operator GMACs + (10^9 multiply-adds) of the given model. + + For more features and customized counting, please use :class:`FlopCountAnalysis`. Args: model (nn.Module): The model to compute flop counts. @@ -132,18 +166,21 @@ def flop_count( supported_ops (dict(str,Callable) or None) : provide additional handlers for extra ops, or overwrite the existing handlers for convolution and matmul and einsum. The key is operator name and the value - is a function that takes (inputs, outputs) of the op. We count - one Multiply-Add as one FLOP. + is a function that takes (inputs, outputs) of the op. Returns: tuple[defaultdict, Counter]: A dictionary that records the number of - gflops for each operation and a Counter that records the number of + GMACs for each operation and a Counter that records the number of unsupported operations. """ if supported_ops is None: supported_ops = {} - flop_counter = FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops) - giga_flops = defaultdict(float) - for op, flop in flop_counter.by_operator().items(): - giga_flops[op] = flop / 1e9 - return giga_flops, flop_counter.unsupported_ops() + mac_counter = ( + FlopCountAnalysis(model, inputs) # pyre-ignore + .set_op_handle(**supported_ops) + .set_use_mac(True) + ) + giga_macs = defaultdict(float) + for op, mac in mac_counter.by_operator().items(): + giga_macs[op] = mac / 1e9 + return giga_macs, mac_counter.unsupported_ops() diff --git a/fvcore/nn/jit_handles.py b/fvcore/nn/jit_handles.py index d03b128..6fef110 100644 --- a/fvcore/nn/jit_handles.py +++ b/fvcore/nn/jit_handles.py @@ -88,7 +88,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: batch_size, input_dim = input_shapes[0] output_dim = input_shapes[1][1] flops = batch_size * input_dim * output_dim - return flops + return flops * 2 def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: @@ -102,7 +102,7 @@ def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # input_shapes[1]: [output_feature_dim, input_feature_dim] assert input_shapes[0][-1] == input_shapes[1][-1] flops = prod(input_shapes[0]) * input_shapes[1][0] - return flops + return flops * 2 def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: @@ -116,7 +116,7 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: n, c, t = input_shapes[0] d = input_shapes[-1][-1] flop = n * c * t * d - return flop + return flop * 2 def conv_flop_count( @@ -137,7 +137,7 @@ def conv_flop_count( out_size = prod(out_shape[2:]) kernel_size = prod(w_shape[2:]) flop = batch_size * out_size * Cout_dim * Cin_dim * kernel_size - return flop + return flop * 2 def conv_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: @@ -181,20 +181,19 @@ def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: n, c, t = input_shapes[0] p = input_shapes[-1][-1] flop = n * c * t * p - return flop + return flop * 2 elif equation == "abc,adc->adb": n, t, g = input_shapes[0] c = input_shapes[-1][1] flop = n * t * g * c - return flop + return flop * 2 else: np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): - # divided by 2 because we count MAC (multiply-add counted as one flop) - flop = float(np.floor(float(line.split(":")[-1]) / 2)) + flop = float(line.split(":")[-1].strip()) return flop raise NotImplementedError("Unsupported einsum operation.") @@ -209,7 +208,7 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: assert len(input_shapes) == 2, input_shapes assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes flop = prod(input_shapes[0]) * input_shapes[-1][-1] - return flop + return flop * 2 def norm_flop_counter(affine_arg_index: int) -> Handle: @@ -226,8 +225,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: input_shape = get_shape(inputs[0]) has_affine = get_shape(inputs[affine_arg_index]) is not None assert 2 <= len(input_shape) <= 5, input_shape - # 5 is just a rough estimate - flop = prod(input_shape) * (5 if has_affine else 4) + # 5 or 7 is just a rough estimate: + # 3 - compute E[x] and E[x^2] + # 2 - compute normalization + # 2 - compute affine + flop = prod(input_shape) * (7 if has_affine else 5) return flop return norm_flop_jit @@ -240,7 +242,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: return norm_flop_counter(1)(inputs, outputs) # pyre-ignore has_affine = get_shape(inputs[1]) is not None input_shape = prod(get_shape(inputs[0])) - return input_shape * (2 if has_affine else 1) + return input_shape * (4 if has_affine else 2) def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Handle: diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py index 8b5ed97..1f294a7 100644 --- a/tests/test_flop_count.py +++ b/tests/test_flop_count.py @@ -186,7 +186,7 @@ def dummy_sigmoid_flop_jit( custom_ops: Dict[str, Handle] = {"aten::sigmoid": dummy_sigmoid_flop_jit} x = torch.rand(batch_size, input_dim) flop_dict1, _ = flop_count(customNet, (x,), supported_ops=custom_ops) - flop_sigmoid = 10000 / 1e9 + flop_sigmoid = 10000 / 1e9 / 2 self.assertEqual( flop_dict1["sigmoid"], flop_sigmoid, @@ -211,7 +211,7 @@ def addmm_dummy_flop_jit( "aten::{}".format(self.lin_op): addmm_dummy_flop_jit } flop_dict2, _ = flop_count(customNet, (x,), supported_ops=custom_ops2) - flop = 400000 / 1e9 + flop = 400000 / 1e9 / 2 self.assertEqual( flop_dict2[self.lin_op], flop, @@ -632,7 +632,7 @@ def test_batchnorm(self) -> None: batch_2d = nn.BatchNorm2d(input_dim, affine=False) x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y) flop_dict, _ = flop_count(batch_2d, (x,)) - gt_flop = 4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y / 1e9 + gt_flop = 2.5 * batch_size * input_dim * spatial_dim_x * spatial_dim_y / 1e9 gt_dict = defaultdict(float) gt_dict["batch_norm"] = gt_flop self.assertDictEqual( @@ -651,7 +651,7 @@ def test_batchnorm(self) -> None: ) flop_dict, _ = flop_count(batch_3d, (x,)) gt_flop = ( - 4 + 2.5 * batch_size * input_dim * spatial_dim_x @@ -740,14 +740,14 @@ def test_batch_norm(self): nodes = self._count_function( F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec), op_name ) - self.assertEqual(counter(*nodes), 32) + self.assertEqual(counter(*nodes), 64) nodes = self._count_function( F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, None, None), op_name, ) - self.assertEqual(counter(*nodes), 16) + self.assertEqual(counter(*nodes), 32) nodes = self._count_function( # training=True @@ -755,7 +755,7 @@ def test_batch_norm(self): (torch.rand(2, 2, 2, 2), vec, vec, vec, vec, True), op_name, ) - self.assertEqual(counter(*nodes), 80) + self.assertEqual(counter(*nodes), 112) def test_group_norm(self): op_name = "aten::group_norm" @@ -765,12 +765,12 @@ def test_group_norm(self): nodes = self._count_function( F.group_norm, (torch.rand(2, 2, 2, 2), 2, vec, vec), op_name ) - self.assertEqual(counter(*nodes), 80) + self.assertEqual(counter(*nodes), 112) nodes = self._count_function( F.group_norm, (torch.rand(2, 2, 2, 2), 2, None, None), op_name ) - self.assertEqual(counter(*nodes), 64) + self.assertEqual(counter(*nodes), 80) def test_upsample(self): op_name = "aten::upsample_bilinear2d" @@ -779,7 +779,7 @@ def test_upsample(self): nodes = self._count_function( F.interpolate, (torch.rand(2, 2, 2, 2), None, 2, "bilinear", False), op_name ) - self.assertEqual(counter(*nodes), 2 ** 4 * 4 * 4) + self.assertEqual(counter(*nodes), 2 ** 4 * 4 * 4 * 2) def test_complicated_einsum(self): op_name = "aten::einsum" @@ -790,7 +790,7 @@ def test_complicated_einsum(self): ("nc,nchw->hw", torch.rand(3, 4), torch.rand(3, 4, 2, 3)), op_name, ) - self.assertEqual(counter(*nodes), 72.0) + self.assertEqual(counter(*nodes), 145.0) def test_torch_mm(self): for op_name, func in zip( @@ -803,4 +803,4 @@ def test_torch_mm(self): (torch.rand(3, 4), torch.rand(4, 5)), op_name, ) - self.assertEqual(counter(*nodes), 60) + self.assertEqual(counter(*nodes), 120) diff --git a/tests/test_jit_model_analysis.py b/tests/test_jit_model_analysis.py index 2e336f5..510e5f2 100644 --- a/tests/test_jit_model_analysis.py +++ b/tests/test_jit_model_analysis.py @@ -620,7 +620,7 @@ def test_changing_handles(self) -> None: "aten::linear": linear_flop_jit, } # type: Dict[str, Handle] - analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle( + analyzer = FlopCountAnalysis(model=model, inputs=inputs).set_op_handle( **op_handles ) analyzer.unsupported_ops_warnings(enabled=False) @@ -638,7 +638,7 @@ def make_dummy_op(name: str, output: int) -> Handle: def dummy_ops_handle( inputs: List[Any], outputs: List[Any] ) -> typing.Counter[str]: - return Counter({name: output}) + return Counter({name: output * 2}) return dummy_ops_handle @@ -725,10 +725,10 @@ def test_copy(self) -> None: non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops # Total is correct for new model and inputs - self.assertEqual(analyzer_new.total(), non_forward_flops * bs) + self.assertEqual(analyzer_new.total(), non_forward_flops * bs * 2) # Original is unaffected - self.assertEqual(analyzer.total(), repeated_net_flops) + self.assertEqual(analyzer.total(), repeated_net_flops * 2) # Settings match self.assertEqual(