Skip to content

Commit

Permalink
Fix some types for mypy 1.13
Browse files Browse the repository at this point in the history
Signed-off-by: mimir-d <[email protected]>
  • Loading branch information
mimir-d committed Oct 25, 2024
1 parent 32a3c58 commit 9f9b9c3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/ocptv/output/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
JSON = ty.Union[ty.Dict[str, "JSON"], ty.List["JSON"], Primitive]


def _is_optional(field: ty.Type):
def _is_optional(field: ty.Type) -> bool:
# type hackery incoming
# ty.Optional[T] == ty.Union[T, None]
# since ty.Union[ty.Union[T,U]] = ty.Union[T,U] we can the
Expand All @@ -30,7 +30,7 @@ class ArtifactEmitter:
Uses the low level dataclass models for the spec, but should not be used in user code.
"""

def __init__(self, writer: Writer):
def __init__(self, writer: Writer) -> None:
self._seq_lock = threading.Lock()
self._seq = 0

Expand All @@ -41,7 +41,7 @@ def __init__(self, writer: Writer):
self._version_emitted = threading.Event()

@staticmethod
def _serialize(artifact: ArtifactType):
def _serialize(artifact: ArtifactType) -> str:
def visit(
value: ty.Union[ArtifactType, ty.Dict, ty.List, Primitive],
formatter: ty.Optional[ty.Callable[[ty.Any], str]] = None,
Expand All @@ -56,7 +56,7 @@ def visit(
val = getattr(value, field.name)

if val is None:
if not _is_optional(field.type):
if not _is_optional(ty.cast(ty.Type, field.type)):
# TODO: fix exception text/type
raise RuntimeError("unacceptable none where not optional")

Expand Down
6 changes: 3 additions & 3 deletions src/ocptv/output/runtime_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def __str__(self):


def _check_type_any(obj: CheckedValue, hint: ty.Type, trace: ty.List[str]):
type_origin = get_origin(hint)
type_args = get_args(hint)
type_origin = ty.cast(ty.Type | None, get_origin(hint))
type_args = ty.cast(ty.Tuple[ty.Type, ...], get_args(hint))

if type_origin is list:
# generic type: typ == ty.List[...]
Expand Down Expand Up @@ -186,7 +186,7 @@ def _check_type_any(obj: CheckedValue, hint: ty.Type, trace: ty.List[str]):
elif dc.is_dataclass(obj):
for field in dc.fields(obj):
subtrace = trace + [f"{obj.__class__.__name__}.{field.name}"]
_check_type_any(getattr(obj, field.name), field.type, subtrace)
_check_type_any(getattr(obj, field.name), ty.cast(ty.Type, field.type), subtrace)

elif not isinstance(obj, hint):
raise TypeCheckError(obj, expected=hint.__name__, trace=trace)
Expand Down

0 comments on commit 9f9b9c3

Please sign in to comment.