Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added versioning manager #8092

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lint.ignore = [
"ARG002",
"ARG005",
"ASYNC109",
"ASYNC110",
"BLE001",
"COM812",
"COM819",
Expand Down
10 changes: 6 additions & 4 deletions src/run_tribler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from pathlib import Path

import pystray
import tribler
from aiohttp import ClientSession
from PIL import Image

import tribler
from tribler.core.session import Session
from tribler.tribler_config import TriblerConfigManager
from tribler.tribler_config import VERSION_SUBDIR, TriblerConfigManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +47,7 @@ def get_root_state_directory(requested_path: os.PathLike | None) -> Path:
Get the default application state directory.
"""
root_state_dir = (Path(requested_path) if os.path.isabs(requested_path)
else (Path(os.environ.get("APPDATA", "~")) / ".TriblerExperimental").expanduser().absolute())
else (Path(os.environ.get("APPDATA", "~")) / ".Tribler").expanduser().absolute())
root_state_dir.mkdir(parents=True, exist_ok=True)
return root_state_dir

Expand All @@ -73,8 +74,9 @@ async def main() -> None:
logger.info("Run Tribler: %s", parsed_args)

root_state_dir = get_root_state_directory(os.environ.get('TSTATEDIR', 'state_directory'))
(root_state_dir / VERSION_SUBDIR).mkdir(exist_ok=True, parents=True)
logger.info("Root state dir: %s", root_state_dir)
config = TriblerConfigManager(root_state_dir / "configuration.json")
config = TriblerConfigManager(root_state_dir / VERSION_SUBDIR / "configuration.json")
config.set("state_dir", str(root_state_dir))

if "CORE_API_PORT" in os.environ:
Expand Down
36 changes: 31 additions & 5 deletions src/tribler/core/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def prepare(self, ipv8: IPv8, session: Session) -> None:
from tribler.core.database.tribler_database import TriblerDatabase
from tribler.core.notifier import Notification

db_path = str(Path(session.config.get("state_dir")) / "sqlite" / "tribler.db")
mds_path = str(Path(session.config.get("state_dir")) / "sqlite" / "metadata.db")
db_path = str(Path(session.config.get_version_state_dir()) / "sqlite" / "tribler.db")
mds_path = str(Path(session.config.get_version_state_dir()) / "sqlite" / "metadata.db")
if session.config.get("memory_db"):
db_path = ":memory:"
mds_path = ":memory:"
Expand Down Expand Up @@ -221,7 +221,8 @@ def get_kwargs(self, session: Session) -> dict:
from tribler.core.rendezvous.database import RendezvousDatabase

out = super().get_kwargs(session)
out["database"] = RendezvousDatabase(db_path=Path(session.config.get("state_dir")) / "sqlite" / "rendezvous.db")
out["database"] = (RendezvousDatabase(db_path=Path(session.config.get_version_state_dir()) / "sqlite"
/ "rendezvous.db"))

return out

Expand Down Expand Up @@ -249,7 +250,8 @@ def prepare(self, overlay_provider: IPv8, session: Session) -> None:
from tribler.core.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.torrent_checker.tracker_manager import TrackerManager

tracker_manager = TrackerManager(state_dir=session.config.get("state_dir"), metadata_store=session.mds)
tracker_manager = TrackerManager(state_dir=Path(session.config.get_version_state_dir()),
metadata_store=session.mds)
torrent_checker = TorrentChecker(config=session.config,
download_manager=session.download_manager,
notifier=session.notifier,
Expand Down Expand Up @@ -298,7 +300,7 @@ def get_kwargs(self, session: Session) -> dict:
from ipv8.dht.provider import DHTCommunityProvider

out = super().get_kwargs(session)
out["exitnode_cache"] = Path(session.config.get("state_dir")) / "exitnode_cache.dat"
out["exitnode_cache"] = Path(session.config.get_version_state_dir()) / "exitnode_cache.dat"
out["notifier"] = session.notifier
out["download_manager"] = session.download_manager
out["socks_servers"] = session.socks_servers
Expand Down Expand Up @@ -336,3 +338,27 @@ def get_kwargs(self, session: Session) -> dict:
max_query_history = session.config.get("user_activity/max_query_history")
out["manager"] = UserActivityManager(TaskManager(), session, max_query_history)
return out

@precondition('session.config.get("versioning/enabled")')
class VersioningComponent(ComponentLauncher):
"""
Launch instructions for the versioning of Tribler.
"""

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.versioning.manager import VersioningManager

session.rest_manager.get_endpoint("/api/versioning").versioning_manager = VersioningManager(
community, session.config
)

def get_endpoints(self) -> list[RESTEndpoint]:
"""
Add the database endpoint.
"""
from tribler.core.versioning.restapi.versioning_endpoint import VersioningEndpoint

return [*super().get_endpoints(), VersioningEndpoint()]
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_spec_file_name(settings: TriblerConfigManager) -> str:
"""
Get the file name of the download spec.
"""
return str(Path(settings.get("state_dir")) / SPEC_FILENAME)
return str(Path(settings.get_version_state_dir()) / SPEC_FILENAME)

@staticmethod
def from_defaults(settings: TriblerConfigManager) -> DownloadConfig:
Expand All @@ -127,6 +127,7 @@ def from_defaults(settings: TriblerConfigManager) -> DownloadConfig:
spec_file_name = DownloadConfig.get_spec_file_name(settings)
defaults = ConfigObj(StringIO(SPEC_CONTENT))
defaults["filename"] = spec_file_name
Path(spec_file_name).parent.mkdir(parents=True, exist_ok=True) # Required for the next write
with open(spec_file_name, "wb") as spec_file:
defaults.write(spec_file)
defaults = ConfigObj(StringIO(), configspec=spec_file_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier,
super().__init__()
self.config = config

self.state_dir = Path(config.get("state_dir"))
self.state_dir = Path(config.get_version_state_dir())
self.ltsettings: dict[lt.session, dict] = {} # Stores a copy of the settings dict for each libtorrent session
self.ltsessions: dict[int, lt.session] = {}
self.dht_health_manager: DHTHealthManager | None = None
Expand Down Expand Up @@ -176,7 +176,7 @@ def initialize(self) -> None:
Initialize the directory structure, launch the periodic tasks and start libtorrent background processes.
"""
# Create the checkpoints directory
self.checkpoint_directory.mkdir(exist_ok=True)
self.checkpoint_directory.mkdir(exist_ok=True, parents=True)

# Start upnp
if self.config.get("libtorrent/upnp"):
Expand Down Expand Up @@ -245,7 +245,7 @@ async def shutdown(self, timeout: int = 30) -> None:
if self.has_session():
logger.info("Saving state...")
self.notify_shutdown_state("Writing session state to disk.")
with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC101
with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC230
ltstate_file.write(lt.bencode(self.get_session().save_state()))

if self.has_session() and self.config.get("libtorrent/upnp"):
Expand Down
7 changes: 5 additions & 2 deletions src/tribler/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TorrentCheckerComponent,
TunnelComponent,
UserActivityComponent,
VersioningComponent,
)
from tribler.core.libtorrent.download_manager.download_manager import DownloadManager
from tribler.core.libtorrent.restapi.create_torrent_endpoint import CreateTorrentEndpoint
Expand Down Expand Up @@ -121,7 +122,8 @@ def register_launchers(self) -> None:
Register all IPv8 launchers that allow communities to be loaded.
"""
for launcher_class in [ContentDiscoveryComponent, DatabaseComponent, DHTDiscoveryComponent, KnowledgeComponent,
RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent]:
RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent,
VersioningComponent]:
instance = launcher_class()
for rest_ep in instance.get_endpoints():
self.rest_manager.add_endpoint(rest_ep)
Expand Down Expand Up @@ -168,7 +170,8 @@ async def start(self) -> None:
self.rest_manager.get_endpoint("/api/ipv8").initialize(self.ipv8)
self.rest_manager.get_endpoint("/api/statistics").ipv8 = self.ipv8
if self.config.get("statistics"):
self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None, True)
self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None,
True)

async def find_api_server(self) -> str | None:
"""
Expand Down
Empty file.
114 changes: 114 additions & 0 deletions src/tribler/core/versioning/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import logging
import os
import platform
import shutil
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import TYPE_CHECKING

from aiohttp import ClientSession
from packaging.version import Version

from tribler.tribler_config import TriblerConfigManager
from tribler.upgrade_script import FROM, TO, upgrade

if TYPE_CHECKING:
from ipv8.taskmanager import TaskManager

logger = logging.getLogger(__name__)


class VersioningManager:
"""
Version related logic.
"""

def __init__(self, task_manager: TaskManager, config: TriblerConfigManager | None) -> None:
"""
Create a new versioning manager.
"""
super().__init__()
self.task_manager = task_manager
self.config = config or TriblerConfigManager()

def get_current_version(self) -> str | None:
"""
Get the current release version, or None when running from archive or GIT.
"""
try:
return version("tribler")
except PackageNotFoundError:
return None

def get_versions(self) -> list[str]:
"""
Get all versions in our state directory.
"""
return [p for p in os.listdir(self.config.get("state_dir"))
if os.path.isdir(os.path.join(self.config.get("state_dir"), p))]

async def check_version(self) -> str | None:
"""
Check the tribler.org + GitHub websites for a new version.
"""
current_version = self.get_current_version()
if current_version is None:
return None

headers = {
"User-Agent": (f"Tribler/{current_version} "
f"(machine={platform.machine()}; os={platform.system()} {platform.release()}; "
f"python={platform.python_version()}; executable={platform.architecture()[0]})")
}
urls = [
f"https://release.tribler.org/releases/latest?current={current_version}",
"https://api.github.com/repos/tribler/tribler/releases/latest"
]

for url in urls:
try:
async with ClientSession(raise_for_status=True) as session:
response = await session.get(url, headers=headers, timeout=5.0)
response_dict = await response.json(content_type=None)
response_version = response_dict["name"]
if response_version.startswith("v"):
response_version = response_version[1:]
except Exception as e:
logger.info(e)
continue # Case 1: this failed, but we may still have another URL to check. Continue.
if Version(response_version) > Version(current_version):
return response_version # Case 2: we found a newer version. Stop.
break # Case 3: we got a response, but we are already at a newer or equal version. Stop.
return None # Either Case 3 or repeated Case 1: no URLs responded. No new version available.

def can_upgrade(self) -> str | bool:
"""
Check if we have old database/download files to port to our current version.
Returns the version that can be upgraded from.
"""
if os.path.isfile(os.path.join(self.config.get_version_state_dir(), ".upgraded")):
return False # We have the upgraded marker: nothing to do.

if FROM not in self.get_versions():
return False # We can't upgrade from this version.

return FROM if (self.get_current_version() in [None, TO]) else False # Always allow upgrades to git (None).

def perform_upgrade(self) -> None:
"""
Upgrade old database/download files to our current version.
"""
src_dir = Path(self.config.get("state_dir")) / FROM
dst_dir = Path(self.config.get_version_state_dir())
self.task_manager.register_executor_task("Upgrade", upgrade, self.config,
str(src_dir.expanduser().absolute()),
str(dst_dir.expanduser().absolute()))

def remove_version(self, version: str) -> None:
"""
Remove the files for a version.
"""
shutil.rmtree(os.path.join(self.config.get("state_dir"), version), ignore_errors=True)
Empty file.
Loading