Skip to content

Commit

Permalink
Make it possible to serialize models
Browse files Browse the repository at this point in the history
  • Loading branch information
jennydaman committed Jul 25, 2024
1 parent a7d3626 commit a09bc24
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 33 deletions.
27 changes: 24 additions & 3 deletions src/aiochris/link/linked.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,35 @@ def mark_to_check(mcs, method: Callable, link_name: str) -> None:
setattr(method, mcs.__DECORATED_METHOD_MARK, link_name)


@serde.deserialize
_S = TypeVar("_S")


def _deserialize_noop_hack(s: _S) -> _S:
setattr(s, "_AIOCHRIS_LINKED_DESERIALIZED", True)
return s


def _skip_hacked(s: any) -> bool:
if getattr(s, "_AIOCHRIS_LINKED_DESERIALIZED"):
return False
return True


@serde.serde
@dataclasses.dataclass(frozen=True)
class Linked(abc.ABC, metaclass=LinkedMeta):
"""
A `Linked` is an object which can make HTTP requests to links from an API.
"""

s: aiohttp.ClientSession = serde.field(deserializer=lambda s: s)
# The functions `_deserialize_noop_hack` and `_skip_hacked` are used with `serde.field`
# so that `s` is included in deserialization, but excluded when serialized.
s: aiohttp.ClientSession = serde.field(
deserializer=_deserialize_noop_hack,
serializer=lambda s: s,
skip_if=_skip_hacked,
)

max_search_requests: int
"""
Maximum number of requests to make for pagination.
Expand All @@ -103,7 +124,7 @@ def _has_link(cls, name: str) -> bool: ...
def _get_link(self, name: str) -> yarl.URL: ...


@serde.deserialize
@serde.serde
@dataclasses.dataclass(frozen=True)
class LinkedModel(Linked, abc.ABC):
"""
Expand Down
14 changes: 7 additions & 7 deletions src/aiochris/models/collection_links.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import dataclasses
import functools
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Optional
from collections.abc import Iterator

from serde import deserialize
from serde import serde

from aiochris.types import ApiUrl, UserUrl, AdminUrl


@deserialize
@serde
@dataclass(frozen=True)
class AbstractCollectionLinks:
@classmethod
Expand All @@ -33,7 +33,7 @@ def _dict(self) -> dict[str, str]:
return dataclasses.asdict(self)


@deserialize
@serde
@dataclass(frozen=True)
class AnonymousCollectionLinks(AbstractCollectionLinks):
chrisinstance: ApiUrl
Expand All @@ -52,7 +52,7 @@ class AnonymousCollectionLinks(AbstractCollectionLinks):
pipeline_instances: Optional[ApiUrl] # removed in CUBE version 6


@deserialize
@serde
@dataclass(frozen=True)
class CollectionLinks(AnonymousCollectionLinks):
user: UserUrl
Expand All @@ -70,13 +70,13 @@ def useruploadedfiles(self) -> ApiUrl:
return self.userfiles


@deserialize
@serde
@dataclass(frozen=True)
class AdminCollectionLinks(CollectionLinks):
admin: AdminUrl


@deserialize
@serde
@dataclass(frozen=True)
class AdminApiCollectionLinks(AbstractCollectionLinks):
compute_resources: ApiUrl
10 changes: 5 additions & 5 deletions src/aiochris/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import datetime
from typing import Optional

from serde import deserialize
from serde import serde

from aiochris.link.linked import LinkedModel
from aiochris.enums import PluginType, Status
from aiochris.types import *


@deserialize
@serde
@dataclass(frozen=True)
class UserData:
"""A *CUBE* user."""
Expand All @@ -28,7 +28,7 @@ class UserData:


# TODO It'd be better to use inheritance instead of optionals
@deserialize
@serde
@dataclass(frozen=True)
class PluginInstanceData(LinkedModel):
"""
Expand Down Expand Up @@ -90,7 +90,7 @@ class PluginInstanceData(LinkedModel):
"""


@deserialize
@serde
@dataclass(frozen=True)
class FeedData(LinkedModel):
url: FeedUrl
Expand All @@ -116,7 +116,7 @@ class FeedData(LinkedModel):
plugin_instances: PluginInstancesUrl


@deserialize
@serde
@dataclass(frozen=True)
class FeedNoteData(LinkedModel):
url: FeedUrl
Expand Down
20 changes: 10 additions & 10 deletions src/aiochris/models/logged_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@

import asyncio
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional
from collections.abc import Sequence

from serde import deserialize
from serde import serde

from aiochris.enums import PluginType, Status
from aiochris.link import http
from aiochris.link.linked import LinkedModel
from aiochris.models.data import PluginInstanceData, FeedData, UserData, FeedNoteData
from aiochris.enums import PluginType, Status
from aiochris.models.public import PublicPlugin
from aiochris.types import *


@deserialize
@serde
@dataclass(frozen=True)
class User(UserData, LinkedModel):
pass # TODO change_email, change_password


@deserialize
@serde
@dataclass(frozen=True)
class File(LinkedModel):
"""
Expand Down Expand Up @@ -59,7 +59,7 @@ def parent(self) -> str:
# TODO download methods


@deserialize
@serde
@dataclass(frozen=True)
class PACSFile(File):
"""
Expand All @@ -86,7 +86,7 @@ class PACSFile(File):
pacs_identifier: str


@deserialize
@serde
@dataclass(frozen=True)
class PluginInstance(PluginInstanceData):
@http.get("feed")
Expand Down Expand Up @@ -158,7 +158,7 @@ async def wait(
return (time.monotonic_ns() - start) / 1e9, cur


@deserialize
@serde
@dataclass(frozen=True)
class FeedNote(FeedNoteData):
@http.get("feed")
Expand All @@ -174,7 +174,7 @@ async def set(
...


@deserialize
@serde
@dataclass(frozen=True)
class Feed(FeedData):
"""
Expand All @@ -201,7 +201,7 @@ async def set(
async def get_note(self) -> FeedNote: ...


@deserialize
@serde
@dataclass(frozen=True)
class Plugin(PublicPlugin):
"""
Expand Down
7 changes: 3 additions & 4 deletions src/aiochris/models/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Optional, Literal, TextIO

import serde
from serde import deserialize

from aiochris.enums import PluginType
from aiochris.link import http
Expand All @@ -16,7 +15,7 @@
from aiochris.util.search import Search


@deserialize
@serde.serde
@dataclass(frozen=True)
class ComputeResource:
url: ApiUrl
Expand All @@ -30,7 +29,7 @@ class ComputeResource:
max_job_exec_seconds: int


@deserialize
@serde.serde
@dataclass(frozen=True)
class PluginParameter(LinkedModel):
"""
Expand All @@ -51,7 +50,7 @@ class PluginParameter(LinkedModel):
plugin: PluginUrl


@deserialize
@serde.serde
@dataclass(frozen=True)
class PublicPlugin(LinkedModel):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/aiochris/util/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
from collections.abc import AsyncIterable, AsyncIterator, AsyncGenerator
from dataclasses import dataclass
from typing import (
Optional,
Expand All @@ -8,25 +9,24 @@
Any,
Generic,
)
from collections.abc import AsyncIterable, AsyncIterator, AsyncGenerator

import yarl
from serde import deserialize
from serde import serde
from serde.json import from_json

from aiochris.link.linked import deserialize_linked, Linked
from aiochris.errors import (
BaseClientError,
raise_for_status,
NonsenseResponseError,
)
from aiochris.link.linked import deserialize_linked, Linked

logger = logging.getLogger(__name__)

T = TypeVar("T")


@deserialize
@serde
class _Paginated:
"""
Response from a paginated endpoint.
Expand Down
10 changes: 10 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Awaitable

import pytest
import serde
from aiohttp.client_exceptions import ClientConnectorError

import tests.examples.plugin_description as example_descriptions
Expand Down Expand Up @@ -143,6 +144,15 @@ async def test_added_plugin(
assert plinst.compute_resource_name == new_compute_resource.name


async def test_serialize(dircopy_instance: PluginInstance):
"""
Make sure it's possible to use `serde.to_dict` on models.
"""
serialized = serde.to_dict(dircopy_instance)
assert not hasattr(serialized, "s")
assert serialized["id"] == dircopy_instance.id


async def test_add_plugin_compute_resources_serialization(
new_compute_resource: ComputeResource,
):
Expand Down

0 comments on commit a09bc24

Please sign in to comment.