From 668233f7238867f982f32ce770deca56efb3f479 Mon Sep 17 00:00:00 2001 From: Simon Bihel Date: Mon, 29 Apr 2024 17:27:56 +0100 Subject: [PATCH] Fix tests/fmt and merge with latest main --- siwe/siwe.py | 8 +++----- tests/test_siwe.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/siwe/siwe.py b/siwe/siwe.py index b8ca1fe..c4afab5 100644 --- a/siwe/siwe.py +++ b/siwe/siwe.py @@ -4,7 +4,7 @@ import string from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Iterable, List, Optional import eth_utils from eth_account.messages import SignableMessage, _hash_eip191_message, encode_defunct @@ -218,16 +218,14 @@ def address_is_checksum_address(cls, v: str) -> str: @classmethod def from_message(cls, message: str, abnf: bool = True) -> "SiweMessage": + """Parse a message in its EIP-4361 format.""" if abnf: parsed_message = ABNFParsedMessage(message=message) else: parsed_message = RegExpParsedMessage(message=message) # TODO There is some redundancy in the checks when deserialising a message. - try: - return cls(**parsed_message.__dict__) - except ValidationError as e: - raise ValueError from e + return cls(**parsed_message.__dict__) def prepare_message(self) -> str: """Serialize to the EIP-4361 format for signing. diff --git a/tests/test_siwe.py b/tests/test_siwe.py index edb9d08..82efe54 100644 --- a/tests/test_siwe.py +++ b/tests/test_siwe.py @@ -30,7 +30,7 @@ class TestMessageParsing: [(test_name, test) for test_name, test in parsing_positive.items()], ) def test_valid_message(self, abnf, test_name, test): - siwe_message = SiweMessage(message=test["message"], abnf=abnf) + siwe_message = SiweMessage.from_message(message=test["message"], abnf=abnf) for key, value in test["fields"].items(): v = getattr(siwe_message, key) if not (isinstance(v, int) or isinstance(v, list) or v is None): @@ -44,7 +44,7 @@ def test_valid_message(self, abnf, test_name, test): ) def test_invalid_message(self, abnf, test_name, test): with pytest.raises(ValueError): - SiweMessage(message=test, abnf=abnf) + SiweMessage.from_message(message=test, abnf=abnf) @pytest.mark.parametrize( "test_name,test", @@ -52,7 +52,7 @@ def test_invalid_message(self, abnf, test_name, test): ) def test_invalid_object_message(self, test_name, test): with pytest.raises(ValidationError): - SiweMessage(message=test) + SiweMessage(**test) class TestMessageGeneration: @@ -61,7 +61,7 @@ class TestMessageGeneration: [(test_name, test) for test_name, test in parsing_positive.items()], ) def test_valid_message(self, test_name, test): - siwe_message = SiweMessage(message=test["fields"]) + siwe_message = SiweMessage(**test["fields"]) assert siwe_message.prepare_message() == test["message"] @@ -71,7 +71,7 @@ class TestMessageVerification: [(test_name, test) for test_name, test in verification_positive.items()], ) def test_valid_message(self, test_name, test): - siwe_message = SiweMessage(message=test) + siwe_message = SiweMessage(**test) timestamp = datetime_from_iso8601_string(test["time"]) if "time" in test else None siwe_message.verify(test["signature"], timestamp=timestamp) @@ -81,7 +81,7 @@ def test_valid_message(self, test_name, test): ) def test_eip1271_message(self, test_name, test): provider = HTTPProvider(endpoint_uri="https://cloudflare-eth.com") - siwe_message = SiweMessage(message=test["message"]) + siwe_message = SiweMessage.from_message(message=test["message"]) siwe_message.verify(test["signature"], provider=provider) @pytest.mark.parametrize( @@ -98,9 +98,9 @@ def test_invalid_message(self, provider, test_name, test): "invalidissued_at", ]: with pytest.raises(ValidationError): - siwe_message = SiweMessage(message=test) + siwe_message = SiweMessage(**test) return - siwe_message = SiweMessage(message=test) + siwe_message = SiweMessage(**test) domain_binding = test.get("domain_binding") match_nonce = test.get("match_nonce") timestamp = datetime_from_iso8601_string(test["time"]) if "time" in test else None @@ -122,7 +122,7 @@ class TestMessageRoundTrip: [(test_name, test) for test_name, test in parsing_positive.items()], ) def test_message_round_trip(self, test_name, test): - message = SiweMessage(test["fields"]) + message = SiweMessage(**test["fields"]) message.address = self.account.address signature = self.account.sign_message( messages.encode_defunct(text=message.prepare_message())