Skip to content

Commit

Permalink
Mypy type issues part 2 (quantumlib#891)
Browse files Browse the repository at this point in the history
* Mypy type issues part 2

- Continue fixing mypy type issues.

Some highlights include:
- Fix bloq_example decorator to pass through types
- Add custom converters for attrs tuple conversion to aid typing
of attrs tuples with automatic converters.

* Add more stuff

* Address comments

* Remove cirq.inverse

* Still don't quite have types right for bloq_example

* Fix typing in bloq_examples

* Fix tests.

* Got rid of most of the redefintions.

* Fix typop
  • Loading branch information
dstrain115 authored Apr 25, 2024
1 parent 9dd3558 commit 34340ed
Show file tree
Hide file tree
Showing 64 changed files with 327 additions and 172 deletions.
2 changes: 1 addition & 1 deletion dev_tools/bloq-method-overrides-report.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _call_graph(bc: Type[Bloq]):
)
if annot['ssa'] != 'SympySymbolAllocator':
print(f"{bc}.build_call_graph `ssa: 'SympySymbolAllocator'`")
if annot['return'] != Set[ForwardRef('BloqCountT')]:
if annot['return'] != Set[ForwardRef('BloqCountT')]: # type: ignore[misc]
print(f"{bc}.build_call_graph -> 'BloqCountT'")


Expand Down
1 change: 1 addition & 0 deletions dev_tools/conf/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
show_error_codes = true
plugins = duet.typing, numpy.typing.mypy_plugin
allow_redefinition = true
check_untyped_defs = true
# Disabling function override checking
# Qualtran has many places where kwargs are used
# with the intention to override in subclasses in ways mypy does not like
Expand Down
4 changes: 2 additions & 2 deletions dev_tools/qualtran_dev_tools/bloq_report_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Optional, Set, Type
from typing import Any, Dict, Iterable, List, Optional, Set, Type

import pandas as pd
import pandas.io.formats.style
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_bloq_report_card(
if bexamples is None:
bexamples = get_bloq_examples()

records = []
records: List[Dict[str, Any]] = []
missing_bclasses = bloq_classes_with_no_examples(bclasses, bexamples)
records.extend(record_for_class_with_no_examples(k) for k in missing_bclasses)
records.extend(record_for_bloq_example(be) for be in bexamples)
Expand Down
8 changes: 4 additions & 4 deletions dev_tools/qualtran_dev_tools/reference_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def __init__(self, page_info):
# you pass in, so we can't do this sorting where it would make the most sense in
# MyClassPageInfo.collect_docs()
methods = _filter_and_sort_members(
self.page_info.py_object, self.methods.info_dict.values()
self.page_info.py_object, self.methods.info_dict.values() # type: ignore[has-type]
)
self.methods = Methods(
info_dict={meth.short_name: meth for meth in methods},
constructor=self.methods.constructor,
constructor=self.methods.constructor, # type: ignore[has-type]
)


Expand All @@ -135,7 +135,7 @@ class MyModulePageInfo(ModulePageInfo):

def collect_docs(self):
ret = super().collect_docs() # pylint: disable=assignment-from-no-return
self._classes = _filter_and_sort_members(self.py_object, self._classes)
self._classes = _filter_and_sort_members(self.py_object, self._classes) # type: ignore[has-type]
return ret


Expand All @@ -149,7 +149,7 @@ def collect_docs(self):
# Note: currently the following sort is un-done by the class page builder.
# If the upstream page builder changes to respect the member order (like for the other
# page types), we should sort them here.
self._methods = _filter_and_sort_members(self.py_object, self._methods)
self._methods = _filter_and_sort_members(self.py_object, self._methods) # type: ignore[has-type]
return ret


Expand Down
11 changes: 7 additions & 4 deletions dev_tools/qualtran_dev_tools/shell_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,16 @@ def test_run_returns_string_output():

def test_run_with_command_logging():
catch_stderr = io.StringIO()
kw = {"stdout": subprocess.DEVNULL}
with contextlib.redirect_stderr(catch_stderr):
shell_tools.run(["echo", "-n", "a", "b"], **kw)
shell_tools.run(["echo", "-n", "a", "b"], stdout=subprocess.DEVNULL)
assert catch_stderr.getvalue() == "run: ('echo', '-n', 'a', 'b')\n"
catch_stderr = io.StringIO()
with contextlib.redirect_stderr(catch_stderr):
shell_tools.run(["echo", "-n", "a", "b"], abbreviate_non_option_arguments=True, **kw)
shell_tools.run(
["echo", "-n", "a", "b"],
abbreviate_non_option_arguments=True,
stdout=subprocess.DEVNULL,
)
assert catch_stderr.getvalue() == "run: ('echo', '-n', '[...]')\n"


Expand All @@ -64,5 +67,5 @@ def test_output_of():
assert shell_tools.output_of(["echo", "test"]) == "test"
# filtering of the None arguments was removed. check this now fails
with pytest.raises(TypeError):
_ = shell_tools.output_of(["echo", "test", None, "duck"])
_ = shell_tools.output_of(["echo", "test", None, "duck"]) # type: ignore[list-item]
assert shell_tools.output_of("pwd", cwd="/tmp") in ["/tmp", "/private/tmp"]
9 changes: 6 additions & 3 deletions qualtran/_infra/adjoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sympy

import qualtran.testing as qlt_testing
from qualtran import Adjoint, Bloq, CompositeBloq, Side, Signature, Soquet
from qualtran import Adjoint, Bloq, BloqInstance, CompositeBloq, Side, Signature, Soquet
from qualtran._infra.adjoint import _adjoint_cbloq
from qualtran.bloqs.basic_gates import CNOT, CSwap, ZeroState
from qualtran.bloqs.for_testing.atom import TestAtom
Expand Down Expand Up @@ -162,8 +162,11 @@ def test_wire_symbol():
(reg,) = zero.signature
adj = Adjoint(zero) # specifically use the Adjoint wrapper for testing

ws = zero.wire_symbol(Soquet(None, reg))
adj_ws = adj.wire_symbol(Soquet(None, reg.adjoint()))
# TODO: Remove binst variable. These BloqInstances are for typing only
# and are not really used by the function.
# See https://github.com/quantumlib/Qualtran/issues/608
ws = zero.wire_symbol(Soquet(BloqInstance(CNOT(), 1), reg))
adj_ws = adj.wire_symbol(Soquet(BloqInstance(CNOT(), 2), reg.adjoint()))
assert isinstance(ws, LarrowTextBox)
assert isinstance(adj_ws, RarrowTextBox)

Expand Down
43 changes: 23 additions & 20 deletions qualtran/_infra/bloq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Generic, Iterable, Optional, Sequence, Type, TypeVar, Union

from attrs import field, frozen

from qualtran.resource_counting import GeneralizerT

from .bloq import Bloq

_BloqType = TypeVar('_BloqType', bound=Bloq)
_GeneralizerType = Union[GeneralizerT, Sequence[GeneralizerT]]


@frozen
class BloqExample:
class BloqExample(Generic[_BloqType]):
"""An instantiation of a bloq and its metadata.
In particular, this class wraps a callable that returns a bloq instantiation with
Expand All @@ -36,18 +41,18 @@ class BloqExample:
generalizer: Passed to `get_bloq_counts_graph` calls for bloq-counts equivalence checking.
"""

_func: Callable[[], Bloq] = field(repr=False, hash=False)
_func: Callable[[], _BloqType] = field(repr=False, hash=False)
name: str
bloq_cls: Type[Bloq]
generalizer: Callable[[Bloq], Optional[Bloq]] = field(
generalizer: _GeneralizerType = field(
converter=lambda x: tuple(x) if isinstance(x, Sequence) else x, default=lambda x: x
)

def make(self) -> Bloq:
def make(self) -> _BloqType:
"""Make the bloq."""
return self._func()

def __call__(self) -> Bloq:
def __call__(self) -> _BloqType:
"""This class is callable: it will make the bloq.
This makes the `bloq_example` decorator make sense: we wrap a function, so this
Expand All @@ -56,12 +61,12 @@ def __call__(self) -> Bloq:
return self.make()


def _name_from_func_name(func: Callable[[], Bloq]) -> str:
def _name_from_func_name(func: Callable[[], _BloqType]) -> str:
"""Use the name of the function as the `BloqExample.name` when using the decorator."""
return func.__name__.lstrip('_')


def _bloq_cls_from_func_annotation(func: Callable[[], Bloq]) -> Type[Bloq]:
def _bloq_cls_from_func_annotation(func: Callable[[], _BloqType]) -> Type[_BloqType]:
"""Use the function return type annotation as the `BloqExample.bloq_cls` with the decorator."""
anno = func.__annotations__
if 'return' not in anno:
Expand All @@ -73,28 +78,28 @@ def _bloq_cls_from_func_annotation(func: Callable[[], Bloq]) -> Type[Bloq]:


@typing.overload
def bloq_example(_func: Callable[[], Bloq], **kwargs: Any) -> BloqExample:
def bloq_example(_func: Callable[[], _BloqType], **kwargs: Any) -> BloqExample[_BloqType]:
...


@typing.overload
def bloq_example(
_func: None, *, generalizer: Callable[[Bloq], Optional[Bloq]] = lambda x: x
) -> Callable[[Callable[[], Bloq]], BloqExample]:
_func: None, *, generalizer: _GeneralizerType = lambda x: x
) -> Callable[[Callable[[], _BloqType]], BloqExample[_BloqType]]:
...


def bloq_example(
_func: Callable[[], Bloq] = None, *, generalizer: Callable[[Bloq], Optional[Bloq]] = lambda x: x
):
_func: Optional[Callable[[], _BloqType]] = None, *, generalizer: _GeneralizerType = lambda x: x
) -> BloqExample[_BloqType]:
"""Decorator to turn a function into a `BloqExample`.
This will set `name` to the name of the function and `bloq_cls` according to the return-type
annotation. You can also call the decorator with keyword arguments, which will be passed
through to the `BloqExample` constructor.
"""

def _inner(func: Callable[[], Bloq]) -> BloqExample:
def _inner(func: Callable[[], _BloqType]) -> BloqExample:
return BloqExample(
func=func,
name=_name_from_func_name(func),
Expand All @@ -109,11 +114,9 @@ def _inner(func: Callable[[], Bloq]) -> BloqExample:
return _inner(_func)


def _to_tuple(T: Type):
def _t(x: Iterable[T]) -> Tuple[T, ...]:
return tuple(x)

return _t
def _to_tuple(x: Iterable[BloqExample]) -> Sequence[BloqExample]:
"""mypy compatible converter for BloqDocSpec.examples"""
return tuple(x)


@frozen(kw_only=True)
Expand Down Expand Up @@ -141,7 +144,7 @@ class BloqDocSpec:
"""

bloq_cls: Type[Bloq]
examples: Sequence[BloqExample] = field(converter=_to_tuple(BloqExample), factory=tuple)
examples: Sequence[BloqExample] = field(converter=_to_tuple, factory=tuple)
import_line: str = field()
call_graph_example: Union[BloqExample, None] = field()

Expand Down
25 changes: 19 additions & 6 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

Expand All @@ -49,11 +50,13 @@
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

# NDArrays must be bound to np.generic
_SoquetType = TypeVar('_SoquetType', bound=np.generic)

SoquetT = Union[Soquet, NDArray[Soquet]]
SoquetT = Union[Soquet, NDArray[_SoquetType]]
"""A `Soquet` or array of soquets."""

SoquetInT = Union[Soquet, NDArray[Soquet], Sequence[Soquet]]
SoquetInT = Union[Soquet, NDArray[_SoquetType], Sequence[Soquet]]
"""A soquet or array-like of soquets.
This type alias is used for input argument to parts of the library that are more
Expand All @@ -62,6 +65,16 @@
"""


def _to_tuple(x: Iterable[Connection]) -> Sequence[Connection]:
"""mypy-compatible attrs converter for CompositeBloq.connections"""
return tuple(x)


def _to_set(x: Iterable[BloqInstance]) -> FrozenSet[BloqInstance]:
"""mypy-compatible attrs converter for CompositeBloq.bloq_instances"""
return frozenset(x)


@attrs.frozen
class CompositeBloq(Bloq):
"""A bloq defined by a collection of sub-bloqs and dataflows between them
Expand All @@ -83,9 +96,9 @@ class CompositeBloq(Bloq):
should correspond to the dangling `Soquets` in the `cxns`.
"""

connections: Tuple[Connection, ...] = attrs.field(converter=tuple)
connections: Tuple[Connection, ...] = attrs.field(converter=_to_tuple)
signature: Signature
bloq_instances: FrozenSet[BloqInstance] = attrs.field(converter=frozenset)
bloq_instances: FrozenSet[BloqInstance] = attrs.field(converter=_to_set)

@bloq_instances.default
def _default_bloq_instances(self):
Expand Down Expand Up @@ -1081,7 +1094,7 @@ def free(self, soq: Soquet) -> None:

self.add(Free(dtype=soq.reg.dtype), reg=soq)

def split(self, soq: Soquet) -> NDArray[Soquet]:
def split(self, soq: Soquet) -> NDArray[Soquet]: # type: ignore[type-var]
"""Add a Split bloq to split up a register."""
from qualtran.bloqs.util_bloqs import Split

Expand All @@ -1090,7 +1103,7 @@ def split(self, soq: Soquet) -> NDArray[Soquet]:

return self.add(Split(dtype=soq.reg.dtype), reg=soq)

def join(self, soqs: NDArray[Soquet], dtype: Optional[QDType] = None) -> Soquet:
def join(self, soqs: NDArray[Soquet], dtype: Optional[QDType] = None) -> Soquet: # type: ignore[type-var]
from qualtran.bloqs.util_bloqs import Join

try:
Expand Down
Loading

0 comments on commit 34340ed

Please sign in to comment.