Skip to content

Commit

Permalink
Add params and MACs units specifiers (#188)
Browse files Browse the repository at this point in the history
* Add params and MACs units specifiers

* Use enums

* Add test case

Co-authored-by: Tyler Yep <tyler.yep@robinhood.com>
  • Loading branch information
richardtml and TylerYep authored Nov 7, 2022
1 parent 0c1ccff commit 8b3ae72
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 24 deletions.
2 changes: 1 addition & 1 deletion tests/test_output/dict_parameters_1.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DictParameter [10, 1] --
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/dict_parameters_2.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DictParameter [10, 1] --
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/dict_parameters_3.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DictParameter [10, 1] --
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Expand Down
40 changes: 40 additions & 0 deletions tests/test_output/formatting_options.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SingleInputNet [16, 10] --
├─Conv2d: 1-1 [16, 10, 24, 24] 260
├─Conv2d: 1-2 [16, 20, 8, 8] 5,020
├─Dropout2d: 1-3 [16, 20, 8, 8] --
├─Linear: 1-4 [16, 50] 16,050
├─Linear: 1-5 [16, 10] 510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds: 7,801,600
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05
==========================================================================================
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SingleInputNet [16, 10] --
├─Conv2d: 1-1 [16, 10, 24, 24] 260
├─Conv2d: 1-2 [16, 20, 8, 8] 5,020
├─Dropout2d: 1-3 [16, 20, 8, 8] --
├─Linear: 1-4 [16, 50] 16,050
├─Linear: 1-5 [16, 10] 510
==========================================================================================
Total params (T): 0.00
Trainable params (T): 0.00
Non-trainable params (T): 0
Total mult-adds (T): 0.00
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05
==========================================================================================
2 changes: 1 addition & 1 deletion tests/test_output/jit.out
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ LinearModel -- --
Total params: 33,153
Trainable params: 33,153
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.03
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/namedtuple.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ NamedTuple [2, 1, 28, 28] --
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/parameter_list.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ParameterListModel -- [100, 100]
Total params: 30,000
Trainable params: 30,000
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
================================================================================================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/parameters_with_other_layers.out
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ParameterFCNet [3, 64] 8,256
Total params: 8,256
Trainable params: 8,256
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/partial_jit.out
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ PartialJITModel -- --
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.00
Expand Down
2 changes: 1 addition & 1 deletion tests/test_output/uninitialized_tensor.out
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ UninitializedParameterModel [2, 2] 128
Total params: 128
Trainable params: 128
Non-trainable params: 0
Total mult-adds (M): 0.00
Total mult-adds (M): 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Expand Down
14 changes: 13 additions & 1 deletion tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
UninitializedParameterModel,
)
from torchinfo import ColumnSettings, summary
from torchinfo.enums import Verbosity
from torchinfo.enums import Units, Verbosity


def test_basic_summary() -> None:
Expand Down Expand Up @@ -175,6 +175,18 @@ def test_row_settings() -> None:
summary(model, input_size=(16, 1, 28, 28), row_settings=("var_names",))


def test_formatting_options() -> None:
model = SingleInputNet()

results = summary(model, input_size=(16, 1, 28, 28), verbose=0)
results.formatting.macs_units = Units.NONE
print(results)

results.formatting.params_units = Units.TERABYTES
results.formatting.macs_units = Units.TERABYTES
print(results)


def test_jit() -> None:
model = LinearModel()
model_jit = torch.jit.script(model)
Expand Down
3 changes: 2 additions & 1 deletion torchinfo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .enums import ColumnSettings, Mode, RowSettings, Verbosity
from .enums import ColumnSettings, Mode, RowSettings, Units, Verbosity
from .model_statistics import ModelStatistics
from .torchinfo import summary

Expand All @@ -8,6 +8,7 @@
"Mode",
"ModelStatistics",
"RowSettings",
"Units",
"Verbosity",
)
__version__ = "1.7.1"
11 changes: 11 additions & 0 deletions torchinfo/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class ColumnSettings(str, Enum):
TRAINABLE = "trainable"


@unique
class Units(str, Enum):
"""Enum containing all available bytes units."""

AUTO = "auto"
MEGABYTES = "M"
GIGABYTES = "G"
TERABYTES = "T"
NONE = ""


@unique
class Verbosity(IntEnum):
"""Contains verbosity levels."""
Expand Down
10 changes: 9 additions & 1 deletion torchinfo/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
from typing import Any

from .enums import ColumnSettings, RowSettings, Verbosity
from .enums import ColumnSettings, RowSettings, Units, Verbosity
from .layer_info import LayerInfo

HEADER_TITLES = {
Expand All @@ -14,6 +14,12 @@
ColumnSettings.MULT_ADDS: "Mult-Adds",
ColumnSettings.TRAINABLE: "Trainable",
}
CONVERSION_FACTORS = {
Units.TERABYTES: 1e12,
Units.GIGABYTES: 1e9,
Units.MEGABYTES: 1e6,
Units.NONE: 1,
}


class FormattingOptions:
Expand All @@ -32,6 +38,8 @@ def __init__(
self.col_names = col_names
self.col_width = col_width
self.row_settings = row_settings
self.params_units = Units.NONE
self.macs_units = Units.AUTO

self.layer_name_width = 40
self.ascii_only = RowSettings.ASCII_ONLY in self.row_settings
Expand Down
47 changes: 35 additions & 12 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from typing import Any

from .formatting import FormattingOptions
from .enums import Units
from .formatting import CONVERSION_FACTORS, FormattingOptions
from .layer_info import LayerInfo


Expand Down Expand Up @@ -46,24 +47,35 @@ def __init__(
def __repr__(self) -> str:
"""Print results of the summary."""
divider = "=" * self.formatting.get_total_width()
total_params = ModelStatistics.format_output_num(
self.total_params, self.formatting.params_units
)
trainable_params = ModelStatistics.format_output_num(
self.trainable_params, self.formatting.params_units
)
non_trainable_params = ModelStatistics.format_output_num(
self.total_params - self.trainable_params, self.formatting.params_units
)
summary_str = (
f"{divider}\n"
f"{self.formatting.header_row()}{divider}\n"
f"{self.formatting.layers_to_str(self.summary_list)}{divider}\n"
f"Total params: {self.total_params:,}\n"
f"Trainable params: {self.trainable_params:,}\n"
f"Non-trainable params: {self.total_params - self.trainable_params:,}\n"
f"Total params{total_params}\n"
f"Trainable params{trainable_params}\n"
f"Non-trainable params{non_trainable_params}\n"
)
if self.input_size:
unit, macs = self.to_readable(self.total_mult_adds)
macs = ModelStatistics.format_output_num(
self.total_mult_adds, self.formatting.macs_units
)
input_size = self.to_megabytes(self.total_input)
output_bytes = self.to_megabytes(self.total_output_bytes)
param_bytes = self.to_megabytes(self.total_param_bytes)
total_bytes = self.to_megabytes(
self.total_input + self.total_output_bytes + self.total_param_bytes
)
summary_str += (
f"Total mult-adds ({unit}): {macs:0.2f}\n{divider}\n"
f"Total mult-adds{macs}\n{divider}\n"
f"Input size (MB): {input_size:0.2f}\n"
f"Forward/backward pass size (MB): {output_bytes:0.2f}\n"
f"Params size (MB): {param_bytes:0.2f}\n"
Expand All @@ -83,10 +95,21 @@ def to_megabytes(num: int) -> float:
return num / 1e6

@staticmethod
def to_readable(num: int) -> tuple[str, float]:
def to_readable(num: int, units: Units = Units.AUTO) -> tuple[Units, float]:
"""Converts a number to millions, billions, or trillions."""
if num >= 1e12:
return "T", num / 1e12
if num >= 1e9:
return "G", num / 1e9
return "M", num / 1e6
if units == Units.AUTO:
if num >= 1e12:
return Units.TERABYTES, num / 1e12
if num >= 1e9:
return Units.GIGABYTES, num / 1e9
return Units.MEGABYTES, num / 1e6
return units, num / CONVERSION_FACTORS[units]

@staticmethod
def format_output_num(num: int, units: Units) -> str:
units_used, converted_num = ModelStatistics.to_readable(num, units)
if converted_num.is_integer():
converted_num = int(converted_num)
units_display = "" if units_used == Units.NONE else f" ({units_used})"
fmt = "d" if isinstance(converted_num, int) else ".2f"
return f"{units_display}: {converted_num:,{fmt}}"

0 comments on commit 8b3ae72

Please sign in to comment.