Skip to content

Commit

Permalink
First pass of type checking in Qualtran (#889)
Browse files Browse the repository at this point in the history
* First pass of type checking in Qualtran

- Fixes many mypy issues in Qualtran.

Some highlights:
- Disable override type checking since overrides or variable kwargs are
not appreciated by mypy
- Remove return type for BloqBuilder.add since many places assume it
  returns a tuple.
- Ignore NDArray[cirq.Qid] type issues
  • Loading branch information
dstrain115 authored Apr 23, 2024
1 parent b3a134f commit 8011d2d
Show file tree
Hide file tree
Showing 66 changed files with 225 additions and 143 deletions.
6 changes: 5 additions & 1 deletion dev_tools/conf/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
show_error_codes = true
plugins = duet.typing, numpy.typing.mypy_plugin
allow_redefinition = true
# Disabling function override checking
# Qualtran has many places where kwargs are used
# with the intention to override in subclasses in ways mypy does not like
disable_error_code = override

[mypy-__main__]
follow_imports = silent
Expand All @@ -15,7 +19,7 @@ follow_imports = silent
ignore_missing_imports = true

# Non-Google
[mypy-sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*]
[mypy-sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,plotly.*,dash.*,tensorflow_docs.*,fxpmath.*,ipywidgets.*,cachetools.*,pydot.*]
follow_imports = silent
ignore_missing_imports = true

Expand Down
4 changes: 2 additions & 2 deletions dev_tools/qualtran_dev_tools/clean_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import List
from typing import Any, List

import nbformat
from nbconvert.preprocessors import ClearMetadataPreprocessor, ClearOutputPreprocessor
Expand Down Expand Up @@ -46,7 +46,7 @@ def clean_notebook(nb_path: Path, do_clean: bool = True):

pp1 = ClearOutputPreprocessor()
pp2 = ClearMetadataPreprocessor(preserve_cell_metadata_mask={'cq.autogen'})
resources = {}
resources: dict[str, Any] = {}
nb, resources = pp1.preprocess(nb, resources=resources)
nb, resources = pp2.preprocess(nb, resources=resources)

Expand Down
2 changes: 2 additions & 0 deletions dev_tools/qualtran_dev_tools/notebook_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def execute_and_export_notebook(paths: _NBInOutPaths) -> Optional[Exception]:
with paths.html_out.open('w') as f:
f.write(html)

return None


class _NotebookRunClosure:
"""Used to run notebook execution logic in subprocesses."""
Expand Down
8 changes: 4 additions & 4 deletions dev_tools/qualtran_dev_tools/reference_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ class _CustomTemplateMixin:
return _CustomTemplateMixin


class MyModulePageBuilder(mixin_custom_template('module'), ModulePageBuilder):
class MyModulePageBuilder(mixin_custom_template('module'), ModulePageBuilder): # type: ignore[misc]
"""Use a custom template for module pages."""


class MyClassPageBuilder(mixin_custom_template('class'), ClassPageBuilder):
class MyClassPageBuilder(mixin_custom_template('class'), ClassPageBuilder): # type: ignore[misc]
"""Use a custom template for class pages.
Additionally, this will re-sort the class members (i.e. methods) to match
Expand All @@ -120,11 +120,11 @@ def __init__(self, page_info):
)


class MyFunctionPageBuilder(mixin_custom_template('function'), FunctionPageBuilder):
class MyFunctionPageBuilder(mixin_custom_template('function'), FunctionPageBuilder): # type: ignore[misc]
"""Use a custom template for function pages."""


class MyTypeAliasPageBuilder(mixin_custom_template('type_alias'), TypeAliasPageBuilder):
class MyTypeAliasPageBuilder(mixin_custom_template('type_alias'), TypeAliasPageBuilder): # type: ignore[misc]
"""Use a custom template for type alias pages."""


Expand Down
4 changes: 2 additions & 2 deletions dev_tools/requirements/deps/mypy.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# the mypy dependency file
mypy~=0.991.0
mypy-protobuf
mypy~=1.9
mypy-protobuf
2 changes: 1 addition & 1 deletion dev_tools/requirements/envs/dev.env.txt
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ more-itertools==10.2.0
# jaraco-functools
mpmath==1.3.0
# via sympy
mypy==0.991
mypy==1.9.0
# via -r deps/mypy.txt
mypy-extensions==1.0.0
# via
Expand Down
6 changes: 5 additions & 1 deletion qualtran/_infra/adjoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, TYPE_CHECKING

import pytest
import sympy
Expand All @@ -26,6 +27,9 @@
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import LarrowTextBox, RarrowTextBox

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT


def test_serial_combo_adjoint():
# The normal decomposition is three `TestAtom` tagged atom{0,1,2}.
Expand Down Expand Up @@ -179,7 +183,7 @@ class DecomposesIntoTAcceptsAdjoint(Bloq):
def signature(self) -> Signature:
return Signature.build(q=1)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT'):
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
soqs = bb.add_d(TAcceptsAdjoint(), **soqs)
return soqs

Expand Down
6 changes: 3 additions & 3 deletions qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Contains the main interface for defining `Bloq`s."""

import abc
from typing import Any, Dict, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -297,7 +297,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def call_graph(
self,
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
keep: Optional[Sequence['Bloq']] = None,
keep: Callable[['Bloq'], bool] = None,
max_depth: Optional[int] = None,
) -> Tuple['nx.DiGraph', Dict['Bloq', Union[int, 'sympy.Expr']]]:
"""Get the bloq call graph and call totals.
Expand Down Expand Up @@ -480,7 +480,7 @@ def on(self, *qubits: 'cirq.Qid') -> 'cirq.Operation':
return cirq.Gate.on(BloqAsCirqGate(bloq=self), *qubits)

def on_registers(
self, **qubit_regs: Union['cirq.Qid', Sequence['cirq.Qid'], 'NDArray[cirq.Qid]']
self, **qubit_regs: Union['cirq.Qid', Sequence['cirq.Qid'], 'NDArray[cirq.Qid]'] # type: ignore[type-var]
) -> 'cirq.Operation':
"""A `cirq.Operation` of this bloq operating on the given qubit registers.
Expand Down
2 changes: 1 addition & 1 deletion qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def add_d(self, bloq: Bloq, **in_soqs: SoquetInT) -> Dict[str, SoquetT]:
binst = BloqInstance(bloq, i=self._new_binst_i())
return dict(self._add_binst(binst, in_soqs=in_soqs))

def add(self, bloq: Bloq, **in_soqs: SoquetInT) -> Union[None, SoquetT, Tuple[SoquetT, ...]]:
def add(self, bloq: Bloq, **in_soqs: SoquetInT):
"""Add a new bloq instance to the compute graph.
This is the primary method for building a composite bloq. Each call to `add` adds a
Expand Down
4 changes: 2 additions & 2 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

def _cvs_convert(
cvs: Union[int, Sequence[int], Sequence[Sequence[int]]]
) -> Tuple[NDArray[int], ...]:
) -> Tuple[NDArray[np.integer], ...]:
if isinstance(cvs, (int, np.integer)):
return (np.array(cvs),)
if isinstance(cvs, np.ndarray):
Expand Down Expand Up @@ -103,7 +103,7 @@ class CtrlSpec:
qdtypes: Tuple[QDType, ...] = attrs.field(
default=QBit(), converter=lambda qt: (qt,) if isinstance(qt, QDType) else tuple(qt)
)
cvs: Tuple[NDArray[int], ...] = attrs.field(default=1, converter=_cvs_convert)
cvs: Tuple[NDArray[np.integer], ...] = attrs.field(default=1, converter=_cvs_convert)

def __attrs_post_init__(self):
assert len(self.qdtypes) == len(self.cvs)
Expand Down
7 changes: 5 additions & 2 deletions qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, TYPE_CHECKING

import attrs
import cirq
Expand Down Expand Up @@ -49,6 +49,9 @@
from qualtran.drawing import get_musical_score_data
from qualtran.drawing.musical_score import Circle, SoqData, TextBox

if TYPE_CHECKING:
from qualtran import SoquetT


def test_ctrl_spec():
cspec1 = CtrlSpec()
Expand Down Expand Up @@ -374,7 +377,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
and_ctrl = [bb.add(one_or_zero[cv]) for cv in self.and_ctrl]

ctrl_soqs = bb.add_d(cbloq, **ctrl_soqs, ctrl=and_ctrl)
out_soqs = [*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]
out_soqs = np.asarray([*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')])

for reg, cvs in zip(cbloq.ctrl_regs, self.ctrl_spec.cvs):
for idx in reg.all_idxs():
Expand Down
24 changes: 17 additions & 7 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

import abc
from enum import Enum
from typing import Any, Iterable, List, Sequence, Union
from typing import Any, cast, Iterable, List, Sequence, Union

import attrs
import numpy as np
Expand Down Expand Up @@ -133,7 +133,9 @@ def from_bits(self, bits: Sequence[int]) -> int:
assert len(bits) == 1
return bits[0]

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
def assert_valid_classical_val_array(
self, val_array: NDArray[np.integer], debug_str: str = 'val'
):
if not np.all((val_array == 0) | (val_array == 1)):
raise ValueError(f"Bad {self} value array in {debug_str}")

Expand Down Expand Up @@ -192,7 +194,7 @@ def get_classical_domain(self) -> Iterable[int]:
def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
mask = (1 << self.bitsize) - 1
mask = (1 << cast(int, self.bitsize)) - 1
return QUInt(self.bitsize).to_bits(int(x) & mask)

def from_bits(self, bits: Sequence[int]) -> int:
Expand All @@ -209,7 +211,9 @@ def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if val >= 2 ** (self.bitsize - 1):
raise ValueError(f"Too-large classical {self}: {val} encountered in {debug_str}")

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
def assert_valid_classical_val_array(
self, val_array: NDArray[np.integer], debug_str: str = 'val'
):
if np.any(val_array < -(2 ** (self.bitsize - 1))):
raise ValueError(f"Too-small classical {self}s encountered in {debug_str}")
if np.any(val_array >= 2 ** (self.bitsize - 1)):
Expand Down Expand Up @@ -298,7 +302,9 @@ def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if val >= 2**self.bitsize:
raise ValueError(f"Too-large classical value encountered in {debug_str}")

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
def assert_valid_classical_val_array(
self, val_array: NDArray[np.integer], debug_str: str = 'val'
):
if np.any(val_array < 0):
raise ValueError(f"Negative classical values encountered in {debug_str}")
if np.any(val_array >= 2**self.bitsize):
Expand Down Expand Up @@ -391,7 +397,9 @@ def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
return QUInt(self.bitsize).from_bits(bits)

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
def assert_valid_classical_val_array(
self, val_array: NDArray[np.integer], debug_str: str = 'val'
):
if np.any(val_array < 0):
raise ValueError(f"Negative classical values encountered in {debug_str}")
if np.any(val_array >= self.iteration_length):
Expand Down Expand Up @@ -539,7 +547,9 @@ def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if val >= 2**self.bitsize:
raise ValueError(f"Too-large classical value encountered in {debug_str}")

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
def assert_valid_classical_val_array(
self, val_array: NDArray[np.integer], debug_str: str = 'val'
):
if np.any(val_array < 0):
raise ValueError(f"Negative classical values encountered in {debug_str}")
if np.any(val_array >= 2**self.bitsize):
Expand Down
5 changes: 4 additions & 1 deletion qualtran/_infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Dict, TYPE_CHECKING

import cirq
import numpy as np
Expand All @@ -33,6 +33,9 @@
from qualtran.bloqs.util_bloqs import Power
from qualtran.testing import execute_notebook

if TYPE_CHECKING:
from qualtran import BloqBuilder


class _TestGate(GateWithRegisters):
@property
Expand Down
5 changes: 3 additions & 2 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import enum
import itertools
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List, overload, Tuple
from typing import Dict, Iterable, Iterator, List, overload, Tuple, Union

import attrs
import numpy as np
import sympy
from attrs import field, frozen

from .data_types import QAny, QBit, QDType
Expand Down Expand Up @@ -128,7 +129,7 @@ def __init__(self, registers: Iterable[Register]):
self._rights = _dedupe((reg.name, reg) for reg in self._registers if reg.side & Side.RIGHT)

@classmethod
def build(cls, **registers: int) -> 'Signature':
def build(cls, **registers: Union[int, sympy.Expr]) -> 'Signature':
"""Construct a Signature comprised of simple thru registers given the register bitsizes.
Args:
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _right_building_block(self, inp, out, anc, depth):
yield from self._right_building_block(inp, out, anc, depth - 1)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> cirq.OP_TREE:
# reverse the order of qubits for big endian-ness.
input_bits = quregs['a'][::-1]
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from qualtran.drawing.musical_score import TextBox

if TYPE_CHECKING:
from qualtran import BloqBuilder
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

Expand Down Expand Up @@ -87,7 +88,7 @@ def __pow__(self, power: int):
return NotImplemented # pragma: no cover

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> cirq.OP_TREE:
"""Decomposes the gate into 4N And and And† operations for a T complexity of 4N.
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/arithmetic/hamming_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _decompose_using_three_to_two_adders(
x = [*y]

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> cirq.OP_TREE:
# Qubit order needs to be reversed because the registers store Big Endian representation
# of integers.
Expand Down
2 changes: 2 additions & 0 deletions qualtran/bloqs/arithmetic/multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from qualtran.resource_counting.symbolic_counting_utils import smax

if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import SoquetT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/basic_gates/cnot.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def add_controlled(
return super().get_ctrl_system(ctrl_spec=ctrl_spec)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', ctrl: 'CirqQuregT', target: 'CirqQuregT'
) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]:
self, qubit_manager: 'cirq.QubitManager', ctrl: 'CirqQuregT', target: 'CirqQuregT' # type: ignore[type-var]
) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: # type: ignore[type-var]
import cirq

(ctrl,) = ctrl
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/basic_gates/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def add_my_tensors(
)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', q: 'CirqQuregT'
) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]:
self, qubit_manager: 'cirq.QubitManager', q: 'CirqQuregT' # type: ignore[type-var]
) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: # type: ignore[type-var]
import cirq

(q,) = q
Expand Down
Loading

0 comments on commit 8011d2d

Please sign in to comment.