Skip to content

Upgrade pydantic 1 to pydantic 2 #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ package_dir =
# For more information, check out https://semver.org/.
install_requires =
aiohttp
aleph-sdk-python~=0.9.0
aleph-sdk-python~=2.0.1
hexbytes
fastapi>=0.95.1
importlib-metadata; python_version<"3.8"
Expand Down
4 changes: 2 additions & 2 deletions src/aleph_vrf/coordinator/executor_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _get_corechannel_aggregate() -> Dict[str, Any]:
Returns the "corechannel" aleph.im aggregate.
This aggregate contains an up-to-date list of staked nodes on the network.
"""
async with aiohttp.ClientSession(settings.API_HOST) as session:
async with aiohttp.ClientSession(str(settings.API_HOST)) as session:
url = (
f"/api/v0/aggregates/{settings.CORECHANNEL_AGGREGATE_ADDRESS}.json?"
f"keys={settings.CORECHANNEL_AGGREGATE_KEY}"
Expand All @@ -53,7 +53,7 @@ async def _get_unauthorized_node_list_aggregate(aggregate_address: str) -> List[
Returns the "vrf_unauthorized_nodes" list aggregate.
This aggregate contains an up-to-date list of nodes not allowed to run a VRF request.
"""
async with aiohttp.ClientSession(settings.API_HOST) as session:
async with aiohttp.ClientSession(str(settings.API_HOST)) as session:
url = (
f"/api/v0/aggregates/{aggregate_address}.json?"
f"keys={settings.VRF_AGGREGATE_KEY}"
Expand Down
2 changes: 1 addition & 1 deletion src/aleph_vrf/coordinator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


class VRFRequest(BaseModel):
request_id: Optional[str]
request_id: Optional[str] = None


@app.get("/")
Expand Down
23 changes: 15 additions & 8 deletions src/aleph_vrf/coordinator/vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aleph.sdk.chains.ethereum import ETHAccount
from aleph.sdk.client import AuthenticatedAlephHttpClient
from aleph.sdk.query.filters import MessageFilter
from aleph_message.models import ItemHash, MessageType, PostMessage
from aleph_message.models import ItemHash, MessageType, PostMessage, AlephMessage
from aleph_message.status import MessageStatus
from hexbytes import HexBytes
from pydantic import BaseModel
Expand Down Expand Up @@ -60,7 +60,7 @@ async def post_executor_api_request(url: str, model: Type[M]) -> M:

response = await resp.json()

return model.parse_obj(response["data"])
return model.model_validate(response["data"])


async def prepare_executor_api_request(url: str) -> bool:
Expand Down Expand Up @@ -187,7 +187,7 @@ async def generate_vrf(

async with AuthenticatedAlephHttpClient(
account=account,
api_server=aleph_api_server or settings.API_HOST,
api_server=aleph_api_server or str(settings.API_HOST),
# Avoid going through the VM connector on aleph.im CRNs
allow_unix_sockets=False,
) as aleph_client:
Expand Down Expand Up @@ -358,9 +358,11 @@ async def get_existing_vrf_message(
if messages.messages:
if len(messages.messages) > 1:
logger.warning(f"Multiple VRF messages found for request id {request_id}")
return messages.messages[
-1
] # Always fetch the last VRF message in case there is more than 1.
# Ensure we're returning a PostMessage
last_message = messages.messages[-1]
if isinstance(last_message, PostMessage):
return last_message
return None
else:
logger.debug(f"Existing VRF message for request id {request_id} not found")
return None
Expand All @@ -374,13 +376,18 @@ async def get_existing_message(
f"Getting VRF message on {aleph_client.api_server} for item_hash {item_hash}"
)

message = await aleph_client.get_message(
message: AlephMessage = await aleph_client.get_message(
item_hash=item_hash,
)

if not message:
raise AlephNetworkError(
f"Message could not be read for item_hash {message.item_hash}"
f"Message could not be read for item_hash {item_hash}"
)

if not isinstance(message, PostMessage):
raise AlephNetworkError(
f"Message for item_hash {item_hash} is not a PostMessage"
)

return message
Expand Down
5 changes: 3 additions & 2 deletions src/aleph_vrf/executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import Annotated
else:
from typing import Annotated
from typing import AsyncGenerator

import fastapi
from aleph.sdk.exceptions import MessageNotFoundError, MultipleMessagesError
Expand Down Expand Up @@ -57,11 +58,11 @@
app = AlephApp(http_app=http_app)


async def authenticated_aleph_client() -> AuthenticatedAlephHttpClient:
async def authenticated_aleph_client() -> AsyncGenerator[AuthenticatedAlephHttpClient, None]:
account = settings.aleph_account()
async with AuthenticatedAlephHttpClient(
account=account,
api_server=settings.API_HOST,
api_server=str(settings.API_HOST),
# Avoid going through the VM connector on aleph.im CRNs
allow_unix_sockets=False,
) as client:
Expand Down
3 changes: 1 addition & 2 deletions src/aleph_vrf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from aleph_message.models import ItemHash, PostMessage
from aleph_message.models.abstract import HashableModel
from pydantic import BaseModel, ValidationError
from pydantic.generics import GenericModel
from typing_extensions import TypeAlias

from aleph_vrf.types import ExecutionId, Nonce, RequestId
Expand Down Expand Up @@ -225,7 +224,7 @@ class APIError(BaseModel):
error: str


class APIResponse(GenericModel, Generic[M]):
class APIResponse(BaseModel, Generic[M]):
data: M


Expand Down
17 changes: 7 additions & 10 deletions src/aleph_vrf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
from aleph.sdk.chains.common import get_fallback_private_key
from aleph.sdk.chains.ethereum import ETHAccount
from hexbytes import HexBytes
from pydantic import BaseSettings, Field, HttpUrl
from pydantic import Field, HttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
API_HOST: HttpUrl = Field(
default="https://api3.aleph.im",
default=HttpUrl("https://api3.aleph.im"),
description="URL of the reference aleph.im Core Channel Node.",
)
CORECHANNEL_AGGREGATE_ADDRESS = Field(
CORECHANNEL_AGGREGATE_ADDRESS: str = Field(
default="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10",
description="Address posting the `corechannel` aggregate.",
)
CORECHANNEL_AGGREGATE_KEY = Field(
CORECHANNEL_AGGREGATE_KEY: str = Field(
default="corechannel", description="Key for the `corechannel` aggregate."
)
FUNCTION: str = Field(
Expand All @@ -26,7 +27,7 @@ class Settings(BaseSettings):
default=None,
description="Address posting the `corechannel` aggregate.",
)
VRF_AGGREGATE_KEY = Field(
VRF_AGGREGATE_KEY: str = Field(
default="vrf", description="Key for the VRF aggregate."
)
NB_EXECUTORS: int = Field(default=16, description="Number of executors to use.")
Expand All @@ -45,11 +46,7 @@ def private_key(self) -> HexBytes:

def aleph_account(self) -> ETHAccount:
return ETHAccount(self.private_key())

class Config:
env_prefix = "ALEPH_VRF_"
case_sensitive = False
env_file = ".env"
model_config = SettingsConfigDict(env_prefix="ALEPH_VRF_", case_sensitive=False, env_file=".env")


settings = Settings()
3 changes: 1 addition & 2 deletions tests/coordinator/test_integration_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def assert_aleph_message_matches_vrf_response(
vrf_response.message_hash, message_type=PostMessage
)

message_vrf_response = VRFResponse.parse_obj(message.content.content)
message_vrf_response = VRFResponse.model_validate(message.content.content)
assert_vrf_response_equal(message_vrf_response, vrf_response)

return message
Expand Down Expand Up @@ -211,7 +211,6 @@ async def send_generate_requests_and_call_publish(
)
# We're only interested in one response for this test
break

return generate_response


Expand Down
12 changes: 6 additions & 6 deletions tests/executor/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def make_post_message(
content=vrf_object,
)

item_content = content.json()
item_content = content.model_dump_json()
item_hash = sha256(item_content.encode()).hexdigest()

return {
Expand Down Expand Up @@ -161,7 +161,7 @@ async def assert_aleph_message_matches_random_number_hash(
random_number_hash.message_hash, message_type=PostMessage
)

message_random_number_hash = VRFRandomNumberHash.parse_obj(message.content.content)
message_random_number_hash = VRFRandomNumberHash.model_validate(message.content.content)
assert_vrf_random_number_hash_equal(message_random_number_hash, random_number_hash)

return message
Expand All @@ -187,7 +187,7 @@ async def assert_aleph_message_matches_random_number(
random_number.message_hash, message_type=PostMessage
)

message_random_number = VRFRandomNumber.parse_obj(message.content.content)
message_random_number = VRFRandomNumber.model_validate(message.content.content)
assert_vrf_random_number_equal(message_random_number, random_number)

return message
Expand All @@ -213,7 +213,7 @@ async def test_normal_request_flow(
assert resp.status == 200, await resp.text()
response_json = await resp.json()

random_number_hash = PublishedVRFRandomNumberHash.parse_obj(response_json["data"])
random_number_hash = PublishedVRFRandomNumberHash.model_validate(response_json["data"])

assert_vrf_hash_matches_request(random_number_hash, vrf_request, item_hash)
random_number_hash_message = await assert_aleph_message_matches_random_number_hash(
Expand All @@ -224,7 +224,7 @@ async def test_normal_request_flow(
assert resp.status == 200, await resp.text()
response_json = await resp.json()

random_number = PublishedVRFRandomNumber.parse_obj(response_json["data"])
random_number = PublishedVRFRandomNumber.model_validate(response_json["data"])
assert_random_number_matches_request(
random_number=random_number,
random_number_hash=random_number_hash,
Expand Down Expand Up @@ -256,7 +256,7 @@ async def test_call_publish_twice(
assert resp.status == 200, await resp.text()
response_json = await resp.json()

random_number_hash = PublishedVRFRandomNumberHash.parse_obj(response_json["data"])
random_number_hash = PublishedVRFRandomNumberHash.model_validate(response_json["data"])

# Call POST /publish a first time
resp = await executor_client.post(f"/publish/{random_number_hash.message_hash}")
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_ccn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class PublicationStatus(BaseModel):

class PubMessageResponse(BaseModel):
publication_status: PublicationStatus
message_status: Optional[MessageStatus]
message_status: Optional[MessageStatus] = None


def format_message(message_dict: Dict[str, Any]):
Expand Down