Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] update to pydantic2 #21

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
all: black ruff mypy test
format: black ruff isort
lint: ruff mypy
all: black ruff pyright mypy test
format: black ruff
lint: pyright mypy checkruff
check: checkblack checkruff

black:
Expand All @@ -15,6 +15,9 @@ checkruff:
checkblack:
poetry run black --check .

pyright:
poetry run pyright .

mypy:
poetry run mypy .

Expand Down
2 changes: 1 addition & 1 deletion lnurl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def handle(
if lnurl.is_login:
return LnurlAuthResponse(callback=lnurl.url, k1=lnurl.url.query_params["k1"])

return get(lnurl.url, response_class=response_class)
return get(str(lnurl.url), response_class=response_class, verify=verify)
19 changes: 9 additions & 10 deletions lnurl/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import List, Literal, Optional, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from .exceptions import LnurlResponseException
from .types import (
Expand Down Expand Up @@ -45,16 +45,15 @@ class UrlAction(LnurlPaySuccessAction):


class LnurlResponseModel(BaseModel):
class Config:
allow_population_by_field_name = True
model_config = ConfigDict(populate_by_name=True)

def dict(self, **kwargs):
kwargs.setdefault("by_alias", True)
return super().dict(**kwargs)
return super().model_dump(**kwargs)

def json(self, **kwargs):
kwargs.setdefault("by_alias", True)
return super().json(**kwargs)
return super().model_dump_json(**kwargs)

@property
def ok(self) -> bool:
Expand Down Expand Up @@ -95,7 +94,7 @@ class LnurlHostedChannelResponse(LnurlResponseModel):
tag: Literal["hostedChannelRequest"] = "hostedChannelRequest"
uri: LightningNodeUri
k1: str
alias: Optional[str]
alias: Optional[str] = None


class LnurlPayResponse(LnurlResponseModel):
Expand All @@ -105,8 +104,8 @@ class LnurlPayResponse(LnurlResponseModel):
max_sendable: MilliSatoshi = Field(..., alias="maxSendable")
metadata: LnurlPayMetadata

@validator("max_sendable")
def max_less_than_min(cls, value, values, **kwargs): # noqa
@field_validator("max_sendable")
def max_less_than_min(cls, value, values, **_):
if "min_sendable" in values and value < values["min_sendable"]:
raise ValueError("`max_sendable` cannot be less than `min_sendable`.")
return value
Expand Down Expand Up @@ -147,8 +146,8 @@ class LnurlWithdrawResponse(LnurlResponseModel):
max_withdrawable: MilliSatoshi = Field(..., alias="maxWithdrawable")
default_description: str = Field("", alias="defaultDescription")

@validator("max_withdrawable")
def max_less_than_min(cls, value, values, **kwargs): # noqa
@field_validator("max_withdrawable")
def max_less_than_min(cls, value, values, **_):
if "min_withdrawable" in values and value < values["min_withdrawable"]:
raise ValueError("`max_withdrawable` cannot be less than `min_withdrawable`.")
return value
Expand Down
205 changes: 95 additions & 110 deletions lnurl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,15 @@
import os
import re
from hashlib import sha256
from typing import List, Optional, Tuple, Union
from urllib.parse import parse_qs
from typing import Annotated, List, Optional, Tuple, Union

from pydantic import (
ConstrainedStr,
HttpUrl,
Json,
PositiveInt,
ValidationError,
parse_obj_as,
)
from pydantic.errors import UrlHostTldError, UrlSchemeError
from pydantic.networks import Parts
from pydantic.validators import str_validator
from pydantic import Field, HttpUrl, Json, PositiveInt, TypeAdapter, UrlConstraints, ValidationError
from pydantic.functional_validators import AfterValidator

from .exceptions import InvalidLnurlPayMetadata
from .helpers import _bech32_decode, _lnurl_clean, _lnurl_decode


def ctrl_characters_validator(value: str) -> str:
"""Checks for control characters (unicode blocks C0 and C1, plus DEL)."""
if re.compile(r"[\u0000-\u001f\u007f-\u009f]").search(value):
raise ValueError
return value


def strict_rfc3986_validator(value: str) -> str:
"""Checks for RFC3986 compliance."""
if re.compile(r"[^]a-zA-Z0-9._~:/?#[@!$&'()*+,;=-]").search(value):
raise ValueError
return value


class ReprMixin:
def __repr__(self) -> str:
attrs = [ # type: ignore
Expand Down Expand Up @@ -62,69 +38,80 @@ def __init__(self, bech32: str, *, hrp: Optional[str] = None, data: Optional[Lis
def __get_data__(cls, bech32: str) -> Tuple[str, List[int]]:
return _bech32_decode(bech32)

@classmethod
def __get_validators__(cls):
yield str_validator
yield cls.validate
# @classmethod
# def __get_validators__(cls):
# # yield str_validator
# yield cls.validate

@classmethod
def validate(cls, value: str) -> "Bech32":
hrp, data = cls.__get_data__(value)
return cls(value, hrp=hrp, data=data)


class Url(HttpUrl):
"""URL with extra validations over pydantic's `HttpUrl`."""

max_length = 2047 # https://stackoverflow.com/questions/417142/

@classmethod
def __get_validators__(cls):
yield ctrl_characters_validator
if os.environ.get("LNURL_STRICT_RFC3986", "0") == "1":
yield strict_rfc3986_validator
yield cls.validate

@property
def base(self) -> str:
hostport = f"{self.host}:{self.port}" if self.port else self.host
return f"{self.scheme}://{hostport}{self.path}"
def ctrl_characters_validator(value: str) -> str:
"""Checks for control characters (unicode blocks C0 and C1, plus DEL)."""
if re.compile(r"[\u0000-\u001f\u007f-\u009f]").search(value):
raise ValidationError
return value

@property
def query_params(self) -> dict:
return {k: v[0] for k, v in parse_qs(self.query).items()}

def strict_rfc3986_validator(value: str) -> str:
"""Checks for RFC3986 compliance."""
if os.environ.get("LNURL_STRICT_RFC3986", "0") == "1":
if re.compile(r"[^]a-zA-Z0-9._~:/?#[@!$&'()*+,;=-]").search(value):
raise ValidationError
return value

class DebugUrl(Url):
"""Unsecure web URL, to make developers life easier."""

allowed_schemes = {"http"}
# Secure web URL
ClearnetUrl = Annotated[
HttpUrl,
UrlConstraints(
max_length=2047, # https://stackoverflow.com/questions/417142/
allowed_schemes=["https"],
),
AfterValidator(ctrl_characters_validator),
AfterValidator(strict_rfc3986_validator),
]

@classmethod
def validate_host(cls, parts: Parts) -> Tuple[str, Optional[str], str, bool]:
host, tld, host_type, rebuild = super().validate_host(parts)
if host not in ["127.0.0.1", "0.0.0.0"]:
raise UrlSchemeError()
return host, tld, host_type, rebuild

def onion_validator(value: str) -> None:
"""checks if it is a valid onion address"""
if not value.endswith(".onion"):
raise ValidationError

class ClearnetUrl(Url):
"""Secure web URL."""

allowed_schemes = {"https"}
# Tor anonymous onion service
OnionUrl = Annotated[
HttpUrl,
UrlConstraints(
max_length=2047, # https://stackoverflow.com/questions/417142/
allowed_schemes=["http"],
),
AfterValidator(ctrl_characters_validator),
AfterValidator(strict_rfc3986_validator),
AfterValidator(onion_validator),
]


class OnionUrl(Url):
"""Tor anonymous onion service."""
def localhost_validator(value: str) -> None:
# host, tld, host_type, rebuild = super().validate_host(parts)
if not value.find("127.0.0.1") or not value.find("0.0.0.0"):
raise ValidationError

allowed_schemes = {"https", "http"}

@classmethod
def validate_host(cls, parts: Parts) -> Tuple[str, Optional[str], str, bool]:
host, tld, host_type, rebuild = super().validate_host(parts)
if tld != "onion":
raise UrlHostTldError()
return host, tld, host_type, rebuild
# Unsecure web URL, to make developers life easier
DebugUrl = Annotated[
HttpUrl,
UrlConstraints(
max_length=2047, # https://stackoverflow.com/questions/417142/
allowed_schemes=["http"],
),
AfterValidator(ctrl_characters_validator),
AfterValidator(strict_rfc3986_validator),
AfterValidator(localhost_validator),
]


class LightningInvoice(Bech32):
Expand All @@ -146,31 +133,31 @@ def h(self):
class LightningNodeUri(ReprMixin, str):
"""Remote node address of form `node_key@ip_address:port_number`."""

__slots__ = ("key", "ip", "port")
# __slots__ = ("key", "ip", "port")

def __new__(cls, uri: str, **kwargs) -> "LightningNodeUri":
return str.__new__(cls, uri)
# def __new__(cls, uri: str, **_) -> "LightningNodeUri":
# return str.__new__(cls, uri)

def __init__(self, uri: str, *, key: Optional[str] = None, ip: Optional[str] = None, port: Optional[str] = None):
str.__init__(uri)
self.key = key
self.ip = ip
self.port = port
# def __init__(self, uri: str, *, key: Optional[str] = None, ip: Optional[str] = None, port: Optional[str] = None):
# str.__init__(uri)
# self.key = key
# self.ip = ip
# self.port = port

@classmethod
def __get_validators__(cls):
yield str_validator
yield cls.validate
# @classmethod
# def __get_validators__(cls):
# yield str_validator
# yield cls.validate

@classmethod
def validate(cls, value: str) -> "LightningNodeUri":
try:
key, netloc = value.split("@")
ip, port = netloc.split(":")
except Exception:
raise ValueError
# @classmethod
# def validate(cls, value: str) -> "LightningNodeUri":
# try:
# key, netloc = value.split("@")
# ip, port = netloc.split(":")
# except Exception:
# raise ValueError

return cls(value, key=key, ip=ip, port=port)
# return cls(value, key=key, ip=ip, port=port)


class Lnurl(ReprMixin, str):
Expand All @@ -188,28 +175,30 @@ def __init__(self, lightning: str, *, url: Optional[Union[OnionUrl, ClearnetUrl,
@classmethod
def __get_url__(cls, bech32: str) -> Union[OnionUrl, ClearnetUrl, DebugUrl]:
url: str = _lnurl_decode(bech32)
return parse_obj_as(Union[OnionUrl, ClearnetUrl, DebugUrl], url) # type: ignore
adapter = TypeAdapter(Union[OnionUrl, ClearnetUrl, DebugUrl])
return adapter.validate_python(url)

@classmethod
def __get_validators__(cls):
yield str_validator
yield cls.validate
# @classmethod
# def __get_validators__(cls):
# yield str_validator
# yield cls.validate

@classmethod
def validate(cls, value: str) -> "Lnurl":
return cls(value, url=cls.__get_url__(value))

@property
def is_login(self) -> bool:
return "tag" in self.url.query_params and self.url.query_params["tag"] == "login"
params = {k: v for k, v in self.url.query_params()}
return params.get("tag") == "login"


class LnurlPayMetadata(ReprMixin, str):
valid_metadata_mime_types = {"text/plain", "image/png;base64", "image/jpeg;base64"}

__slots__ = ("_list",)

def __new__(cls, json_str: str, **kwargs) -> "LnurlPayMetadata":
def __new__(cls, json_str: str, **_) -> "LnurlPayMetadata":
return str.__new__(cls, json_str)

def __init__(self, json_str: str, *, json_obj: Optional[List] = None):
Expand All @@ -219,7 +208,8 @@ def __init__(self, json_str: str, *, json_obj: Optional[List] = None):
@classmethod
def __validate_metadata__(cls, json_str: str) -> List[Tuple[str, str]]:
try:
parse_obj_as(Json[List[Tuple[str, str]]], json_str)
adapter = TypeAdapter(Json[List[Tuple[str, str]]])
adapter.validate_python(json_str)
data = [(str(item[0]), str(item[1])) for item in json.loads(json_str)]
except ValidationError:
raise InvalidLnurlPayMetadata
Expand All @@ -233,10 +223,10 @@ def __validate_metadata__(cls, json_str: str) -> List[Tuple[str, str]]:

return clean_data

@classmethod
def __get_validators__(cls):
yield str_validator
yield cls.validate
# @classmethod
# def __get_validators__(cls):
# yield str_validator
# yield cls.validate

@classmethod
def validate(cls, value: str) -> "LnurlPayMetadata":
Expand Down Expand Up @@ -265,13 +255,8 @@ def list(self) -> List[Tuple[str, str]]:
return self._list


class InitializationVector(ConstrainedStr):
min_length = 24
max_length = 24


class Max144Str(ConstrainedStr):
max_length = 144
InitializationVector = Annotated[str, Field(max_length=24, min_length=24)]
Max144Str = Annotated[str, Field(max_length=144)]


class MilliSatoshi(PositiveInt):
Expand Down
Loading