Skip to content

Commit

Permalink
add type annotations for function arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
helylle committed Sep 27, 2024
1 parent 43e48d4 commit 941532f
Show file tree
Hide file tree
Showing 116 changed files with 498 additions and 383 deletions.
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ line-length = 120
target-version = "py310"

[lint]
select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA"]
select = ["E", "F", "W", "I", "ASYNC", "UP", "FLY", "PERF", "FURB", "ERA", "ANN001"]

ignore = ["E501"]
2 changes: 1 addition & 1 deletion src/eduid/common/clients/amapi_client/amapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class AMAPIClient(GNAPClient):
def __init__(self, amapi_url: str, auth_data=GNAPClientAuthData, verify_tls: bool = True, **kwargs):
def __init__(self, amapi_url: str, auth_data: GNAPClientAuthData, verify_tls: bool = True, **kwargs):
super().__init__(auth_data=auth_data, verify=verify_tls, **kwargs)
self.amapi_url = amapi_url

Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/config/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class BadConfiguration(Exception):
def __init__(self, message):
def __init__(self, message: str):
Exception.__init__(self)
self.value = message

Expand Down
6 changes: 3 additions & 3 deletions src/eduid/common/config/parsers/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from functools import wraps
from string import Template
from typing import Any
Expand All @@ -11,7 +11,7 @@
from eduid.common.config.parsers.exceptions import SecretKeyException


def decrypt(f):
def decrypt(f: Callable):
@wraps(f)
def decrypt_decorator(*args, **kwargs):
config_dict = f(*args, **kwargs)
Expand Down Expand Up @@ -83,7 +83,7 @@ def decrypt_config(config_dict: Mapping[str, Any]) -> Mapping[str, Any]:
return new_config_dict


def interpolate(f):
def interpolate(f: Callable):
@wraps(f)
def interpolation_decorator(*args, **kwargs):
config_dict = f(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/config/parsers/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class ParserException(Exception):
def __init__(self, message):
def __init__(self, message: str):
Exception.__init__(self)
self.value = message

Expand Down
5 changes: 3 additions & 2 deletions src/eduid/common/decorators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import inspect
import warnings
from collections.abc import Callable
from functools import wraps


# https://stackoverflow.com/questions/2536307/how-do-i-deprecate-python-functions/40301488#40301488
def deprecated(reason):
def deprecated(reason: str | type | Callable):
"""
This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
Expand All @@ -20,7 +21,7 @@ def deprecated(reason):
# def old_function(x, y):
# pass

def decorator(func1):
def decorator(func1: Callable):
if inspect.isclass(func1):
fmt1 = "Call to deprecated class {name} ({reason})."
else:
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/fastapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def reset_failure_info(req: ContextRequest, key: str) -> None:
req.app.context.logger.info(f"Check {key} back to normal. Resetting info {info}")


def check_restart(key, restart: int, terminate: int) -> bool:
def check_restart(key: str, restart: int, terminate: int) -> bool:
res = False # default to no restart
info = FAILURE_INFO.get(key)
if not info:
Expand Down
8 changes: 4 additions & 4 deletions src/eduid/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

# Default to RFC3339/ISO 8601 with tz
class EduidFormatter(logging.Formatter):
def __init__(self, relative_time: bool = False, fmt=None):
def __init__(self, relative_time: bool = False, fmt: str | None = None):
super().__init__(fmt=fmt, style="{")
self._relative_time = relative_time

def formatTime(self, record: logging.LogRecord, datefmt=None) -> str:
def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
if self._relative_time:
# Relative time makes much more sense than absolute time when running tests for example
_seconds = record.relativeCreated / 1000
Expand All @@ -52,7 +52,7 @@ def formatTime(self, record: logging.LogRecord, datefmt=None) -> str:
class AppFilter(logging.Filter):
"""Add `system_hostname`, `hostname` and `app_name` to records being logged."""

def __init__(self, app_name):
def __init__(self, app_name: str):
super().__init__()
self.app_name = app_name
# TODO: I guess it could be argued that these should be put in the LocalContext and not evaluated at runtime.
Expand Down Expand Up @@ -139,7 +139,7 @@ def filter(self, record: logging.LogRecord) -> bool:
def merge_config(base_config: dict[str, Any], new_config: dict[str, Any]) -> dict[str, Any]:
"""Recursively merge two dictConfig dicts."""

def merge(node, key, value):
def merge(node: dict[str, Any], key: str, value: Any):
if isinstance(value, dict):
for item in value:
if key in node:
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/common/models/scim_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def is_group(self):
return self.ref and "/Groups/" in self.ref

@classmethod
def from_mapping(cls, data):
def from_mapping(cls, data: Any):
return cls.model_validate(data)


Expand Down
11 changes: 6 additions & 5 deletions src/eduid/common/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
__author__ = "ft"

from abc import ABC, abstractmethod
from logging import Logger

from eduid.common.config.base import StatsConfigMixin

Expand All @@ -23,7 +24,7 @@ class AppStats(ABC):
def count(self, name: str, value: int = 1) -> None:
pass

def gauge(self, name: str, value: int, rate=1, delta=False):
def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False):
pass


Expand All @@ -35,7 +36,7 @@ class NoOpStats(AppStats):
configured allows us to not check if current_app.stats is set everywhere.
"""

def __init__(self, logger=None, prefix=None):
def __init__(self, logger: Logger | None = None, prefix: str | None = None):
self.logger = logger
self.prefix = prefix

Expand All @@ -45,15 +46,15 @@ def count(self, name: str, value: int = 1) -> None:
name = f"{self.prefix!s}.{name!s}"
self.logger.info(f"No-op stats count: {name!r} {value!r}")

def gauge(self, name: str, value: int, rate=1, delta=False):
def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False):
if self.logger:
if self.prefix:
name = f"{self.prefix!s}.{name!s}"
self.logger.info(f"No-op stats gauge: {name} {value}")


class Statsd(AppStats):
def __init__(self, host, port, prefix=None):
def __init__(self, host: str, port: int, prefix: str | None = None):
import statsd

self.client = statsd.StatsClient(host, port, prefix=prefix)
Expand All @@ -64,7 +65,7 @@ def count(self, name: str, value: int = 1) -> None:
# for .count
self.client.incr(f"{name}.count", count=value)

def gauge(self, name: str, value: int, rate=1, delta=False):
def gauge(self, name: str, value: int, rate: int = 1, delta: bool = False):
self.client.gauge(f"{name}.gauge", value=value, rate=rate, delta=delta)


Expand Down
26 changes: 19 additions & 7 deletions src/eduid/graphdb/groupdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,28 @@ def _add_or_update_users_and_groups(

for user_member in group.member_users:
res = self._add_user_to_group(tx, group=group, member=user_member, role=Role.MEMBER)
members.add(User.from_mapping(res.data()))
if res:
members.add(User.from_mapping(res.data()))
else:
logger.info(f"User {user_member.identifier} not added to group {group.identifier}.")
for group_member in group.member_groups:
res = self._add_group_to_group(tx, group=group, member=group_member, role=Role.MEMBER)
members.add(self._load_group(res.data()))
if res:
members.add(self._load_group(res.data()))
else:
logger.info(f"Group {group_member.identifier} not added to group {group.identifier}.")
for user_owner in group.owner_users:
res = self._add_user_to_group(tx, group=group, member=user_owner, role=Role.OWNER)
owners.add(User.from_mapping(res.data()))
if res:
owners.add(User.from_mapping(res.data()))
else:
logger.info(f"User {user_owner.identifier} not added to group {group.identifier}.")
for group_owner in group.owner_groups:
res = self._add_group_to_group(tx, group=group, member=group_owner, role=Role.OWNER)
owners.add(self._load_group(res.data()))
if res:
owners.add(self._load_group(res.data()))
else:
logger.info(f"Group {group_owner.identifier} not added to group {group.identifier}.")
return members, owners

def _remove_missing_users_and_groups(self, tx: Transaction, group: Group, role: Role) -> None:
Expand Down Expand Up @@ -161,7 +173,7 @@ def _remove_user_from_group(self, tx: Transaction, group: Group, user_identifier
"""
tx.run(q, scope=self.scope, identifier=group.identifier, user_identifier=user_identifier)

def _add_group_to_group(self, tx, group: Group, member: Group, role: Role) -> Record:
def _add_group_to_group(self, tx: Transaction, group: Group, member: Group, role: Role) -> Record | None:
q = f"""
MATCH (g:Group {{scope: $scope, identifier: $group_identifier}})
MERGE (m:Group {{scope: $scope, identifier: $identifier}})
Expand All @@ -187,7 +199,7 @@ def _add_group_to_group(self, tx, group: Group, member: Group, role: Role) -> Re
display_name=member.display_name,
).single()

def _add_user_to_group(self, tx, group: Group, member: User, role: Role) -> Record:
def _add_user_to_group(self, tx: Transaction, group: Group, member: User, role: Role) -> Record | None:
q = f"""
MATCH (g:Group {{scope: $scope, identifier: $group_identifier}})
MERGE (m:User {{identifier: $identifier}})
Expand Down Expand Up @@ -276,7 +288,7 @@ def remove_group(self, identifier: str) -> None:
with self.db.driver.session(default_access_mode=WRITE_ACCESS) as session:
session.run(q, scope=self.scope, identifier=identifier)

def get_groups_by_property(self, key: str, value: str, skip=0, limit=100):
def get_groups_by_property(self, key: str, value: str, skip: int = 0, limit: int = 100):
res: list[Group] = []
q = f"""
MATCH (g: Group {{scope: $scope}})
Expand Down
4 changes: 3 additions & 1 deletion src/eduid/graphdb/tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from neo4j import basic_auth

from eduid.graphdb.db import BaseGraphDB
Expand All @@ -17,7 +19,7 @@ def test_create_db(self):

class TestBaseGraphDB(Neo4jTestCase):
class TestDB(BaseGraphDB):
def __init__(self, db_uri, config=None):
def __init__(self, db_uri: str, config: dict[str, Any] | None = None):
super().__init__(db_uri, config=config)

def db_setup(self):
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/graphdb/tests/test_groupdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self) -> None:
self.user2: dict[str, str] = {"identifier": "user2", "display_name": "Namn Namnsson"}

@staticmethod
def _assert_group(expected: Group, testing: Group, modified=False):
def _assert_group(expected: Group, testing: Group, modified: bool = False):
assert expected.identifier == testing.identifier
assert expected.display_name == testing.display_name
assert testing.created_ts is not None
Expand Down
3 changes: 2 additions & 1 deletion src/eduid/maccapi/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jwcrypto.common import JWException
from pydantic import ValidationError
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.types import ASGIApp

from eduid.common.fastapi.context_request import ContextRequestMixin
from eduid.common.models.bearer_token import AuthnBearerToken, RequestedAccessDenied
Expand All @@ -22,7 +23,7 @@ def return_error_response(status_code: int, detail: str) -> JSONResponse:


class AuthenticationMiddleware(BaseHTTPMiddleware, ContextRequestMixin):
def __init__(self, app, context: Context):
def __init__(self, app: ASGIApp, context: Context):
super().__init__(app)
self.context = context

Expand Down
2 changes: 1 addition & 1 deletion src/eduid/queue/db/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def parse_queue_item(self, doc: Mapping, parse_payload: bool = True):
return item
return replace(item, payload=self._load_payload(item))

async def grab_item(self, item_id: str | ObjectId, worker_name: str, regrab=False) -> QueueItem | None:
async def grab_item(self, item_id: str | ObjectId, worker_name: str, regrab: bool = False) -> QueueItem | None:
"""
:param item_id: document id
:param worker_name: current workers name
Expand Down
21 changes: 13 additions & 8 deletions src/eduid/queue/decorators.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from collections.abc import Callable
from inspect import isclass
from typing import Any

from pymongo.synchronous.collection import Collection

from eduid.userdb.db import MongoDB

# TODO: Refactor but keep transaction audit document structure
from eduid.userdb.db.base import TUserDbDocument
from eduid.userdb.util import utc_now


class TransactionAudit:
enabled = False

def __init__(self, db_uri, db_name="eduid_queue", collection_name="transaction_audit"):
self._conn = None
self.db_uri = db_uri
self.db_name = db_name
self.collection_name = collection_name
self.collection = None
def __init__(self, db_uri: str, db_name: str = "eduid_queue", collection_name: str = "transaction_audit"):
self._conn: MongoDB | None = None
self.db_uri: str = db_uri
self.db_name: str = db_name
self.collection_name: str = collection_name
self.collection: Collection[TUserDbDocument] | None = None

def __call__(self, f):
def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]:
if not self.enabled:
return f

Expand Down Expand Up @@ -47,7 +52,7 @@ def disable(cls):
cls.enabled = False

@staticmethod
def _filter(func, data, *args, **kwargs):
def _filter(func: str, data: Any, *args, **kwargs):
if data is False:
return data
if func == "_get_navet_data":
Expand Down
6 changes: 3 additions & 3 deletions src/eduid/queue/tests/test_mail_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from datetime import timedelta
from os import environ
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from aiosmtplib import SMTPResponse

Expand Down Expand Up @@ -82,7 +82,7 @@ async def test_eduid_signup_mail_from_stream(self):
await self._assert_item_gets_processed(queue_item)

@patch("aiosmtplib.SMTP.sendmail")
async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail):
async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_sendmail: MagicMock):
"""
Test that saved queue items are handled by the handle_new_item method
"""
Expand All @@ -99,7 +99,7 @@ async def test_eduid_signup_mail_from_stream_unrecoverable_error(self, mock_send
await self._assert_item_gets_processed(queue_item)

@patch("aiosmtplib.SMTP.sendmail")
async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail):
async def test_eduid_signup_mail_from_stream_error_retry(self, mock_sendmail: MagicMock):
"""
Test that saved queue items are handled by the handle_new_item method
"""
Expand Down
2 changes: 1 addition & 1 deletion src/eduid/queue/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


def cancel_task(signame, task):
def cancel_task(signame: str, task: Task):
logger.info(f"got signal {signame}: exit")
task.cancel()

Expand Down
Loading

0 comments on commit 941532f

Please sign in to comment.