Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Improve typehints. #426

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/deploy-model-card-creator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Deploy-Space-Creator

on:
- push
- pull_request

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand Down
4 changes: 2 additions & 2 deletions skops/card/_markup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from collections.abc import Mapping
from contextlib import contextmanager
from typing import Any, Sequence
from typing import Any

from skops.card._model_card import TableSection

Expand Down Expand Up @@ -258,7 +258,7 @@ def _table(self, item) -> str:
# pandoc < 2.5
columns, body = self._table_old(item)

table: Mapping[str, Sequence[Any]]
table: Mapping[str, list[Any]]
if not body:
table = {key: [] for key in columns}
else:
Expand Down
14 changes: 7 additions & 7 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hashlib import sha256
from pathlib import Path
from reprlib import Repr
from typing import Any, Iterator, List, Literal, Optional, Sequence, Union
from typing import Any, Iterator, Literal

import joblib
from huggingface_hub import ModelCardData
Expand Down Expand Up @@ -60,7 +60,7 @@ def _clean_table(table: str) -> str:
return table


def metadata_from_config(config_path: Union[str, Path]) -> ModelCardData:
def metadata_from_config(config_path: str | Path) -> ModelCardData:
"""Construct a ``ModelCardData`` object from a ``config.json`` file.

Most information needed for the metadata section of a ``README.md`` file on
Expand Down Expand Up @@ -281,7 +281,7 @@ def __repr__(self) -> str:
class TableSection(Section):
"""Adds a table to the model card"""

table: Mapping[str, Sequence[Any]] = field(default_factory=dict)
table: Mapping[str, list[Any]] = field(default_factory=dict)
folded: bool = False

def __post_init__(self) -> None:
Expand Down Expand Up @@ -488,7 +488,7 @@ def __init__(
model_diagram: bool | Literal["auto"] | str = "auto",
metadata: ModelCardData | None = None,
template: Literal["skops"] | dict[str, str] | None = "skops",
trusted: Optional[List[str]] = None,
trusted: list[str] | None = None,
) -> None:
self.model = model
self.metadata = metadata or ModelCardData()
Expand Down Expand Up @@ -619,7 +619,7 @@ def add(self, folded: bool = False, **kwargs: str) -> Self:
return self

def _select(
self, subsection_names: Sequence[str], create: bool = True
self, subsection_names: list[str], create: bool = True
) -> dict[str, Section]:
"""Select a single section from the data.

Expand Down Expand Up @@ -713,7 +713,7 @@ def select(self, key: str) -> Section:
parent_section = self._select(subsection_names, create=False)
return parent_section[leaf_node_name]

def delete(self, key: str | Sequence[str]) -> None:
def delete(self, key: str | list[str]) -> None:
"""Delete a section from the model card.

To delete a subsection of an existing section, use a ``"/"`` in the
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def add_metrics(
def add_permutation_importances(
self,
permutation_importances,
columns: Sequence[str],
columns: list[str],
plot_file: str | Path = "permutation_importances.png",
plot_name: str = "Permutation Importances",
overwrite: bool = False,
Expand Down
3 changes: 1 addition & 2 deletions skops/cli/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import pathlib
import pickle
from typing import Optional

from skops.cli._utils import get_log_level
from skops.io import dumps, get_untrusted_types
Expand Down Expand Up @@ -59,7 +58,7 @@ def _convert_file(


def format_parser(
parser: Optional[argparse.ArgumentParser] = None,
parser: argparse.ArgumentParser | None = None,
) -> argparse.ArgumentParser:
"""Adds arguments and help to parent CLI parser for the convert method."""

Expand Down
34 changes: 16 additions & 18 deletions skops/hub_utils/_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import shutil
import warnings
from pathlib import Path
from typing import Any, List, Literal, MutableMapping, Optional, Sequence, Union
from typing import Any, Literal, MutableMapping

import numpy as np
from huggingface_hub import HfApi, InferenceClient, snapshot_download
Expand All @@ -25,7 +25,7 @@
]


def _validate_folder(path: Union[str, Path]) -> None:
def _validate_folder(path: str | Path) -> None:
"""Validate the contents of a folder.

This function checks if the contents of a folder make a valid repo for a
Expand Down Expand Up @@ -117,15 +117,15 @@ def _get_example_input_from_tabular_data(data):
)


def _get_example_input_from_text_data(data: Sequence[str]):
def _get_example_input_from_text_data(data: list[str]):
"""Returns the example input of a model for a text task.

The input is converted into a dictionary which is then stored in the config
file.

Parameters
----------
data: Sequence[str]
data: list of str
A sequence of strings. The first 3 elements are used as example input.

Returns
Expand Down Expand Up @@ -197,9 +197,9 @@ def _get_column_names(data):

def _create_config(
*,
model_path: Union[str, Path],
requirements: List[str],
dst: Union[str, Path],
model_path: str | Path,
requirements: list[str],
dst: str | Path,
task: Literal[
"tabular-classification",
"tabular-regression",
Expand Down Expand Up @@ -319,9 +319,9 @@ def _check_model_file(path: str | Path) -> Path:

def init(
*,
model: Union[str, Path],
requirements: List[str],
dst: Union[str, Path],
model: str | Path,
requirements: list[str],
dst: str | Path,
task: Literal[
"tabular-classification",
"tabular-regression",
Expand Down Expand Up @@ -468,9 +468,7 @@ def dump_json(path, content):
json.dump(content, f, sort_keys=True, indent=4)


def update_env(
*, path: Union[str, Path], requirements: Union[List[str], None] = None
) -> None:
def update_env(*, path: str | Path, requirements: list[str] | None = None) -> None:
"""Update the environment requirements of a repo.

This function takes the path to the repo, and updates the requirements of
Expand Down Expand Up @@ -498,7 +496,7 @@ def update_env(
def push(
*,
repo_id: str,
source: Union[str, Path],
source: str | Path,
token: str | None = None,
commit_message: str | None = None,
create_remote: bool = False,
Expand Down Expand Up @@ -579,7 +577,7 @@ def push(
)


def get_config(path: Union[str, Path]) -> dict[str, Any]:
def get_config(path: str | Path) -> dict[str, Any]:
"""Returns the configuration of a project.

Parameters
Expand All @@ -598,7 +596,7 @@ def get_config(path: Union[str, Path]) -> dict[str, Any]:
return config


def get_requirements(path: Union[str, Path]) -> List[str]:
def get_requirements(path: str | Path) -> list[str]:
"""Returns the requirements of a project.

Parameters
Expand All @@ -620,7 +618,7 @@ def get_requirements(path: Union[str, Path]) -> List[str]:
def download(
*,
repo_id: str,
dst: Union[str, Path],
dst: str | Path,
revision: str | None = None,
token: str | None = None,
keep_cache: bool = True,
Expand Down Expand Up @@ -685,7 +683,7 @@ def download(


# TODO(v0.10): remove this function
def get_model_output(repo_id: str, data: Any, token: Optional[str] = None) -> Any:
def get_model_output(repo_id: str, data: Any, token: str | None = None) -> Any:
"""Returns the output of the model using Hugging Face Hub's inference API.

See the :ref:`User Guide <hf_hub_inference>` for more details.
Expand Down
20 changes: 10 additions & 10 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

import io
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union
from typing import Any, Generator, Type, Union

from ._protocol import PROTOCOL
from ._utils import LoadContext, get_module, get_type_paths
from .exceptions import UntrustedTypesFoundException

NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = {}
VALID_NODE_CHILD_TYPES = Optional[
Union["Node", List["Node"], Dict[str, "Node"], Type, str, io.BytesIO]
VALID_NODE_CHILD_TYPES = Union[
"Node", list["Node"], dict[str, "Node"], Type, str, io.BytesIO
]


def check_type(module_name: str, type_name: str, trusted: Sequence[str]) -> bool:
def check_type(module_name: str, type_name: str, trusted: list[str]) -> bool:
"""Check if a type is safe to load.

A type is safe to load only if it's present in the trusted list.
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[Sequence[str]] = None,
trusted: list[str | type[Any]] | None = None,
memoize: bool = True,
) -> None:
self.class_name, self.module_name = state["__class__"], state["__module__"]
Expand Down Expand Up @@ -172,8 +172,8 @@ def _construct(self):

@staticmethod
def _get_trusted(
trusted: Optional[Sequence[Union[str, Type]]],
default: Sequence[Union[str, Type]],
trusted: list[str | Type] | None,
default: list[str | Type],
) -> list[str]:
"""Return a trusted list, or True.

Expand Down Expand Up @@ -233,7 +233,7 @@ def get_unsafe_set(self) -> set[str]:
continue

# Get the safety set based on the type of the child. In most cases
# other than ListNode and DictNode, children are all of type Node.
# other than listNode and DictNode, children are all of type Node.
if isinstance(child, list):
# iterate through the list
for value in child:
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[List[str]] = None,
trusted: list[str | type[Any]] | None = None,
):
# we pass memoize as False because we don't want to memoize the cached
# node.
Expand All @@ -302,7 +302,7 @@ def _construct(self):
def get_tree(
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[Sequence[str]],
trusted: list[str | type[Any]] | None,
) -> Node:
"""Get the tree of nodes.

Expand Down
Loading
Loading