Skip to content

Commit

Permalink
Use async libraries for requests against Pseudo Service. Add exponent…
Browse files Browse the repository at this point in the history
…ial backoff retry (#412)

* Use asyncio for requests. Add exponential backoff

* Add pytest-asyncio

* Delete old endpoint method

* Amend tests for new method

* make exponential retry explicit

* Make mypy happy
  • Loading branch information
mallport authored Nov 14, 2024
1 parent 4c1a4af commit 211a392
Show file tree
Hide file tree
Showing 8 changed files with 2,931 additions and 2,417 deletions.
5,001 changes: 2,730 additions & 2,271 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pyarrow = ">=14.0.2, <15"
orjson = "^3.10.1"
wcmatch = "^8.5.1"
msgspec = ">=0.18.6"
aiohttp = ">=3.10.5"
aiohttp-retry = ">=2.9.1"
pytest-asyncio = "^0.24.0"

[tool.poetry.group.test.dependencies]
typeguard = ">=2.13.3"
Expand Down
2 changes: 1 addition & 1 deletion src/dapla_pseudo/globberator/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_path(schema_path: str, rules_path: str) -> "SchemaTraverser":
def write_rules(file_path: str, rules: list[PseudoRule]) -> None:
"""Write rules to file."""
opener = GCSFileSystem.open if file_path.startswith("gs://") else open
with opener(file_path) as rules_file:
with opener(file_path, mode="wb") as rules_file:
rules_file.write(msgspec.json.encode(rules))

def match_rules(self, separator: str = "/") -> list[PseudoRule]:
Expand Down
62 changes: 13 additions & 49 deletions src/dapla_pseudo/v1/baseclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
and descriptive than the user-friendly methods that are exposed.
"""

import asyncio
import json
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed
from datetime import date
from typing import cast

import msgspec
import polars as pl
import requests

from dapla_pseudo.constants import Env
from dapla_pseudo.constants import MapFailureStrategy
Expand Down Expand Up @@ -127,56 +124,23 @@ def _pseudonymize_field(
) -> PseudoFieldResponse:
"""Pseudonymizes the specified fields in the DataFrame using the provided pseudonymization function.
The pseudonymization is performed in parallel. After the parallel processing is finished,
The pseudonymization is performed concurrently. After the processing is finished,
the pseudonymized fields replace the original fields in the DataFrame stored in `self._dataframe`.
"""

def pseudonymize_field_runner(
request: PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest,
) -> tuple[str, list[str], RawPseudoMetadata]:
"""Function that performs the pseudonymization on a Polars Series."""
if (
type(request) == PseudoFieldRequest
and request.pseudo_func.function_type == PseudoFunctionTypes.REDACT
):
## If we redact, we do this inside this library
## to avoid making API calls towards Pseudo Service
name, data, metadata = _BasePseudonymizer._redact_field(request)
return name, data, metadata

else:
response: requests.Response = (
self._pseudo_client._post_to_field_endpoint(
f"{self._pseudo_operation.value}/field",
request,
timeout,
stream=True,
)
)
payload = msgspec.json.decode(response.content.decode("utf-8"))
data = payload["data"]
metadata = RawPseudoMetadata(
field_name=request.name,
logs=payload["logs"],
metrics=payload["metrics"],
datadoc=payload["datadoc_metadata"]["pseudo_variables"],
)

return request.name, data, metadata

# type narrowing isn't carried over from caller function
assert isinstance(self._dataset, MutableDataFrame)
# Execute the pseudonymization API calls in parallel
with ThreadPoolExecutor() as executor:
raw_metadata_fields: list[RawPseudoMetadata] = []
futures = [
executor.submit(pseudonymize_field_runner, request)
for request in pseudo_requests
]
for future in as_completed(futures):
field_name, data, raw_metadata = future.result()
self._dataset.update(field_name, data)
raw_metadata_fields.append(raw_metadata)

raw_metadata_fields: list[RawPseudoMetadata] = []
for field_name, data, raw_metadata in asyncio.run(
self._pseudo_client.post_to_field_endpoint(
path=f"{self._pseudo_operation.value}/field",
timeout=timeout,
pseudo_requests=pseudo_requests,
)
):
self._dataset.update(field_name, data)
raw_metadata_fields.append(raw_metadata)

return PseudoFieldResponse(
data=self._dataset.to_polars(), raw_metadata=raw_metadata_fields
Expand Down
111 changes: 84 additions & 27 deletions src/dapla_pseudo/v1/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Module that implements a client abstraction that makes it easy to communicate with the Dapla Pseudo Service REST API."""

import asyncio
import os
import typing as t
from datetime import date

import google.auth.transport.requests
import google.oauth2.id_token
import requests
from aiohttp import ClientResponse
from aiohttp import ClientSession
from aiohttp import TCPConnector
from aiohttp_retry import ExponentialRetry
from aiohttp_retry import RetryClient
from dapla import AuthClient
from ulid import ULID

Expand All @@ -15,6 +21,7 @@
from dapla_pseudo.types import FileSpecDecl
from dapla_pseudo.v1.models.api import DepseudoFieldRequest
from dapla_pseudo.v1.models.api import PseudoFieldRequest
from dapla_pseudo.v1.models.api import RawPseudoMetadata
from dapla_pseudo.v1.models.api import RepseudoFieldRequest
from dapla_pseudo.v1.models.core import Mimetypes

Expand Down Expand Up @@ -51,13 +58,87 @@ def __auth_token(self) -> str:
else str(self.static_auth_token)
)

async def post_to_field_endpoint(
self,
path: str,
timeout: int,
pseudo_requests: list[
PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest
],
) -> list[tuple[str, list[str], RawPseudoMetadata]]:
"""Post a request to the Pseudo Service field endpoint.
Args:
path (str): Full URL to the endpoint
timeout (int): Request timeout
pseudo_requests: Pseudo requests
Returns:
list[tuple[str, list[str], RawPseudoMetadata]]: A list of tuple of (field_name, data, metadata)
"""

async def _post(
client: RetryClient,
path: str,
timeout: int,
request: PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest,
) -> tuple[str, list[str], RawPseudoMetadata]:
async with client.post(
url=f"{self.pseudo_service_url}/{path}",
headers={
"Authorization": f"Bearer {self.__auth_token()}",
"Content-Type": Mimetypes.JSON.value,
"X-Correlation-Id": PseudoClient._generate_new_correlation_id(),
},
json={"request": request.model_dump(by_alias=True)},
timeout=timeout,
) as response:
await PseudoClient._handle_response_error(response)
response_json = await response.json()
data = response_json["data"]
metadata = RawPseudoMetadata(
field_name=request.name,
logs=response_json["logs"],
metrics=response_json["metrics"],
datadoc=response_json["datadoc_metadata"]["pseudo_variables"],
)

return request.name, data, metadata

aio_session = ClientSession(connector=TCPConnector(limit=200))
async with RetryClient(
client_session=aio_session,
retry_options=ExponentialRetry(
attempts=5, start_timeout=0.1, max_timeout=30, factor=2
),
) as client:
results = await asyncio.gather(
*[
_post(client=client, path=path, timeout=timeout, request=req)
for req in pseudo_requests
]
)

return results

@staticmethod
def _generate_new_correlation_id() -> str:
return str(ULID())

@staticmethod
def _handle_response_error(response: requests.Response) -> None:
async def _handle_response_error(response: ClientResponse) -> None:
"""Report error messages in response object."""
match response.status:
case status if status in range(200, 300):
pass
case _:
print(response.headers)
print(await response.text())
response.raise_for_status()

@staticmethod
def _handle_response_error_sync(response: requests.Response) -> None:
"""Report error messages in response object. For synchronous callers."""
match response.status_code:
case status if status in range(200, 300):
pass
Expand Down Expand Up @@ -92,31 +173,7 @@ def _post_to_file_endpoint(
timeout=timeout,
)

PseudoClient._handle_response_error(response)
return response

def _post_to_field_endpoint(
self,
path: str,
pseudo_field_request: (
PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest
),
timeout: int,
stream: bool = False,
) -> requests.Response:
response = requests.post(
url=f"{self.pseudo_service_url}/{path}",
headers={
"Authorization": f"Bearer {self.__auth_token()}",
"Content-Type": Mimetypes.JSON.value,
"X-Correlation-Id": PseudoClient._generate_new_correlation_id(),
},
json={"request": pseudo_field_request.model_dump(by_alias=True)},
stream=stream,
timeout=timeout,
)

PseudoClient._handle_response_error(response)
PseudoClient._handle_response_error_sync(response)
return response

def _post_to_sid_endpoint(
Expand All @@ -140,7 +197,7 @@ def _post_to_sid_endpoint(
timeout=TIMEOUT_DEFAULT, # seconds
)

PseudoClient._handle_response_error(response)
PseudoClient._handle_response_error_sync(response)
return response


Expand Down
5 changes: 0 additions & 5 deletions src/dapla_pseudo/v1/result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Common API models for builder packages."""

from collections import Counter
from io import BufferedWriter
from pathlib import Path
from typing import Any
from typing import cast
Expand Down Expand Up @@ -161,10 +160,6 @@ def to_file(self, file_path: str, **kwargs: Any) -> None:
datadoc_file_path = Path(file_path).parent.joinpath(Path(datadoc_file_name))
datadoc_file_handle = datadoc_file_path.open(mode="w")

file_handle = cast(
BufferedWriter, file_handle
) # file handle is always BufferedWriter when opening with "wb"

match self._pseudo_data:
case pl.DataFrame() as df:
write_from_df(df, file_format, file_handle, **kwargs)
Expand Down
61 changes: 32 additions & 29 deletions tests/v1/unit/test_base_pseudonymizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dapla_pseudo.v1.models.api import PseudoFieldResponse
from dapla_pseudo.v1.models.api import PseudoFileRequest
from dapla_pseudo.v1.models.api import PseudoFileResponse
from dapla_pseudo.v1.models.api import RawPseudoMetadata
from dapla_pseudo.v1.models.core import File
from dapla_pseudo.v1.models.core import Mimetypes
from dapla_pseudo.v1.models.core import PseudoConfig
Expand Down Expand Up @@ -87,35 +88,36 @@ def test_pseudonymize_field(
values=list(df_personer["fnr"]),
)

expected_json = {
"data": ["jJuuj0i", "ylc9488", "yeLfkaL"],
"logs": [],
"metrics": [{"MAPPED_SID": 3}],
"datadoc_metadata": {
"pseudo_variables": [
{
"short_name": "fnr",
"data_element_path": "fnr",
"data_element_pattern": "fnr*",
"stable_identifier_type": "FREG_SNR",
"stable_identifier_version": "2023-08-31",
"encryption_algorithm": "TINK-FPE",
"encryption_key_reference": "papis-common-key-1",
"encryption_algorithm_parameters": [
{"keyId": "papis-common-key-1"},
{"strategy": "skip"},
],
mocked_data = ["jJuuj0i", "ylc9488", "yeLfkaL"]
mocked_metadata = RawPseudoMetadata(
field_name="fnr",
logs=[],
metrics=[{"MAPPED_SID": 3}],
datadoc=[
{
"datadoc_metadata": {
"pseudo_variables": {
"short_name": "fnr",
"data_element_path": "fnr",
"data_element_pattern": "fnr*",
"stable_identifier_type": "FREG_SNR",
"stable_identifier_version": "2023-08-31",
"encryption_algorithm": "TINK-FPE",
"encryption_key_reference": "papis-common-key-1",
"encryption_algorithm_parameters": [
{"keyId": "papis-common-key-1"},
{"strategy": "skip"},
],
}
}
]
},
}

mocked_response = Mock(content=bytes(json.dumps(expected_json), encoding="utf-8"))
}
],
)

mocked_post_to_field = mocker.patch(
"dapla_pseudo.v1.client.PseudoClient._post_to_field_endpoint",
mocked_asyncio_run = mocker.patch(
"dapla_pseudo.v1.baseclasses.asyncio.run",
)
mocked_post_to_field.return_value = mocked_response
mocked_asyncio_run.return_value = [("fnr", mocked_data, mocked_metadata)]
base = _BasePseudonymizer(
pseudo_operation=PseudoOperation.PSEUDONYMIZE,
dataset=df_personer,
Expand All @@ -125,9 +127,10 @@ def test_pseudonymize_field(
response = base._pseudonymize_field([sid_req], timeout=ANY)
metadata = response.raw_metadata[0]
assert isinstance(response, PseudoFieldResponse)
assert metadata.datadoc == expected_json["datadoc_metadata"]["pseudo_variables"] # type: ignore[index]
assert metadata.logs == expected_json["logs"]
assert metadata.metrics == expected_json["metrics"]

assert metadata.datadoc == mocked_metadata.datadoc
assert metadata.logs == mocked_metadata.logs
assert metadata.metrics == mocked_metadata.metrics


def test_pseudonymize_dataset(
Expand Down
Loading

0 comments on commit 211a392

Please sign in to comment.