Skip to content
This repository has been archived by the owner on Aug 19, 2024. It is now read-only.

Commit

Permalink
Merge pull request #67 from qstokkink/upd_restep_loading
Browse files Browse the repository at this point in the history
Register REST Endpoints before initializing them
  • Loading branch information
qstokkink authored Jun 7, 2024
2 parents 195d581 + 3b826c1 commit 197a251
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 127 deletions.
58 changes: 45 additions & 13 deletions src/tribler/core/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,23 @@
from ipv8.peer import Peer
from ipv8.types import IPv8

from tribler.core.restapi.rest_endpoint import RESTEndpoint
from tribler.core.session import Session


class BaseLauncher(CommunityLauncher):
class CommunityLauncherWEndpoints(CommunityLauncher):
"""
A CommunityLauncher that can supply endpoints.
"""

def get_endpoints(self) -> list[RESTEndpoint]:
"""
Get a list of endpoints that should be loaded.
"""
return []


class BaseLauncher(CommunityLauncherWEndpoints):
"""
The base class for all Tribler Community launchers.
"""
Expand Down Expand Up @@ -69,7 +82,7 @@ def __init__(self, settings: SettingsClass) -> None:
self.settings = settings


class ComponentLauncher(CommunityLauncher):
class ComponentLauncher(CommunityLauncherWEndpoints):
"""
A launcher for components that simply need a TaskManager, not a full Community.
"""
Expand Down Expand Up @@ -103,10 +116,15 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.content_discovery.community import ContentDiscoveryCommunity
session.rest_manager.get_endpoint("/search").content_discovery_community = community

def get_endpoints(self) -> list[RESTEndpoint]:
"""
Add the search endpoint.
"""
from tribler.core.content_discovery.restapi.search_endpoint import SearchEndpoint

session.rest_manager.add_endpoint(SearchEndpoint(cast(ContentDiscoveryCommunity, community)))
return [*super().get_endpoints(), SearchEndpoint()]


@precondition('session.config.get("database/enabled")')
Expand Down Expand Up @@ -142,14 +160,21 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.database.restapi.database_endpoint import DatabaseEndpoint

session.rest_manager.get_endpoint("/downloads").mds = session.mds
session.rest_manager.get_endpoint("/statistics").mds = session.mds
session.rest_manager.add_endpoint(DatabaseEndpoint(session.download_manager,
torrent_checker=None,
metadata_store=session.mds,
tribler_db=session.db))

db_endpoint = session.rest_manager.get_endpoint("/metadata")
db_endpoint.download_manager = session.download_manager
db_endpoint.mds = session.mds
db_endpoint.tribler_db = session.db

def get_endpoints(self) -> list[RESTEndpoint]:
"""
Add the database endpoint.
"""
from tribler.core.database.restapi.database_endpoint import DatabaseEndpoint

return [*super().get_endpoints(), DatabaseEndpoint()]


@set_in_session("knowledge_community")
Expand All @@ -158,7 +183,7 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
@precondition('session.config.get("knowledge_community/enabled")')
@overlay("tribler.core.knowledge.community", "KnowledgeCommunity")
@kwargs(db="session.db", key='session.ipv8.keys["secondary"].key')
class KnowledgeComponent(CommunityLauncher):
class KnowledgeComponent(CommunityLauncherWEndpoints):
"""
Launch instructions for the knowledge community.
"""
Expand All @@ -167,10 +192,17 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.knowledge.community import KnowledgeCommunity
endpoint = session.rest_manager.get_endpoint("/knowledge")
endpoint.db = session.db
endpoint.community = community

def get_endpoints(self) -> list[RESTEndpoint]:
"""
Add the knowledge endpoint.
"""
from tribler.core.knowledge.restapi.knowledge_endpoint import KnowledgeEndpoint

session.rest_manager.add_endpoint(KnowledgeEndpoint(session.db, cast(KnowledgeCommunity, community)))
return [*super().get_endpoints(), KnowledgeEndpoint()]


@after("DatabaseComponent")
Expand Down
20 changes: 11 additions & 9 deletions src/tribler/core/content_discovery/restapi/search_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
from aiohttp_apispec import docs, querystring_schema
from ipv8.REST.schema import schema
from marshmallow.fields import Integer, List, String
from typing_extensions import TypeAlias

from tribler.core.database.queries import to_fts_query
from tribler.core.database.restapi.database_endpoint import DatabaseEndpoint
from tribler.core.database.restapi.schema import MetadataParameters
from tribler.core.restapi.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse

if TYPE_CHECKING:
from aiohttp.abc import Request

from tribler.core.content_discovery.community import ContentDiscoveryCommunity
from tribler.core.restapi.rest_manager import TriblerRequest

RequestType: TypeAlias = TriblerRequest[tuple[ContentDiscoveryCommunity]]


class RemoteQueryParameters(MetadataParameters):
Expand All @@ -36,15 +38,15 @@ class SearchEndpoint(RESTEndpoint):

path = "/search"

def __init__(self,
content_discovery_community: ContentDiscoveryCommunity,
middlewares: tuple = (),
client_max_size: int = MAX_REQUEST_SIZE) -> None:
def __init__(self, middlewares: tuple = (), client_max_size: int = MAX_REQUEST_SIZE) -> None:
"""
Create a new search endpoint.
"""
super().__init__(middlewares, client_max_size)
self.content_discovery_community = content_discovery_community

self.content_discovery_community = None
self.required_components = ("content_discovery_community", )

self.app.add_routes([web.put("/remote", self.remote_search)])

@docs(
Expand All @@ -64,7 +66,7 @@ def __init__(self,
},
)
@querystring_schema(RemoteQueryParameters)
async def remote_search(self, request: Request) -> RESTResponse:
async def remote_search(self, request: RequestType) -> RESTResponse:
"""
Perform a search for a given query.
"""
Expand All @@ -85,7 +87,7 @@ async def remote_search(self, request: Request) -> RESTResponse:
self._logger.info("Parameters: %s", str(sanitized))
self._logger.info("FTS: %s", fts)

request_uuid, peers_list = self.content_discovery_community.send_search_request(**sanitized)
request_uuid, peers_list = request.context[0].send_search_request(**sanitized)
peers_mid_list = [hexlify(p.mid).decode() for p in peers_list]

return RESTResponse({"request_uuid": str(request_uuid), "peers": peers_mid_list})
51 changes: 26 additions & 25 deletions src/tribler/core/database/restapi/database_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ipv8.REST.schema import schema
from marshmallow.fields import Boolean, Integer, String
from pony.orm import db_session
from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from tribler.core.database.layers.knowledge import ResourceType
from tribler.core.database.queries import to_fts_query
Expand All @@ -26,14 +26,16 @@
)

if typing.TYPE_CHECKING:
from aiohttp.abc import Request
from multidict import MultiDictProxy, MultiMapping

from tribler.core.database.store import MetadataStore
from tribler.core.database.tribler_database import TriblerDatabase
from tribler.core.libtorrent.download_manager.download_manager import DownloadManager
from tribler.core.restapi.rest_manager import TriblerRequest
from tribler.core.torrent_checker.torrent_checker import TorrentChecker

RequestType: TypeAlias = TriblerRequest[tuple[MetadataStore]]

TORRENT_CHECK_TIMEOUT = 20

# This dict is used to translate JSON fields into the columns used in Pony for _sorting_.
Expand Down Expand Up @@ -73,21 +75,19 @@ class DatabaseEndpoint(RESTEndpoint):

path = "/metadata"

def __init__(self, # noqa: PLR0913
download_manager: DownloadManager,
torrent_checker: TorrentChecker | None,
metadata_store: MetadataStore,
tribler_db: TriblerDatabase | None = None,
middlewares: tuple = (),
client_max_size: int = MAX_REQUEST_SIZE) -> None:
def __init__(self, middlewares: tuple = (), client_max_size: int = MAX_REQUEST_SIZE) -> None:
"""
Create a new database endpoint.
"""
super().__init__(middlewares, client_max_size)
self.download_manager = download_manager
self.torrent_checker = torrent_checker
self.mds = metadata_store
self.tribler_db: TriblerDatabase | None = tribler_db

self.mds: MetadataStore | None = None
self.required_components = ("mds", )

self.download_manager: DownloadManager | None = None
self.torrent_checker: TorrentChecker | None = None
self.tribler_db: TriblerDatabase | None = None

self.app.add_routes(
[
web.get("/torrents/{infohash}/health", self.get_torrent_health),
Expand Down Expand Up @@ -173,7 +173,7 @@ def add_statements_to_metadata_list(self, contents_list: list[dict]) -> None:
}
},
)
async def get_torrent_health(self, request: Request) -> RESTResponse:
async def get_torrent_health(self, request: RequestType) -> RESTResponse:
"""
Fetch the swarm health of a specific torrent.
"""
Expand All @@ -194,11 +194,12 @@ def add_download_progress_to_metadata_list(self, contents_list: list[dict]) -> N
"""
Retrieve the download status from libtorrent and attach it to the torrent descriptions in the content list.
"""
for torrent in contents_list:
if torrent["type"] == REGULAR_TORRENT:
dl = self.download_manager.get_download(unhexlify(torrent["infohash"]))
if dl is not None and dl.tdef.infohash not in self.download_manager.metainfo_requests:
torrent["progress"] = dl.get_state().get_progress()
if self.download_manager is not None:
for torrent in contents_list:
if torrent["type"] == REGULAR_TORRENT:
dl = self.download_manager.get_download(unhexlify(torrent["infohash"]))
if dl is not None and dl.tdef.infohash not in self.download_manager.metainfo_requests:
torrent["progress"] = dl.get_state().get_progress()

@docs(
tags=["Metadata"],
Expand All @@ -215,7 +216,7 @@ def add_download_progress_to_metadata_list(self, contents_list: list[dict]) -> N
}
},
)
async def get_popular_torrents(self, request: Request) -> RESTResponse:
async def get_popular_torrents(self, request: RequestType) -> RESTResponse:
"""
Get the list of most popular torrents.
"""
Expand All @@ -226,7 +227,7 @@ async def get_popular_torrents(self, request: Request) -> RESTResponse:
sanitized["txt_filter"] = t_filter

with db_session:
contents_list = [entry.to_simple_dict() for entry in self.mds.get_entries(**sanitized)]
contents_list = [entry.to_simple_dict() for entry in request.context[0].get_entries(**sanitized)]

self.add_download_progress_to_metadata_list(contents_list)
self.add_statements_to_metadata_list(contents_list)
Expand Down Expand Up @@ -257,7 +258,7 @@ async def get_popular_torrents(self, request: Request) -> RESTResponse:
},
)
@querystring_schema(SearchMetadataParameters)
async def local_search(self, request: Request) -> RESTResponse: # noqa: C901
async def local_search(self, request: RequestType) -> RESTResponse: # noqa: C901
"""
Perform a search for a given query.
"""
Expand All @@ -281,7 +282,7 @@ async def local_search(self, request: Request) -> RESTResponse: # noqa: C901
sanitized["txt_filter"] = fts
self._logger.info("FTS: %s", fts)

mds: MetadataStore = self.mds
mds: MetadataStore = request.context[0]

def search_db() -> tuple[list[dict], int, int]:
with db_session:
Expand Down Expand Up @@ -344,7 +345,7 @@ def search_db() -> tuple[list[dict], int, int]:
}
},
)
async def completions(self, request: Request) -> RESTResponse:
async def completions(self, request: RequestType) -> RESTResponse:
"""
Return auto-completion suggestions for a given query.
"""
Expand All @@ -353,5 +354,5 @@ async def completions(self, request: Request) -> RESTResponse:
return RESTResponse({"error": "query parameter missing"}, status=HTTP_BAD_REQUEST)

keywords = args["q"].strip().lower()
results = self.mds.get_auto_complete_terms(keywords, max_terms=5)
results = request.context[0].get_auto_complete_terms(keywords, max_terms=5)
return RESTResponse({"completions": results})
Loading

0 comments on commit 197a251

Please sign in to comment.