Skip to content

Commit

Permalink
turn HasDataType and HasShape into Protocol\s
Browse files Browse the repository at this point in the history
  • Loading branch information
markusschmaus authored and brandonwillard committed Sep 24, 2022
1 parent ec82b9f commit d51271f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 14 deletions.
22 changes: 15 additions & 7 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union

from typing_extensions import TypeAlias
from typing_extensions import Protocol, TypeAlias, runtime_checkable

from aesara.graph import utils
from aesara.graph.basic import Constant, Variable
Expand Down Expand Up @@ -262,14 +262,22 @@ def values_eq_approx(cls, a: D, b: D) -> bool:
return cls.values_eq(a, b)


class HasDataType:
"""A mixin for a type that has a :attr:`dtype` attribute."""
DataType = str

dtype: str

@runtime_checkable
class HasDataType(Protocol):
"""A protocol matching any class with :attr:`dtype` attribute."""

class HasShape:
"""A mixin for a type that has :attr:`shape` and :attr:`ndim` attributes."""
dtype: DataType


ShapeType = Tuple[Optional[int], ...]


@runtime_checkable
class HasShape(Protocol):
"""A protocol matching any class that has :attr:`shape` and :attr:`ndim` attributes."""

ndim: int
shape: Tuple[Optional[int], ...]
shape: ShapeType
2 changes: 1 addition & 1 deletion aesara/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2441,7 +2441,7 @@ def linking_patch(lib_dirs: List[str], libs: List[str]) -> List[str]:
if sys.platform != "win32":
return [f"-l{l}" for l in libs]

def sort_key(lib): # type: ignore
def sort_key(lib):
name, *numbers, extension = lib.split(".")
return (extension == "dll", tuple(map(int, numbers)))

Expand Down
5 changes: 3 additions & 2 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import MergeOptimizer
from aesara.graph.type import HasDataType, HasShape
from aesara.graph.type import DataType
from aesara.graph.utils import MetaObject, MethodNotDefined
from aesara.link.c.op import COp
from aesara.link.c.type import CType
Expand Down Expand Up @@ -268,7 +268,7 @@ def convert(x, dtype=None):
return x_


class ScalarType(CType, HasDataType, HasShape):
class ScalarType(CType):

"""
Internal class, should not be used by clients.
Expand All @@ -284,6 +284,7 @@ class ScalarType(CType, HasDataType, HasShape):
__props__ = ("dtype",)
ndim = 0
shape = ()
dtype: DataType

def __init__(self, dtype):
if isinstance(dtype, str) and dtype == "floatX":
Expand Down
3 changes: 1 addition & 2 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import aesara
from aesara import scalar as aes
from aesara.graph.basic import Variable
from aesara.graph.type import HasDataType
from aesara.tensor.type import DenseTensorType, TensorType


Expand All @@ -33,7 +32,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix)


class SparseTensorType(TensorType, HasDataType):
class SparseTensorType(TensorType):
"""A `Type` for sparse tensors.
Notes
Expand Down
8 changes: 6 additions & 2 deletions aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.type import HasDataType, HasShape
from aesara.graph.type import DataType, ShapeType
from aesara.graph.utils import MetaType
from aesara.link.c.type import CType
from aesara.misc.safe_asarray import _asarray
Expand Down Expand Up @@ -48,11 +48,15 @@
}


class TensorType(CType[np.ndarray], HasDataType, HasShape):
class TensorType(CType[np.ndarray]):
r"""Symbolic `Type` representing `numpy.ndarray`\s."""

__props__: Tuple[str, ...] = ("dtype", "shape")

ndim: int
shape: ShapeType
dtype: DataType

dtype_specs_map = dtype_specs_map
context_name = "cpu"
filter_checks_isfinite = False
Expand Down

0 comments on commit d51271f

Please sign in to comment.