Skip to content

Commit

Permalink
feat: add follow_redirect, timeout, useragent (#33)
Browse files Browse the repository at this point in the history
* feat: add follow_redirect, timeout, useragent
- follow redirect on requests
- add timeout to request, default: 5
- add customaizable useragent
* check if returned in voice matches the amount
* await handle
* improvements
  • Loading branch information
dni authored May 8, 2024
1 parent 68be68d commit 8a95b54
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 25 deletions.
94 changes: 74 additions & 20 deletions lnurl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,19 @@

from .exceptions import InvalidLnurl, InvalidUrl, LnurlResponseException
from .helpers import lnurlauth_signature, url_encode
from .models import LnurlAuthResponse, LnurlPayResponse, LnurlResponse, LnurlResponseModel, LnurlWithdrawResponse
from .models import (
LnurlAuthResponse,
LnurlPayActionResponse,
LnurlPayResponse,
LnurlResponse,
LnurlResponseModel,
LnurlWithdrawResponse,
)
from .types import ClearnetUrl, DebugUrl, LnAddress, Lnurl, OnionUrl

USER_AGENT = "lnbits/lnurl"
TIMEOUT = 5


def decode(bech32_lnurl: str) -> Union[OnionUrl, ClearnetUrl, DebugUrl]:
try:
Expand All @@ -25,13 +35,20 @@ def encode(url: str) -> Lnurl:
raise InvalidUrl


async def get(url: str, *, response_class: Optional[Any] = None) -> LnurlResponseModel:
async with httpx.AsyncClient() as client:
async def get(
url: str,
*,
response_class: Optional[Any] = None,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
try:
res = await client.get(url)
res = await client.get(url, timeout=timeout or TIMEOUT)
res.raise_for_status()
except Exception as e:
raise LnurlResponseException(str(e))
except Exception as exc:
raise LnurlResponseException(str(exc)) from exc

if response_class:
assert issubclass(response_class, LnurlResponseModel), "Use a valid `LnurlResponseModel` subclass."
Expand All @@ -43,73 +60,108 @@ async def get(url: str, *, response_class: Optional[Any] = None) -> LnurlRespons
async def handle(
bech32_lnurl: str,
response_class: Optional[LnurlResponseModel] = None,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
if "@" in bech32_lnurl:
lnaddress = LnAddress(bech32_lnurl)
return await get(lnaddress.url, response_class=response_class)
return await get(lnaddress.url, response_class=response_class, user_agent=user_agent, timeout=timeout)
lnurl = Lnurl(bech32_lnurl)
except (ValidationError, ValueError):
raise InvalidLnurl

if lnurl.is_login:
return LnurlAuthResponse(callback=lnurl.url, k1=lnurl.url.query_params["k1"])

return await get(lnurl.url, response_class=response_class)
return await get(lnurl.url, response_class=response_class, user_agent=user_agent, timeout=timeout)


async def execute(bech32_or_address: str, value: str) -> LnurlResponseModel:
async def execute(
bech32_or_address: str,
value: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
res = handle(bech32_or_address)
res = await handle(bech32_or_address, user_agent=user_agent, timeout=timeout)
except Exception as exc:
raise LnurlResponseException(str(exc))

if isinstance(res, LnurlPayResponse) and res.tag == "payRequest":
return await execute_pay_request(res, value)
return await execute_pay_request(res, value, user_agent=user_agent, timeout=timeout)
elif isinstance(res, LnurlAuthResponse) and res.tag == "login":
return await execute_login(res, value)
return await execute_login(res, value, user_agent=user_agent, timeout=timeout)
elif isinstance(res, LnurlWithdrawResponse) and res.tag == "withdrawRequest":
return await execute_withdraw(res, value)
return await execute_withdraw(res, value, user_agent=user_agent, timeout=timeout)

raise LnurlResponseException(f"{res.tag} not implemented") # type: ignore


async def execute_pay_request(res: LnurlPayResponse, msat: str) -> LnurlResponseModel:
async def execute_pay_request(
res: LnurlPayResponse,
msat: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
if not res.min_sendable <= MilliSatoshi(msat) <= res.max_sendable:
raise LnurlResponseException(f"Amount {msat} not in range {res.min_sendable} - {res.max_sendable}")

try:
async with httpx.AsyncClient() as client:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
res2 = await client.get(
url=res.callback,
params={
"amount": msat,
},
timeout=timeout or TIMEOUT,
)
res2.raise_for_status()
return LnurlResponse.from_dict(res2.json())
pay_res = LnurlResponse.from_dict(res2.json())
assert isinstance(pay_res, LnurlPayActionResponse), "Invalid response in execute_pay_request."
invoice = bolt11_decode(pay_res.pr)
if invoice.amount_msat != int(msat):
raise LnurlResponseException(
f"{res.callback.host} returned an invalid invoice."
f"Excepted `{msat}` msat, got `{invoice.amount_msat}`."
)
return pay_res
except Exception as exc:
raise LnurlResponseException(str(exc))


async def execute_login(res: LnurlAuthResponse, secret: str) -> LnurlResponseModel:
async def execute_login(
res: LnurlAuthResponse,
secret: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
assert res.callback.host, "LNURLauth host does not exist"
key, sig = lnurlauth_signature(res.callback.host, secret, res.k1)
async with httpx.AsyncClient() as client:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
res2 = await client.get(
url=res.callback,
params={
"key": key,
"sig": sig,
},
timeout=timeout or TIMEOUT,
)
res2.raise_for_status()
return LnurlResponse.from_dict(res2.json())
except Exception as e:
raise LnurlResponseException(str(e))


async def execute_withdraw(res: LnurlWithdrawResponse, pr: str) -> LnurlResponseModel:
async def execute_withdraw(
res: LnurlWithdrawResponse,
pr: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
invoice = bolt11_decode(pr)
except Bolt11Exception as exc:
Expand All @@ -119,13 +171,15 @@ async def execute_withdraw(res: LnurlWithdrawResponse, pr: str) -> LnurlResponse
if not res.min_withdrawable <= MilliSatoshi(amount) <= res.max_withdrawable:
raise LnurlResponseException(f"Amount {amount} not in range {res.min_withdrawable} - {res.max_withdrawable}")
try:
async with httpx.AsyncClient() as client:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
res2 = await client.get(
url=res.callback,
params={
"k1": res.k1,
"pr": pr,
},
timeout=timeout or TIMEOUT,
)
res2.raise_for_status()
return LnurlResponse.from_dict(res2.json())
Expand Down
1 change: 1 addition & 0 deletions lnurl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class LnurlPayActionResponse(LnurlResponseModel):
pr: LightningInvoice
success_action: Optional[Union[MessageAction, UrlAction, AesAction]] = Field(None, alias="successAction")
routes: List[List[LnurlPayRouteHop]] = []
verify: Optional[str] = None


class LnurlWithdrawResponse(LnurlResponseModel):
Expand Down
12 changes: 7 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def test_encode_nohttps(self, url):


class TestHandle:
"""Responses from the LNURL: https://legend.lnbits.com/"""
"""Responses from the LNURL: https://demo.lnbits.com/"""

@pytest.mark.xfail(reason="legend.lnbits.com is down")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"bech32",
Expand All @@ -86,7 +87,7 @@ async def test_handle_withdraw(self, bech32):
res = await handle(bech32)
assert isinstance(res, LnurlWithdrawResponse)
assert res.tag == "withdrawRequest"
assert res.callback.host == "legend.lnbits.com"
assert res.callback.host == "demo.lnbits.com"
assert res.default_description == "sample withdraw"
assert res.max_withdrawable >= res.min_withdrawable

Expand All @@ -104,8 +105,9 @@ async def test_get_requests_error(self, url):


class TestPayFlow:
"""Full LNURL-pay flow interacting with https://legend.lnbits.com/"""
"""Full LNURL-pay flow interacting with https://demo.lnbits.com/"""

@pytest.mark.xfail(reason="legend.lnbits.com is down")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"bech32, amount",
Expand All @@ -114,14 +116,14 @@ class TestPayFlow:
"LNURL1DP68GURN8GHJ7MR9VAJKUEPWD3HXY6T5WVHXXMMD9AKXUATJD3CZ7JN9F4EHQJQC25ZZY",
"1000",
),
("donate@legend.lnbits.com", "100000"),
("donate@demo.lnbits.com", "100000"),
],
)
async def test_pay_flow(self, bech32: str, amount: str):
res = await handle(bech32)
assert isinstance(res, LnurlPayResponse)
assert res.tag == "payRequest"
assert res.callback.host == "legend.lnbits.com"
assert res.callback.host == "demo.lnbits.com"
assert len(res.metadata.list()) >= 1
assert res.metadata.text != ""

Expand Down

0 comments on commit 8a95b54

Please sign in to comment.