Skip to content

Commit

Permalink
⚫ Updating black to the newest version.
Browse files Browse the repository at this point in the history
Different behaviors on `...`.
  • Loading branch information
rentruewang committed May 14, 2024
1 parent ffb74c0 commit aa808b3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 66 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"numpy>=1.26.3",
"scipy>=1.11.4",
"torch>=2.1.2",
"black>=24.4.2",
]
requires-python = "==3.10.*"
readme = "README.md"
Expand Down
36 changes: 12 additions & 24 deletions src/koila/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,21 @@
@runtime_checkable
class Runnable(Protocol[T]):
@abstractmethod
def run(self) -> T:
...
def run(self) -> T: ...


@runtime_checkable
class TensorMixin(Protocol):
@overload
@abstractmethod
def size(self) -> Tuple[int, ...]:
...
def size(self) -> Tuple[int, ...]: ...

@overload
@abstractmethod
def size(self, dim: int) -> int:
...
def size(self, dim: int) -> int: ...

@abstractmethod
def size(self, dim: int | None = None) -> int | Tuple[int, ...]:
...
def size(self, dim: int | None = None) -> int | Tuple[int, ...]: ...

def numel(self) -> int:
return functools.reduce(operator.mul, self.size(), 1)
Expand All @@ -57,12 +53,10 @@ def dim(self) -> int:
return len(self.size())

@abstractmethod
def dtype(self) -> DType:
...
def dtype(self) -> DType: ...

@abstractmethod
def device(self) -> str | Device:
...
def device(self) -> str | Device: ...


class BatchedPair(NamedTuple):
Expand All @@ -82,16 +76,13 @@ def map(self, func: Callable[[int], int]) -> BatchInfo:
@runtime_checkable
class RunnableTensor(Runnable[Tensor], TensorMixin, Protocol):
@abstractmethod
def batch(self) -> BatchInfo | None:
...
def batch(self) -> BatchInfo | None: ...

@abstractmethod
def run(self, partial: Tuple[int, int] | None = None) -> Tensor:
...
def run(self, partial: Tuple[int, int] | None = None) -> Tensor: ...

@abstractmethod
def visit(self, nodes: Dict[int, TensorLike]) -> None:
...
def visit(self, nodes: Dict[int, TensorLike]) -> None: ...

def buffer(self) -> Dict[int, TensorLike]:
nodes = {}
Expand Down Expand Up @@ -150,18 +141,15 @@ def bat(tensor: TensorLike) -> BatchInfo | None:


@overload
def run(val: RunnableTensor, partial: Tuple[int, int] | None = None) -> Tensor:
...
def run(val: RunnableTensor, partial: Tuple[int, int] | None = None) -> Tensor: ...


@overload
def run(val: Runnable[E], partial: Tuple[int, int] | None = None) -> E:
...
def run(val: Runnable[E], partial: Tuple[int, int] | None = None) -> E: ...


@overload
def run(val: E, partial: Tuple[int, int] | None = None) -> E:
...
def run(val: E, partial: Tuple[int, int] | None = None) -> E: ...


def run(val: Any, partial: Tuple[int, int] | None = None) -> Any:
Expand Down
50 changes: 18 additions & 32 deletions src/koila/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,10 @@ def visit(self, nodes: Dict[int, TensorLike]) -> None:
assert hash(self) in nodes.keys()

@overload
def size(self) -> Tuple[int, ...]:
...
def size(self) -> Tuple[int, ...]: ...

@overload
def size(self, dim: int) -> int:
...
def size(self, dim: int) -> int: ...

def size(self, dim: int | None = None) -> int | Tuple[int, ...]:
data = self._data
Expand Down Expand Up @@ -442,43 +440,37 @@ def backward(self) -> None:


@overload
def lazy(val: Tensor | LazyTensor, batch: int | None = None) -> LazyTensor:
...
def lazy(val: Tensor | LazyTensor, batch: int | None = None) -> LazyTensor: ...


@overload
def lazy(*val: Tensor | LazyTensor, batch: int | None = None) -> Tuple[LazyTensor, ...]:
...
def lazy(
*val: Tensor | LazyTensor, batch: int | None = None
) -> Tuple[LazyTensor, ...]: ...


@overload
def lazy(val: int) -> int:
...
def lazy(val: int) -> int: ...


@overload
def lazy(*val: int) -> Tuple[int, ...]:
...
def lazy(*val: int) -> Tuple[int, ...]: ...


@overload
def lazy(val: float) -> float:
...
def lazy(val: float) -> float: ...


@overload
def lazy(*val: float) -> Tuple[float, ...]:
...
def lazy(*val: float) -> Tuple[float, ...]: ...


@overload
def lazy(val: bool) -> bool:
...
def lazy(val: bool) -> bool: ...


@overload
def lazy(*val: bool) -> Tuple[bool, ...]:
...
def lazy(*val: bool) -> Tuple[bool, ...]: ...


def lazy(*values: Any, batch: int | None = None) -> Any:
Expand Down Expand Up @@ -521,18 +513,15 @@ class _ValIdx(NamedTuple):


@overload
def _min(input: TensorLike) -> TensorLike:
...
def _min(input: TensorLike) -> TensorLike: ...


@overload
def _min(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx:
...
def _min(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx: ...


@overload
def _min(input: TensorLike, other: TensorLike) -> TensorLike:
...
def _min(input: TensorLike, other: TensorLike) -> TensorLike: ...


@wraps(torch.min)
Expand All @@ -555,18 +544,15 @@ def _min(input: TensorLike, *args: Any, **kwargs: Any) -> TensorLike | _ValIdx:


@overload
def _max(input: TensorLike) -> TensorLike:
...
def _max(input: TensorLike) -> TensorLike: ...


@overload
def _max(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx:
...
def _max(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx: ...


@overload
def _max(input: TensorLike, other: TensorLike) -> TensorLike:
...
def _max(input: TensorLike, other: TensorLike) -> TensorLike: ...


@wraps(torch.max)
Expand Down
15 changes: 5 additions & 10 deletions src/koila/prepasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@

class CallBack(Protocol):
@abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> Reducer:
...
def __call__(self, *args: Any, **kwargs: Any) -> Reducer: ...


class Reducer(Protocol):
@abstractmethod
def __call__(self, result: Tensor, /) -> Tensor:
...
def __call__(self, result: Tensor, /) -> Tensor: ...


@dataclass(frozen=True)
Expand All @@ -63,12 +61,10 @@ def __iter__(self):
return iter(self.shape)

@overload
def __getitem__(self, index: int) -> int:
...
def __getitem__(self, index: int) -> int: ...

@overload
def __getitem__(self, index: slice) -> Tuple[int, ...]:
...
def __getitem__(self, index: slice) -> Tuple[int, ...]: ...

def __getitem__(self, index: int | slice) -> int | Tuple[int, ...]:
return self.shape[index]
Expand Down Expand Up @@ -98,8 +94,7 @@ def reducer(self) -> CallBack | None:
@runtime_checkable
class PrePassFunc(Protocol):
@abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> PrePass:
...
def __call__(self, *args: Any, **kwargs: Any) -> PrePass: ...


def mute_unused_args(*args: Any, **kwargs: Any) -> None:
Expand Down

0 comments on commit aa808b3

Please sign in to comment.