Skip to content

Commit

Permalink
Fix annotation for tracked_function (#3163)
Browse files Browse the repository at this point in the history
### Changes

- Fix suggestion in editor for function under tracked_function
- Add nncf/telemtry to mypy check
 
### Reason for changes

Broken annotation of function that decorated by tracked_function
  • Loading branch information
AlexanderDokuchaev authored Dec 19, 2024
1 parent 00f9f08 commit 3b11904
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 28 deletions.
14 changes: 8 additions & 6 deletions nncf/telemetry/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.
import functools
import inspect
from typing import Callable, List, Union
from typing import Any, Callable, List, Optional, TypeVar, Union

from nncf.telemetry.events import MODEL_BASED_CATEGORY
from nncf.telemetry.events import get_current_category
Expand All @@ -21,6 +21,8 @@
from nncf.telemetry.extractors import VerbatimTelemetryExtractor
from nncf.telemetry.wrapper import telemetry

TFunction = TypeVar("TFunction", bound=Callable[..., Any])


class tracked_function:
"""
Expand All @@ -29,7 +31,7 @@ class tracked_function:
function execution. The category of the session and events will be determined by parameters to the decorator.
"""

def __init__(self, category: str = None, extractors: List[Union[str, TelemetryExtractor]] = None):
def __init__(self, category: str = None, extractors: Optional[List[Union[str, TelemetryExtractor]]] = None) -> None:
"""
:param category: A category to be attributed to the events. If set to None, no events will be sent.
:param extractors: Add argument names in this list as string values to send an event with an "action" equal to
Expand All @@ -44,11 +46,11 @@ def __init__(self, category: str = None, extractors: List[Union[str, TelemetryEx
else:
self._collectors = []

def __call__(self, fn: Callable) -> Callable:
def __call__(self, fn: TFunction) -> TFunction:
fn_signature = inspect.signature(fn)

@functools.wraps(fn)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Any:
bound_args = fn_signature.bind(*args, **kwargs)
bound_args.apply_defaults()

Expand All @@ -59,7 +61,7 @@ def wrapped(*args, **kwargs):
events: List[CollectedEvent] = []
for collector in self._collectors:
argname = collector.argname
argvalue = bound_args.arguments[argname] if argname is not None else None
argvalue = bound_args.arguments[argname] if argname else None
event = collector.extract(argvalue)
events.append(event)

Expand All @@ -82,4 +84,4 @@ def wrapped(*args, **kwargs):
telemetry.end_session(self._category)
return retval

return wrapped
return wrapped # type: ignore[return-value]
8 changes: 4 additions & 4 deletions nncf/telemetry/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from contextlib import contextmanager
from typing import Optional, TypeVar
from typing import Generator, Optional, TypeVar

from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
Expand All @@ -25,12 +25,12 @@
# Dynamic categories
MODEL_BASED_CATEGORY = "model_based"

CURRENT_CATEGORY = None
CURRENT_CATEGORY: Optional[str] = None

TModel = TypeVar("TModel")


def _set_current_category(category: str):
def _set_current_category(category: Optional[str]) -> None:
global CURRENT_CATEGORY
CURRENT_CATEGORY = category

Expand All @@ -56,7 +56,7 @@ def get_model_based_category(model: TModel) -> str:


@contextmanager
def telemetry_category(category: str) -> str:
def telemetry_category(category: Optional[str]) -> Generator[Optional[str], None, None]:
previous_category = get_current_category()
_set_current_category(category)
yield category
Expand Down
8 changes: 4 additions & 4 deletions nncf/telemetry/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from enum import Enum
from typing import Any, Optional, Union

SerializableData = Union[str, Enum]
SerializableData = Union[str, Enum, bool]


@dataclass
Expand All @@ -26,16 +26,16 @@ class CollectedEvent:
"""

name: str
data: SerializableData = None # GA limitations
int_data: int = None
data: Optional[SerializableData] = None # GA limitations
int_data: Optional[int] = None


class TelemetryExtractor(ABC):
"""
Interface for custom telemetry extractors, to be used with the `nncf.telemetry.tracked_function` decorator.
"""

def __init__(self, argname: Optional[str] = None):
def __init__(self, argname: str = ""):
self._argname = argname

@property
Expand Down
28 changes: 14 additions & 14 deletions nncf/telemetry/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
from abc import ABC
from abc import abstractmethod
from typing import Callable, Optional
from typing import Any, Callable, Optional
from unittest.mock import MagicMock

from nncf.common.logging import nncf_logger
Expand All @@ -29,7 +29,7 @@ class ITelemetry(ABC):
# https://support.google.com/analytics/answer/1033068

@abstractmethod
def start_session(self, category: str, **kwargs):
def start_session(self, category: str, **kwargs: Any) -> None:
"""
Sends a message about starting of a new session.
Expand All @@ -45,9 +45,9 @@ def send_event(
event_action: str,
event_label: str,
event_value: Optional[int] = None,
force_send=False,
**kwargs,
):
force_send: bool = False,
**kwargs: Any,
) -> None:
"""
Send single event.
Expand All @@ -61,7 +61,7 @@ def send_event(
"""

@abstractmethod
def end_session(self, category: str, **kwargs):
def end_session(self, category: str, **kwargs: Any) -> None:
"""
Sends a message about ending of the current session.
Expand All @@ -78,7 +78,7 @@ def skip_if_raised(func: Callable[..., None]) -> Callable[..., None]:
"""

@functools.wraps(func)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> None:
try:
func(*args, **kwargs)

Expand All @@ -91,7 +91,7 @@ def wrapped(*args, **kwargs):
class NNCFTelemetry(ITelemetry):
MEASUREMENT_ID = "G-W5E9RNLD4H"

def __init__(self):
def __init__(self) -> None:
self._app_name = "nncf"
self._app_version = __version__
try:
Expand All @@ -108,7 +108,7 @@ def __init__(self):
nncf_logger.debug(f"Failed to instantiate telemetry object: exception {e}")

@skip_if_raised
def start_session(self, category: str, **kwargs):
def start_session(self, category: str, **kwargs: Any) -> None:
self._impl.start_session(category, **kwargs)

@skip_if_raised
Expand All @@ -118,9 +118,9 @@ def send_event(
event_action: str,
event_label: str,
event_value: Optional[int] = None,
force_send=False,
**kwargs,
):
force_send: bool = False,
**kwargs: Any,
) -> None:
if event_value is None:
event_value = 1
self._impl.send_event(
Expand All @@ -135,12 +135,12 @@ def send_event(
)

@skip_if_raised
def end_session(self, category: str, **kwargs):
def end_session(self, category: str, **kwargs: Any) -> None:
self._impl.end_session(category, **kwargs)


try:
from openvino_telemetry import Telemetry
from openvino_telemetry import Telemetry # type: ignore

telemetry = NNCFTelemetry()
except ImportError:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ files = [
"nncf/common/utils/",
"nncf/common/tensor_statistics",
"nncf/experimental/torch2",
"nncf/telemetry/",
]

[tool.ruff]
Expand Down

0 comments on commit 3b11904

Please sign in to comment.