Skip to content

Commit

Permalink
Merge pull request #31 from taskiq-python/feature/skip-undefined
Browse files Browse the repository at this point in the history
Skip functions with undefined types.
  • Loading branch information
s3rius authored Dec 9, 2024
2 parents 6caa86e + 0b7518b commit 3622b62
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 6 deletions.
12 changes: 12 additions & 0 deletions taskiq_dependencies/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,15 @@ def __eq__(self, rhs: object) -> bool:
if not isinstance(rhs, Dependency):
return False
return self._id == rhs._id

def __repr__(self) -> str:
func_name = str(self.dependency)
if self.dependency is not None and hasattr(self.dependency, "__name__"):
func_name = self.dependency.__name__
return (
f"Dependency({func_name}, "
f"use_cache={self.use_cache}, "
f"kwargs={self.kwargs}, "
f"parent={self.parent}"
")"
)
57 changes: 52 additions & 5 deletions taskiq_dependencies/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
import sys
import warnings
from collections import defaultdict, deque
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints

from graphlib import TopologicalSorter
Expand Down Expand Up @@ -171,19 +173,64 @@ def _build_graph(self) -> None: # noqa: C901
if inspect.isclass(origin):
# If this is a class, we need to get signature of
# an __init__ method.
hints = get_type_hints(origin.__init__)
try:
hints = get_type_hints(origin.__init__)
except NameError:
_, src_lineno = inspect.getsourcelines(origin)
src_file = Path(inspect.getfile(origin)).relative_to(
Path.cwd(),
)
warnings.warn(
"Cannot resolve type hints for "
f"a class {origin.__name__} defined "
f"at {src_file}:{src_lineno}.",
RuntimeWarning,
stacklevel=2,
)
continue
sign = inspect.signature(
origin.__init__,
**signature_kwargs,
)
elif inspect.isfunction(dep.dependency):
# If this is function or an instance of a class, we get it's type hints.
hints = get_type_hints(dep.dependency)
try:
hints = get_type_hints(dep.dependency)
except NameError:
_, src_lineno = inspect.getsourcelines(dep.dependency) # type: ignore
src_file = Path(inspect.getfile(dep.dependency)).relative_to(
Path.cwd(),
)
warnings.warn(
"Cannot resolve type hints for "
f"a function {dep.dependency.__name__} defined "
f"at {src_file}:{src_lineno}.",
RuntimeWarning,
stacklevel=2,
)
continue
sign = inspect.signature(origin, **signature_kwargs) # type: ignore
else:
hints = get_type_hints(
dep.dependency.__call__, # type: ignore
)
try:
hints = get_type_hints(
dep.dependency.__call__, # type: ignore
)
except NameError:
_, src_lineno = inspect.getsourcelines(dep.dependency.__class__)
src_file = Path(
inspect.getfile(dep.dependency.__class__),
).relative_to(
Path.cwd(),
)
cls_name = dep.dependency.__class__.__name__
warnings.warn(
"Cannot resolve type hints for "
f"an object of class {cls_name} defined "
f"at {src_file}:{src_lineno}.",
RuntimeWarning,
stacklevel=2,
)
continue
sign = inspect.signature(origin, **signature_kwargs) # type: ignore

# Now we need to iterate over parameters, to
Expand Down
63 changes: 62 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import re
import uuid
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator, Generic, Tuple, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generator,
Generic,
Tuple,
TypeVar,
)

import pytest

Expand Down Expand Up @@ -891,3 +899,56 @@ def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
assert info.name == ""
assert info.definition is None
assert info.graph == graph


def test_skip_type_checking_function() -> None:
"""Test if we can skip type only for type checking for the function."""
if TYPE_CHECKING:

class A:
pass

def target(unknown: "A") -> None:
pass

with pytest.warns(RuntimeWarning, match=r"Cannot resolve.*function target.*"):
graph = DependencyGraph(target=target)
with graph.sync_ctx() as ctx:
assert "unknown" not in ctx.resolve_kwargs()


def test_skip_type_checking_class() -> None:
"""Test if we can skip type only for type checking for the function."""
if TYPE_CHECKING:

class A:
pass

class Target:
def __init__(self, unknown: "A") -> None:
pass

with pytest.warns(RuntimeWarning, match=r"Cannot resolve.*class Target.*"):
graph = DependencyGraph(target=Target)
with graph.sync_ctx() as ctx:
assert "unknown" not in ctx.resolve_kwargs()


def test_skip_type_checking_object() -> None:
"""Test if we can skip type only for type checking for the function."""
if TYPE_CHECKING:

class A:
pass

class Target:
def __call__(self, unknown: "A") -> None:
pass

with pytest.warns(
RuntimeWarning,
match=r"Cannot resolve.*object of class Target.*",
):
graph = DependencyGraph(target=Target())
with graph.sync_ctx() as ctx:
assert "unknown" not in ctx.resolve_kwargs()

0 comments on commit 3622b62

Please sign in to comment.