Skip to content

Commit

Permalink
Add an OPDS1 feed downloader (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen authored Aug 12, 2024
1 parent 88412bb commit cbc5c66
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 57 deletions.
21 changes: 18 additions & 3 deletions src/palace_tools/cli/download_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import typer
import xmltodict

from palace_tools.feeds import axis, opds, overdrive
from palace_tools.feeds.opds import write_json
from palace_tools.feeds import axis, opds, opds1, overdrive
from palace_tools.utils.typer import run_typer_app_as_main

app = typer.Typer()
Expand Down Expand Up @@ -102,7 +101,23 @@ def download_opds(
"""Download OPDS 2 feed."""
publications = opds.fetch(url, username, password, authentication)
with output_file.open("w") as file:
write_json(file, publications)
opds.write_json(file, publications)


@app.command("opds1")
def download_opds1(
username: str = typer.Option(None, "--username", "-u", help="Username"),
password: str = typer.Option(None, "--password", "-p", help="Password"),
authentication: opds.AuthType = typer.Option(
opds.AuthType.NONE, "--auth", "-a", help="Authentication type"
),
url: str = typer.Argument(..., help="URL of feed", metavar="URL"),
output_file: Path = typer.Argument(
..., help="Output file", writable=True, file_okay=True, dir_okay=False
),
) -> None:
"""Download OPDS 1.x feed."""
opds1.fetch(url, username, password, authentication, output_file)


def main() -> None:
Expand Down
164 changes: 110 additions & 54 deletions src/palace_tools/feeds/opds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import math
import sys
from base64 import b64encode
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from enum import Enum
from typing import Any, TextIO
from typing import Any, NamedTuple, TextIO

import httpx
from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn
Expand All @@ -18,6 +18,12 @@ class AuthType(Enum):
NONE = "none"


class OpdsLinkTuple(NamedTuple):
type: str
href: str
rel: str


class OAuthAuth(httpx.Auth):
# Implementation of OPDS auth document OAuth client credentials flow for httpx
# See:
Expand All @@ -26,40 +32,49 @@ class OAuthAuth(httpx.Auth):

requires_response_body = True

def __init__(self, username: str, password: str) -> None:
def __init__(
self,
username: str,
password: str,
*,
feed_url: str,
parse_links: Callable[[str], dict[str, OpdsLinkTuple]] | None = None,
) -> None:
self.username = username
self.password = password
self.feed_url = feed_url
self.parse_links = parse_links

self.token: str | None = None
self.oauth_url: str | None = None

@staticmethod
def _get_oauth_url_from_auth_document(
url: str, auth_document: Mapping[str, Any]
) -> str:
def _get_oauth_url_from_auth_document(auth_document: Mapping[str, Any]) -> str:
auth_types: list[dict[str, Any]] = auth_document.get("authentication", [])
oauth_authentication = [
tlinks
for t in auth_types
if t.get("type") == "http://opds-spec.org/auth/oauth/client_credentials"
and (tlinks := t.get("links")) is not None
]
if not oauth_authentication:
print(f"Unable to find supported authentication type ({url})")
print(f"Auth document: {json.dumps(auth_document)}")
try:
[links] = [
tlinks
for t in auth_types
if t.get("type") == "http://opds-spec.org/auth/oauth/client_credentials"
and (tlinks := t.get("links")) is not None
]
except (ValueError, TypeError):
print("Unable to find supported authentication type")
print(f"Auth document: {json.dumps(auth_document, indent=2)}")
sys.exit(-1)

links = oauth_authentication[0]
auth_links: list[str] = [
lhref
for l in links
if l.get("rel") == "authenticate" and (lhref := l.get("href")) is not None
]
if len(auth_links) != 1:
print(f"Unable to find valid authentication link ({url})")
print(
f"Found {len(auth_links)} authentication links. Auth document: {json.dumps(auth_document)}"
)
try:
[auth_link] = [
lhref
for l in links
if l.get("rel") == "authenticate"
and (lhref := l.get("href")) is not None
]
except (ValueError, TypeError):
print("Unable to find valid authentication link")
print(f"Auth document: {json.dumps(auth_document, indent=2)}")
sys.exit(-1)
return auth_links[0]
return auth_link # type: ignore[no-any-return]

@staticmethod
def _oauth_token_request(url: str, username: str, password: str) -> httpx.Request:
Expand All @@ -70,43 +85,84 @@ def _oauth_token_request(url: str, username: str, password: str) -> httpx.Reques
"POST", url, headers=headers, data={"grant_type": "client_credentials"}
)

def refresh_auth_url(self) -> Generator[httpx.Request, httpx.Response, None]:
response = yield httpx.Request("GET", self.feed_url)
if response.status_code == 200 and self.parse_links is not None:
links = self.parse_links(response.text)
auth_doc_url = links.get("http://opds-spec.org/auth/document")
if auth_doc_url is None:
print("No auth document link found")
print(links)
sys.exit(-1)
auth_doc_response = yield httpx.Request("GET", auth_doc_url.href)
if auth_doc_response.status_code != 200:
error_and_exit(auth_doc_response)
elif response.status_code == 401:
auth_doc_response = response
else:
error_and_exit(response)

if (
auth_doc_response.headers.get("Content-Type")
!= "application/vnd.opds.authentication.v1.0+json"
):
error_and_exit(auth_doc_response, "Invalid content type")

self.oauth_url = self._get_oauth_url_from_auth_document(
auth_doc_response.json()
)

def refresh_token(self) -> Generator[httpx.Request, httpx.Response, None]:
if self.oauth_url is None:
yield from self.refresh_auth_url()

# This should never happen, but we assert for sanity and mypy
assert self.oauth_url is not None

response = yield self._oauth_token_request(
self.oauth_url, self.username, self.password
)
if response.status_code != 200:
error_and_exit(response)
if (access_token := response.json().get("access_token")) is None:
print("No access token in response")
print(response.text)
sys.exit(-1)
self.token = access_token

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
if self.token is not None:
request.headers["Authorization"] = f"Bearer {self.token}"
token_refreshed = False
if self.oauth_url is None or self.token is None:
yield from self.refresh_token()
token_refreshed = True

# This should never happen, but we assert it for mypy and our sanity
assert self.token is not None

request.headers["Authorization"] = f"Bearer {self.token}"
response = yield request
if (
response.status_code == 401
and response.headers.get("Content-Type")
== "application/vnd.opds.authentication.v1.0+json"
):
oauth_url = self._get_oauth_url_from_auth_document(
str(request.url), response.json()
)
response = yield self._oauth_token_request(
oauth_url, self.username, self.password
)
if response.status_code != 200:
print(f"Error: {response.status_code}")
print(response.text)
sys.exit(-1)
if (access_token := response.json().get("access_token")) is None:
print("No access token in response")
print(response.text)
sys.exit(-1)
self.token = access_token

if response.status_code == 401 and not token_refreshed:
yield from self.refresh_token()
request.headers["Authorization"] = f"Bearer {self.token}"
yield request


def error_and_exit(response: httpx.Response, detail: str = "") -> None:
print(f"Error: {detail}")
print(f"Request: {response.request.method} {response.request.url}")
print(f"Status code: {response.status_code}")
print(f"Headers: {json.dumps(dict(response.headers), indent=4)}")
print(f"Body: {response.text}")
sys.exit(-1)


def make_request(session: httpx.Client, url: str) -> dict[str, Any]:
response = session.get(url)
if response.status_code != 200:
print(f"Error: {response.status_code}")
print(f"Headers: {json.dumps(dict(response.headers), indent=4)}")
print(response.text)
sys.exit(-1)
error_and_exit(response)
return response.json() # type: ignore[no-any-return]


Expand All @@ -132,7 +188,7 @@ def fetch(
if auth_type == AuthType.BASIC:
client.auth = httpx.BasicAuth(username, password)
elif auth_type == AuthType.OAUTH:
client.auth = OAuthAuth(username, password)
client.auth = OAuthAuth(username, password, feed_url=url)
elif auth_type != AuthType.NONE:
print("Username and password are required for authentication")
sys.exit(-1)
Expand Down
69 changes: 69 additions & 0 deletions src/palace_tools/feeds/opds1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
from pathlib import Path
from xml.etree import ElementTree

import httpx
from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn

from palace_tools.feeds.opds import AuthType, OAuthAuth, OpdsLinkTuple, error_and_exit


def parse_links(feed: str) -> dict[str, OpdsLinkTuple]:
feed_element = ElementTree.fromstring(feed)
return {
rel: OpdsLinkTuple(type=link_type, href=href, rel=rel)
for link in feed_element.findall("{http://www.w3.org/2005/Atom}link")
if (rel := link.get("rel")) is not None
and (link_type := link.get("type")) is not None
and (href := link.get("href")) is not None
}


def make_request(session: httpx.Client, url: str) -> str:
response = session.get(url)
if response.status_code != 200:
error_and_exit(response)
return response.text


def fetch(
url: str,
username: str | None,
password: str | None,
auth_type: AuthType,
output_file: Path,
) -> None:
# Create a session to fetch the documents
client = httpx.Client()

client.headers.update(
{
"Accept": "application/atom+xml;profile=opds-catalog;kind=acquisition,application/atom+xml;q=0.9,application/xml;q=0.8,*/*;q=0.1",
"User-Agent": "Palace",
}
)
client.timeout = httpx.Timeout(30.0)

if username and password:
if auth_type == AuthType.BASIC:
client.auth = httpx.BasicAuth(username, password)
elif auth_type == AuthType.OAUTH:
client.auth = OAuthAuth(
username, password, feed_url=url, parse_links=parse_links
)
elif auth_type != AuthType.NONE:
print("Username and password are required for authentication")
sys.exit(-1)

next_url: str | None = url
with output_file.open("w") as file:
with Progress(
SpinnerColumn(), *Progress.get_default_columns(), MofNCompleteColumn()
) as progress:
download_task = progress.add_task(f"Downloading Feed", total=None)
while next_url is not None:
response = make_request(client, next_url)
file.write(response)
links = parse_links(response)
next_url = links.get("next") and links["next"].href
progress.update(download_task, advance=1)

0 comments on commit cbc5c66

Please sign in to comment.