Skip to content

Commit

Permalink
Fix tests/fmt and merge with latest main
Browse files Browse the repository at this point in the history
  • Loading branch information
sbihel committed Apr 29, 2024
1 parent f0aad95 commit 668233f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
8 changes: 3 additions & 5 deletions siwe/siwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import string
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Iterable, List, Optional

import eth_utils
from eth_account.messages import SignableMessage, _hash_eip191_message, encode_defunct
Expand Down Expand Up @@ -218,16 +218,14 @@ def address_is_checksum_address(cls, v: str) -> str:

@classmethod
def from_message(cls, message: str, abnf: bool = True) -> "SiweMessage":
"""Parse a message in its EIP-4361 format."""
if abnf:
parsed_message = ABNFParsedMessage(message=message)
else:
parsed_message = RegExpParsedMessage(message=message)

# TODO There is some redundancy in the checks when deserialising a message.
try:
return cls(**parsed_message.__dict__)
except ValidationError as e:
raise ValueError from e
return cls(**parsed_message.__dict__)

def prepare_message(self) -> str:
"""Serialize to the EIP-4361 format for signing.
Expand Down
18 changes: 9 additions & 9 deletions tests/test_siwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestMessageParsing:
[(test_name, test) for test_name, test in parsing_positive.items()],
)
def test_valid_message(self, abnf, test_name, test):
siwe_message = SiweMessage(message=test["message"], abnf=abnf)
siwe_message = SiweMessage.from_message(message=test["message"], abnf=abnf)
for key, value in test["fields"].items():
v = getattr(siwe_message, key)
if not (isinstance(v, int) or isinstance(v, list) or v is None):
Expand All @@ -44,15 +44,15 @@ def test_valid_message(self, abnf, test_name, test):
)
def test_invalid_message(self, abnf, test_name, test):
with pytest.raises(ValueError):
SiweMessage(message=test, abnf=abnf)
SiweMessage.from_message(message=test, abnf=abnf)

@pytest.mark.parametrize(
"test_name,test",
[(test_name, test) for test_name, test in parsing_negative_objects.items()],
)
def test_invalid_object_message(self, test_name, test):
with pytest.raises(ValidationError):
SiweMessage(message=test)
SiweMessage(**test)


class TestMessageGeneration:
Expand All @@ -61,7 +61,7 @@ class TestMessageGeneration:
[(test_name, test) for test_name, test in parsing_positive.items()],
)
def test_valid_message(self, test_name, test):
siwe_message = SiweMessage(message=test["fields"])
siwe_message = SiweMessage(**test["fields"])
assert siwe_message.prepare_message() == test["message"]


Expand All @@ -71,7 +71,7 @@ class TestMessageVerification:
[(test_name, test) for test_name, test in verification_positive.items()],
)
def test_valid_message(self, test_name, test):
siwe_message = SiweMessage(message=test)
siwe_message = SiweMessage(**test)
timestamp = datetime_from_iso8601_string(test["time"]) if "time" in test else None
siwe_message.verify(test["signature"], timestamp=timestamp)

Expand All @@ -81,7 +81,7 @@ def test_valid_message(self, test_name, test):
)
def test_eip1271_message(self, test_name, test):
provider = HTTPProvider(endpoint_uri="https://cloudflare-eth.com")
siwe_message = SiweMessage(message=test["message"])
siwe_message = SiweMessage.from_message(message=test["message"])
siwe_message.verify(test["signature"], provider=provider)

@pytest.mark.parametrize(
Expand All @@ -98,9 +98,9 @@ def test_invalid_message(self, provider, test_name, test):
"invalidissued_at",
]:
with pytest.raises(ValidationError):
siwe_message = SiweMessage(message=test)
siwe_message = SiweMessage(**test)
return
siwe_message = SiweMessage(message=test)
siwe_message = SiweMessage(**test)
domain_binding = test.get("domain_binding")
match_nonce = test.get("match_nonce")
timestamp = datetime_from_iso8601_string(test["time"]) if "time" in test else None
Expand All @@ -122,7 +122,7 @@ class TestMessageRoundTrip:
[(test_name, test) for test_name, test in parsing_positive.items()],
)
def test_message_round_trip(self, test_name, test):
message = SiweMessage(test["fields"])
message = SiweMessage(**test["fields"])
message.address = self.account.address
signature = self.account.sign_message(
messages.encode_defunct(text=message.prepare_message())
Expand Down

0 comments on commit 668233f

Please sign in to comment.