Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One complex parameter should count as two params #193

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

scaomath
Copy link

@scaomath scaomath commented Nov 9, 2022

As all models' parameters counting traces back here

def get_param_count(
module: nn.Module, name: str, param: torch.Tensor
) -> tuple[int, str]:
"""
Get count of number of params, accounting for mask.
Masked models save parameters with the suffix "_orig" added.
They have a buffer ending with "_mask" which has only 0s and 1s.
If a mask exists, the sum of 1s in mask is number of params.
"""
if name.endswith("_orig"):
without_suffix = name[:-5]
pruned_weights = rgetattr(module, f"{without_suffix}_mask")
if pruned_weights is not None:
parameter_count = int(torch.sum(pruned_weights))
return parameter_count, without_suffix
return param.nelement(), name

there is no checking on whether the parameter tensor is complex or real. If a parameter is complex, such as a + i b, then it represents actually two parameters (for counting MACs/FLOPs purpose).

Of course this PR might not be conforming with torchinfo's dev, feel free to close it, I hope complex would be considered in next version.

@TylerYep
Copy link
Owner

TylerYep commented Nov 9, 2022

Thanks for looking into this. Could you provide a small example model + output to illustrate your point? That will help validate this fix works

@scaomath
Copy link
Author

scaomath commented Nov 9, 2022

Thanks for looking into this. Could you provide a small example model + output to illustrate your point? That will help validate this fix works

I added a test. Complex layer correctly returns the double params count. I encounter this issue when running torchinfo for network like FNO: neuraloperator/neuraloperator#52
I thought it would be nice to integrate this to one of my all time fav add-on for torch.

Thoughts? @TylerYep

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants