You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm extending the given example for fvcore.nn.FlopCountAnalysis to add flops count of a custom op within my model class.
import torch
from collections import Counter
from fvcore.nn import FlopCountAnalysis
from torch import nn
class TestModel(nn.Module):
"""Toy model."""
def __init__(self):
super().__init__()
self.act = nn.ReLU()
self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1)
self.fc = nn.Linear(in_features=1000, out_features=10)
def forward(self, x):
_ = self.custom_op_flop_counter(inputs=x, outputs=None)
return self.fc(self.act(self.conv(x)).flatten(1))
@staticmethod
# Has no access to anything else in the class.
def custom_op_flop_counter(inputs, outputs) -> Counter:
"""Returns counter value to include in flops."""
# The function should return a counter object with per-operator statistics.
return Counter({'custom_op': 500})
model = TestModel()
inputs = (torch.randn((1, 3, 10, 10)),)
flops = FlopCountAnalysis(
model,
inputs).set_op_handle(
"custom_op", model.custom_op_flop_counter)
print(flops.by_module_and_operator())
The "custom_op" and its returned value of 500 are not seen in the print statement. It does print the expected values for the linear and conv operators, e.g.: {'': Counter({'linear': 10000, 'conv': 3000}), 'act': Counter(), 'conv': Counter({'conv': 3000}), 'fc': Counter({'linear': 10000})}. What am I doing wrong here that prevents my custom op count being included?
The text was updated successfully, but these errors were encountered:
guynich
changed the title
Counting FLOPS for a custom ops with set_op_handle: a toy example that doesn't work.
Counting FLOPS for a custom op with set_op_handle: a toy example that doesn't work.
Apr 9, 2024
I'm extending the given example for fvcore.nn.FlopCountAnalysis to add flops count of a custom op within my model class.
The "custom_op" and its returned value of
500
are not seen in the print statement. It does print the expected values for thelinear
andconv
operators, e.g.:{'': Counter({'linear': 10000, 'conv': 3000}), 'act': Counter(), 'conv': Counter({'conv': 3000}), 'fc': Counter({'linear': 10000})}
. What am I doing wrong here that prevents my custom op count being included?The text was updated successfully, but these errors were encountered: