diff --git a/.ruff.toml b/.ruff.toml index f695d22b1e..b711ca71f0 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -9,6 +9,7 @@ lint.ignore = [ "ARG002", "ARG005", "ASYNC109", + "ASYNC110", "BLE001", "COM812", "COM819", diff --git a/src/run_tribler.py b/src/run_tribler.py index d16e204e56..689014e9c0 100644 --- a/src/run_tribler.py +++ b/src/run_tribler.py @@ -12,11 +12,12 @@ from pathlib import Path import pystray -import tribler from aiohttp import ClientSession from PIL import Image + +import tribler from tribler.core.session import Session -from tribler.tribler_config import TriblerConfigManager +from tribler.tribler_config import VERSION_SUBDIR, TriblerConfigManager logger = logging.getLogger(__name__) @@ -46,7 +47,7 @@ def get_root_state_directory(requested_path: os.PathLike | None) -> Path: Get the default application state directory. """ root_state_dir = (Path(requested_path) if os.path.isabs(requested_path) - else (Path(os.environ.get("APPDATA", "~")) / ".TriblerExperimental").expanduser().absolute()) + else (Path(os.environ.get("APPDATA", "~")) / ".Tribler").expanduser().absolute()) root_state_dir.mkdir(parents=True, exist_ok=True) return root_state_dir @@ -73,8 +74,9 @@ async def main() -> None: logger.info("Run Tribler: %s", parsed_args) root_state_dir = get_root_state_directory(os.environ.get('TSTATEDIR', 'state_directory')) + (root_state_dir / VERSION_SUBDIR).mkdir(exist_ok=True, parents=True) logger.info("Root state dir: %s", root_state_dir) - config = TriblerConfigManager(root_state_dir / "configuration.json") + config = TriblerConfigManager(root_state_dir / VERSION_SUBDIR / "configuration.json") config.set("state_dir", str(root_state_dir)) if "CORE_API_PORT" in os.environ: diff --git a/src/tribler/core/components.py b/src/tribler/core/components.py index e5e5f13caf..b94c74451a 100644 --- a/src/tribler/core/components.py +++ b/src/tribler/core/components.py @@ -142,8 +142,8 @@ def prepare(self, ipv8: IPv8, session: Session) -> None: from tribler.core.database.tribler_database import TriblerDatabase from tribler.core.notifier import Notification - db_path = str(Path(session.config.get("state_dir")) / "sqlite" / "tribler.db") - mds_path = str(Path(session.config.get("state_dir")) / "sqlite" / "metadata.db") + db_path = str(Path(session.config.get_version_state_dir()) / "sqlite" / "tribler.db") + mds_path = str(Path(session.config.get_version_state_dir()) / "sqlite" / "metadata.db") if session.config.get("memory_db"): db_path = ":memory:" mds_path = ":memory:" @@ -221,7 +221,8 @@ def get_kwargs(self, session: Session) -> dict: from tribler.core.rendezvous.database import RendezvousDatabase out = super().get_kwargs(session) - out["database"] = RendezvousDatabase(db_path=Path(session.config.get("state_dir")) / "sqlite" / "rendezvous.db") + out["database"] = (RendezvousDatabase(db_path=Path(session.config.get_version_state_dir()) / "sqlite" + / "rendezvous.db")) return out @@ -249,7 +250,8 @@ def prepare(self, overlay_provider: IPv8, session: Session) -> None: from tribler.core.torrent_checker.torrent_checker import TorrentChecker from tribler.core.torrent_checker.tracker_manager import TrackerManager - tracker_manager = TrackerManager(state_dir=session.config.get("state_dir"), metadata_store=session.mds) + tracker_manager = TrackerManager(state_dir=Path(session.config.get_version_state_dir()), + metadata_store=session.mds) torrent_checker = TorrentChecker(config=session.config, download_manager=session.download_manager, notifier=session.notifier, @@ -298,7 +300,7 @@ def get_kwargs(self, session: Session) -> dict: from ipv8.dht.provider import DHTCommunityProvider out = super().get_kwargs(session) - out["exitnode_cache"] = Path(session.config.get("state_dir")) / "exitnode_cache.dat" + out["exitnode_cache"] = Path(session.config.get_version_state_dir()) / "exitnode_cache.dat" out["notifier"] = session.notifier out["download_manager"] = session.download_manager out["socks_servers"] = session.socks_servers @@ -336,3 +338,27 @@ def get_kwargs(self, session: Session) -> dict: max_query_history = session.config.get("user_activity/max_query_history") out["manager"] = UserActivityManager(TaskManager(), session, max_query_history) return out + +@precondition('session.config.get("versioning/enabled")') +class VersioningComponent(ComponentLauncher): + """ + Launch instructions for the versioning of Tribler. + """ + + def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None: + """ + When we are done launching, register our REST API. + """ + from tribler.core.versioning.manager import VersioningManager + + session.rest_manager.get_endpoint("/api/versioning").versioning_manager = VersioningManager( + community, session.config + ) + + def get_endpoints(self) -> list[RESTEndpoint]: + """ + Add the database endpoint. + """ + from tribler.core.versioning.restapi.versioning_endpoint import VersioningEndpoint + + return [*super().get_endpoints(), VersioningEndpoint()] diff --git a/src/tribler/core/libtorrent/download_manager/download_config.py b/src/tribler/core/libtorrent/download_manager/download_config.py index aff2e089f5..184b8f1167 100644 --- a/src/tribler/core/libtorrent/download_manager/download_config.py +++ b/src/tribler/core/libtorrent/download_manager/download_config.py @@ -117,7 +117,7 @@ def get_spec_file_name(settings: TriblerConfigManager) -> str: """ Get the file name of the download spec. """ - return str(Path(settings.get("state_dir")) / SPEC_FILENAME) + return str(Path(settings.get_version_state_dir()) / SPEC_FILENAME) @staticmethod def from_defaults(settings: TriblerConfigManager) -> DownloadConfig: @@ -127,6 +127,7 @@ def from_defaults(settings: TriblerConfigManager) -> DownloadConfig: spec_file_name = DownloadConfig.get_spec_file_name(settings) defaults = ConfigObj(StringIO(SPEC_CONTENT)) defaults["filename"] = spec_file_name + Path(spec_file_name).parent.mkdir(parents=True, exist_ok=True) # Required for the next write with open(spec_file_name, "wb") as spec_file: defaults.write(spec_file) defaults = ConfigObj(StringIO(), configspec=spec_file_name) diff --git a/src/tribler/core/libtorrent/download_manager/download_manager.py b/src/tribler/core/libtorrent/download_manager/download_manager.py index 004101fedd..2811241728 100644 --- a/src/tribler/core/libtorrent/download_manager/download_manager.py +++ b/src/tribler/core/libtorrent/download_manager/download_manager.py @@ -89,7 +89,7 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier, super().__init__() self.config = config - self.state_dir = Path(config.get("state_dir")) + self.state_dir = Path(config.get_version_state_dir()) self.ltsettings: dict[lt.session, dict] = {} # Stores a copy of the settings dict for each libtorrent session self.ltsessions: dict[int, lt.session] = {} self.dht_health_manager: DHTHealthManager | None = None @@ -176,7 +176,7 @@ def initialize(self) -> None: Initialize the directory structure, launch the periodic tasks and start libtorrent background processes. """ # Create the checkpoints directory - self.checkpoint_directory.mkdir(exist_ok=True) + self.checkpoint_directory.mkdir(exist_ok=True, parents=True) # Start upnp if self.config.get("libtorrent/upnp"): @@ -245,7 +245,7 @@ async def shutdown(self, timeout: int = 30) -> None: if self.has_session(): logger.info("Saving state...") self.notify_shutdown_state("Writing session state to disk.") - with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC101 + with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC230 ltstate_file.write(lt.bencode(self.get_session().save_state())) if self.has_session() and self.config.get("libtorrent/upnp"): diff --git a/src/tribler/core/session.py b/src/tribler/core/session.py index c0cbd8d971..5055e6ed19 100644 --- a/src/tribler/core/session.py +++ b/src/tribler/core/session.py @@ -19,6 +19,7 @@ TorrentCheckerComponent, TunnelComponent, UserActivityComponent, + VersioningComponent, ) from tribler.core.libtorrent.download_manager.download_manager import DownloadManager from tribler.core.libtorrent.restapi.create_torrent_endpoint import CreateTorrentEndpoint @@ -121,7 +122,8 @@ def register_launchers(self) -> None: Register all IPv8 launchers that allow communities to be loaded. """ for launcher_class in [ContentDiscoveryComponent, DatabaseComponent, DHTDiscoveryComponent, KnowledgeComponent, - RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent]: + RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent, + VersioningComponent]: instance = launcher_class() for rest_ep in instance.get_endpoints(): self.rest_manager.add_endpoint(rest_ep) @@ -168,7 +170,8 @@ async def start(self) -> None: self.rest_manager.get_endpoint("/api/ipv8").initialize(self.ipv8) self.rest_manager.get_endpoint("/api/statistics").ipv8 = self.ipv8 if self.config.get("statistics"): - self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None, True) + self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None, + True) async def find_api_server(self) -> str | None: """ diff --git a/src/tribler/core/versioning/__init__.py b/src/tribler/core/versioning/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tribler/core/versioning/manager.py b/src/tribler/core/versioning/manager.py new file mode 100644 index 0000000000..e0b5609505 --- /dev/null +++ b/src/tribler/core/versioning/manager.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import logging +import os +import platform +import shutil +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import TYPE_CHECKING + +from aiohttp import ClientSession +from packaging.version import Version + +from tribler.tribler_config import TriblerConfigManager +from tribler.upgrade_script import FROM, TO, upgrade + +if TYPE_CHECKING: + from ipv8.taskmanager import TaskManager + +logger = logging.getLogger(__name__) + + +class VersioningManager: + """ + Version related logic. + """ + + def __init__(self, task_manager: TaskManager, config: TriblerConfigManager | None) -> None: + """ + Create a new versioning manager. + """ + super().__init__() + self.task_manager = task_manager + self.config = config or TriblerConfigManager() + + def get_current_version(self) -> str | None: + """ + Get the current release version, or None when running from archive or GIT. + """ + try: + return version("tribler") + except PackageNotFoundError: + return None + + def get_versions(self) -> list[str]: + """ + Get all versions in our state directory. + """ + return [p for p in os.listdir(self.config.get("state_dir")) + if os.path.isdir(os.path.join(self.config.get("state_dir"), p))] + + async def check_version(self) -> str | None: + """ + Check the tribler.org + GitHub websites for a new version. + """ + current_version = self.get_current_version() + if current_version is None: + return None + + headers = { + "User-Agent": (f"Tribler/{current_version} " + f"(machine={platform.machine()}; os={platform.system()} {platform.release()}; " + f"python={platform.python_version()}; executable={platform.architecture()[0]})") + } + urls = [ + f"https://release.tribler.org/releases/latest?current={current_version}", + "https://api.github.com/repos/tribler/tribler/releases/latest" + ] + + for url in urls: + try: + async with ClientSession(raise_for_status=True) as session: + response = await session.get(url, headers=headers, timeout=5.0) + response_dict = await response.json(content_type=None) + response_version = response_dict["name"] + if response_version.startswith("v"): + response_version = response_version[1:] + except Exception as e: + logger.info(e) + continue # Case 1: this failed, but we may still have another URL to check. Continue. + if Version(response_version) > Version(current_version): + return response_version # Case 2: we found a newer version. Stop. + break # Case 3: we got a response, but we are already at a newer or equal version. Stop. + return None # Either Case 3 or repeated Case 1: no URLs responded. No new version available. + + def can_upgrade(self) -> str | bool: + """ + Check if we have old database/download files to port to our current version. + + Returns the version that can be upgraded from. + """ + if os.path.isfile(os.path.join(self.config.get_version_state_dir(), ".upgraded")): + return False # We have the upgraded marker: nothing to do. + + if FROM not in self.get_versions(): + return False # We can't upgrade from this version. + + return FROM if (self.get_current_version() in [None, TO]) else False # Always allow upgrades to git (None). + + def perform_upgrade(self) -> None: + """ + Upgrade old database/download files to our current version. + """ + src_dir = Path(self.config.get("state_dir")) / FROM + dst_dir = Path(self.config.get_version_state_dir()) + self.task_manager.register_executor_task("Upgrade", upgrade, self.config, + str(src_dir.expanduser().absolute()), + str(dst_dir.expanduser().absolute())) + + def remove_version(self, version: str) -> None: + """ + Remove the files for a version. + """ + shutil.rmtree(os.path.join(self.config.get("state_dir"), version), ignore_errors=True) diff --git a/src/tribler/core/versioning/restapi/__init__.py b/src/tribler/core/versioning/restapi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tribler/core/versioning/restapi/versioning_endpoint.py b/src/tribler/core/versioning/restapi/versioning_endpoint.py new file mode 100644 index 0000000000..dd02b396d8 --- /dev/null +++ b/src/tribler/core/versioning/restapi/versioning_endpoint.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aiohttp import web +from aiohttp_apispec import docs +from ipv8.REST.schema import schema +from marshmallow.fields import Bool, List, String + +from tribler.core.restapi.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse +from tribler.tribler_config import VERSION_SUBDIR + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from tribler.core.restapi.rest_manager import TriblerRequest + from tribler.core.versioning.manager import VersioningManager + RequestType: TypeAlias = TriblerRequest[tuple[VersioningManager]] + + +class VersioningEndpoint(RESTEndpoint): + """ + An endpoint for version determination and upgrading from the previous version. + """ + + path = "/api/versioning" + + def __init__(self, middlewares: tuple = (), client_max_size: int = MAX_REQUEST_SIZE) -> None: + """ + Create a new endpoint to create torrents. + """ + super().__init__(middlewares, client_max_size) + + self.versioning_manager: VersioningManager | None = None + self.required_components = ("versioning_manager",) + + self.app.add_routes([ + web.get("/versions", self.get_versions), + web.get("/versions/current", self.get_current_version), + web.get("/versions/check", self.check_version), + web.delete("/versions/{version}", self.remove_version), + web.post("/upgrade", self.perform_upgrade), + web.get("/upgrade/available", self.can_upgrade), + web.get("/upgrade/working", self.is_upgrading) + ]) + + @docs( + tags=["Versioning"], + summary="Get the current release version or whether we are running from source.", + responses={ + 200: { + "schema": schema(CurrentVersionResponse={"version": String}) + } + } + ) + async def get_current_version(self, request: RequestType) -> RESTResponse: + """ + Get the current release version, or None when running from archive or GIT. + """ + return RESTResponse({"version": request.context[0].get_current_version() or "git"}) + + @docs( + tags=["Versioning"], + summary="Get all versions in our state directory.", + responses={ + 200: { + "schema": schema(GetVersionsResponse={"versions": List(String), "current": String}) + } + } + ) + async def get_versions(self, request: RequestType) -> RESTResponse: + """ + Get all versions in our state directory. + """ + return RESTResponse({"versions": request.context[0].get_versions(), "current": VERSION_SUBDIR}) + + @docs( + tags=["Versioning"], + summary="Check the tribler.org + GitHub websites for a new version.", + responses={ + 200: { + "schema": schema(CheckVersionResponse={"new_version": String, "has_version": Bool}) + } + } + ) + async def check_version(self, request: RequestType) -> RESTResponse: + """ + Check the tribler.org + GitHub websites for a new version. + """ + new_version = await request.context[0].check_version() + return RESTResponse({"new_version": new_version or "", "has_version": new_version is not None}) + + @docs( + tags=["Versioning"], + summary="Check if we have old database/download files to port to our current version.", + responses={ + 200: { + "schema": schema(CanUpgradeResponse={"can_upgrade": String}) + } + } + ) + async def can_upgrade(self, request: RequestType) -> RESTResponse: + """ + Check if we have old database/download files to port to our current version. + """ + return RESTResponse({"can_upgrade": request.context[0].can_upgrade()}) + + @docs( + tags=["Versioning"], + summary="Perform an upgrade.", + responses={ + 200: { + "schema": schema(PerformUpgradeResponse={"success": Bool}) + } + } + ) + async def perform_upgrade(self, request: RequestType) -> RESTResponse: + """ + Perform an upgrade. + """ + request.context[0].perform_upgrade() + return RESTResponse({"success": True}) + + @docs( + tags=["Versioning"], + summary="Check if the upgrade is still running.", + responses={ + 200: { + "schema": schema(IsUpgradingResponse={"running": Bool}) + } + } + ) + async def is_upgrading(self, request: RequestType) -> RESTResponse: + """ + Check if the upgrade is still running. + """ + return RESTResponse({"running": request.context[0].task_manager.get_task("Upgrade") is not None}) + + @docs( + tags=["Versioning"], + summary="Check if the upgrade is still running.", + parameters=[{ + "in": "path", + "name": "version", + "description": "The version to remove.", + "type": "string", + "required": "true" + }], + responses={ + 200: { + "schema": schema(RemoveVersionResponse={"success": Bool}) + }, + HTTP_BAD_REQUEST: { + "schema": schema(RemoveVersionNotFoundResponse={"error": String}) + } + } + ) + async def remove_version(self, request: RequestType) -> RESTResponse: + """ + Remove the files for a version. + """ + version = request.match_info["version"] + if not version: + return RESTResponse({"error": "No version given"}, status=HTTP_BAD_REQUEST) + request.context[0].remove_version(version) + return RESTResponse({"success": True}) diff --git a/src/tribler/test_unit/core/versioning/__init__.py b/src/tribler/test_unit/core/versioning/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tribler/test_unit/core/versioning/restapi/__init__.py b/src/tribler/test_unit/core/versioning/restapi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tribler/test_unit/core/versioning/restapi/test_versioning_endpoint.py b/src/tribler/test_unit/core/versioning/restapi/test_versioning_endpoint.py new file mode 100644 index 0000000000..e22e1a402d --- /dev/null +++ b/src/tribler/test_unit/core/versioning/restapi/test_versioning_endpoint.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, Mock, call + +from ipv8.test.base import TestBase + +from tribler.core.restapi.rest_endpoint import HTTP_BAD_REQUEST +from tribler.core.versioning.restapi.versioning_endpoint import VersioningEndpoint +from tribler.test_unit.base_restapi import MockRequest, response_to_json +from tribler.tribler_config import VERSION_SUBDIR + +if TYPE_CHECKING: + from tribler.core.versioning.manager import VersioningManager + + +class GenericRequest(MockRequest): + """ + A MockRequest that mimics generic GET requests for the versioning endpoint. + """ + + def __init__(self, vman: VersioningManager, route: str) -> None: + """ + Create a new request. + """ + super().__init__({}, "GET", f"/versioning/{route}") + self.context = (vman,) + + +class PerformUpgradeRequest(MockRequest): + """ + A MockRequest that mimics PerformUpgrade requests for the versioning endpoint. + """ + + def __init__(self, vman: VersioningManager) -> None: + """ + Create a new request. + """ + super().__init__({}, "POST", "/versioning/upgrade") + self.context = (vman,) + + +class RemoveVersionRequest(MockRequest): + """ + A MockRequest that mimics RemoveVersion requests for the versioning endpoint. + """ + + def __init__(self, vman: VersioningManager, version: str) -> None: + """ + Create a new request. + """ + super().__init__({}, "DELETE", f"/versioning/versions/{version}") + self.context = (vman,) + self.version_str = version + + @property + def match_info(self) -> dict[str, str]: + """ + Return our version info. + """ + return {"version": self.version_str} + + +class TestVersioningEndpoint(TestBase): + """ + Tests for the VersioningEndpoint class. + """ + + def setUp(self) -> None: + """ + Create a new VersioningEndpoint. + """ + super().setUp() + self.vman = Mock() + self.rest_ep = VersioningEndpoint() + self.rest_ep.versioning_manager = self.vman + + async def test_current_version(self) -> None: + """ + Check if the current version is correctly returned. + """ + self.vman.get_current_version = Mock(return_value="1.2.3") + + response = await self.rest_ep.get_current_version(GenericRequest(self.vman, "versions/current")) + response_body_json = await response_to_json(response) + + self.assertEqual("1.2.3", response_body_json["version"]) + + async def test_versions(self) -> None: + """ + Check if the known versions are correctly returned. + """ + self.vman.get_versions = Mock(return_value=["1.2.3", "4.5.6"]) + + response = await self.rest_ep.get_versions(GenericRequest(self.vman, "versions")) + response_body_json = await response_to_json(response) + + self.assertEqual({"1.2.3", "4.5.6"}, set(response_body_json["versions"])) + self.assertEqual(VERSION_SUBDIR, response_body_json["current"]) + + async def test_check_version_available(self) -> None: + """ + Check if the checked version is correctly returned when a version is available. + """ + self.vman.check_version = AsyncMock(return_value="1.2.3") + + response = await self.rest_ep.check_version(GenericRequest(self.vman, "versions/check")) + response_body_json = await response_to_json(response) + + self.assertTrue(response_body_json["has_version"]) + self.assertEqual("1.2.3", response_body_json["new_version"]) + + async def test_check_version_unavailable(self) -> None: + """ + Check if the checked version is correctly returned when a version is not available. + """ + self.vman.check_version = AsyncMock(return_value=None) + + response = await self.rest_ep.check_version(GenericRequest(self.vman, "versions/check")) + response_body_json = await response_to_json(response) + + self.assertFalse(response_body_json["has_version"]) + self.assertEqual("", response_body_json["new_version"]) + + async def test_can_upgrade_no(self) -> None: + """ + Check if the inability to upgrade is correctly returned. + """ + self.vman.can_upgrade = Mock(return_value=False) + + response = await self.rest_ep.can_upgrade(GenericRequest(self.vman, "upgrade/available")) + response_body_json = await response_to_json(response) + + self.assertFalse(response_body_json["can_upgrade"]) + + async def test_can_upgrade(self) -> None: + """ + Check if the ability to upgrade is correctly returned. + """ + self.vman.can_upgrade = Mock(return_value="1.2.3") + + response = await self.rest_ep.can_upgrade(GenericRequest(self.vman, "upgrade/available")) + response_body_json = await response_to_json(response) + + self.assertEqual("1.2.3", response_body_json["can_upgrade"]) + + async def test_is_upgrading(self) -> None: + """ + Check if the upgrading status is correctly returned. + """ + self.vman.task_manager.get_task = Mock(return_value=True) + + response = await self.rest_ep.is_upgrading(GenericRequest(self.vman, "upgrade/working")) + response_body_json = await response_to_json(response) + + self.assertTrue(response_body_json["running"]) + + async def test_is_upgrading_no(self) -> None: + """ + Check if the non-upgrading status is correctly returned. + """ + self.vman.task_manager.get_task = Mock(return_value=None) + + response = await self.rest_ep.is_upgrading(GenericRequest(self.vman, "upgrade/working")) + response_body_json = await response_to_json(response) + + self.assertFalse(response_body_json["running"]) + + async def test_perform_upgrade(self) -> None: + """ + Check if a request to perform an upgrade launches an upgrade task. + """ + response = await self.rest_ep.perform_upgrade(PerformUpgradeRequest(self.vman)) + response_body_json = await response_to_json(response) + + self.assertTrue(response_body_json["success"]) + self.assertEqual(call(), self.vman.perform_upgrade.call_args) + + async def test_remove_version_illegal(self) -> None: + """ + Check if a request without a version returns a BAD REQUEST status. + """ + response = await self.rest_ep.remove_version(RemoveVersionRequest(self.vman, "")) + + self.assertEqual(HTTP_BAD_REQUEST, response.status) + + async def test_remove_version(self) -> None: + """ + Check if a request to remove a given version is forwarded. + """ + response = await self.rest_ep.remove_version(RemoveVersionRequest(self.vman, "1.2.3")) + response_body_json = await response_to_json(response) + + self.assertTrue(response_body_json["success"]) + self.assertEqual(call("1.2.3"), self.vman.remove_version.call_args) diff --git a/src/tribler/test_unit/core/versioning/test_manager.py b/src/tribler/test_unit/core/versioning/test_manager.py new file mode 100644 index 0000000000..6c7a0198f9 --- /dev/null +++ b/src/tribler/test_unit/core/versioning/test_manager.py @@ -0,0 +1,184 @@ +from importlib.metadata import PackageNotFoundError +from unittest.mock import AsyncMock, Mock, patch + +from ipv8.taskmanager import TaskManager +from ipv8.test.base import TestBase + +import tribler +from tribler.core.versioning.manager import VersioningManager +from tribler.tribler_config import TriblerConfigManager +from tribler.upgrade_script import FROM, TO + + +class MockTriblerConfigManager(TriblerConfigManager): + """ + A memory-based TriblerConfigManager. + """ + + def write(self) -> None: + """ + Don't actually write to any file. + """ + + +class TestVersioningManager(TestBase): + """ + Tests for the Notifier class. + """ + + def setUp(self) -> None: + """ + Create a new versioning manager. + """ + super().setUp() + self.task_manager = TaskManager() + self.manager = VersioningManager(self.task_manager, MockTriblerConfigManager()) + + async def tearDown(self) -> None: + """ + Shut down our task manager. + """ + await self.task_manager.shutdown_task_manager() + await super().tearDown() + + def test_get_current_version(self) -> None: + """ + Check if a normal version can be correctly returned. + """ + with patch.dict(tribler.core.versioning.manager.__dict__, {"version": lambda _: "1.2.3"}): + self.assertEqual("1.2.3", self.manager.get_current_version()) + + def test_get_current_version_not_found(self) -> None: + """ + Check if a value of None is returned as the version, when it cannot be found. + """ + with patch.dict(tribler.core.versioning.manager.__dict__, {"version": Mock(side_effect=PackageNotFoundError)}): + self.assertIsNone(self.manager.get_current_version()) + + def test_get_versions(self) -> None: + """ + Check if we can find all three versions in our test directory. + """ + with patch("os.listdir", lambda _: ["1.2.3", "1.3.0", "1.2.4"]), patch("os.path.isdir", lambda _: True): + self.assertEqual({"1.2.3", "1.2.4", "1.3.0"}, set(self.manager.get_versions())) + + def test_get_versions_empty(self) -> None: + """ + Check if an empty list is returned if no versions exist. + """ + with patch("os.listdir", lambda _: []): + self.assertEqual(set(), set(self.manager.get_versions())) + + async def test_check_version_no_version(self) -> None: + """ + Check if the bleeding edge source does not think it needs to be updated. + """ + self.assertIsNone(await self.manager.check_version()) + + async def test_check_version_no_responses(self) -> None: + """ + Check if None is returned when no responses are received. + """ + self.manager.get_current_version = Mock(return_value="1.0.0") + with patch.dict(tribler.core.versioning.manager.__dict__, {"ClientSession": Mock(side_effect=RuntimeError)}): + self.assertIsNone(await self.manager.check_version()) + + async def test_check_version_latest(self) -> None: + """ + Check if None is returned when we are already at the latest version. + """ + self.manager.get_current_version = Mock(return_value="1.0.0") + with patch.dict(tribler.core.versioning.manager.__dict__, {"ClientSession": Mock(return_value=Mock( + __aexit__=AsyncMock(), + __aenter__=AsyncMock(return_value=AsyncMock( + get=AsyncMock(return_value=Mock(json=AsyncMock(return_value={"name": "1.0.0"}))) + + ))))}): + self.assertIsNone(await self.manager.check_version()) + + async def test_check_version_latest_old(self) -> None: + """ + Check if None is returned when we are already at the latest version, in old format. + """ + self.manager.get_current_version = Mock(return_value="1.0.0") + with patch.dict(tribler.core.versioning.manager.__dict__, {"ClientSession": Mock(return_value=Mock( + __aexit__=AsyncMock(), + __aenter__=AsyncMock(return_value=AsyncMock( + get=AsyncMock(return_value=Mock(json=AsyncMock(return_value={"name": "v1.0.0"}))) + + ))))}): + self.assertIsNone(await self.manager.check_version()) + + async def test_check_version_newer(self) -> None: + """ + Check if a newer version is returned when available. + """ + self.manager.get_current_version = Mock(return_value="1.0.0") + with patch.dict(tribler.core.versioning.manager.__dict__, {"ClientSession": Mock(return_value=Mock( + __aexit__=AsyncMock(), + __aenter__=AsyncMock(return_value=AsyncMock( + get=AsyncMock(return_value=Mock(json=AsyncMock(return_value={"name": "1.0.1"}))) + + ))))}): + self.assertEqual("1.0.1", await self.manager.check_version()) + + async def test_check_version_newer_retry(self) -> None: + """ + Check if a newer version is returned when available from the backup url. + """ + self.manager.get_current_version = Mock(return_value="1.0.0") + with patch.dict(tribler.core.versioning.manager.__dict__, {"ClientSession": Mock(side_effect=[ + RuntimeError, + Mock( + __aexit__=AsyncMock(), + __aenter__=AsyncMock(return_value=AsyncMock( + get=AsyncMock(return_value=Mock(json=AsyncMock(return_value={"name": "1.0.1"}))) + + )))])}): + self.assertEqual("1.0.1", await self.manager.check_version()) + + def test_can_upgrade_upgraded(self) -> None: + """ + Check if we cannot upgrade an already upgraded version. + """ + with patch("os.path.isfile", lambda _: True): + self.assertFalse(self.manager.can_upgrade()) + + def test_can_upgrade_unsupported(self) -> None: + """ + Check if we cannot upgrade from an unsupported version. + """ + self.manager.get_versions = Mock(return_value=["0.0.0"]) + + with patch("os.path.isfile", lambda _: False): + self.assertFalse(self.manager.can_upgrade()) + + def test_can_upgrade_to_unsupported(self) -> None: + """ + Check if we cannot upgrade to an unsupported version. + """ + self.manager.get_versions = Mock(return_value=[FROM]) + self.manager.get_current_version = Mock(return_value="0.0.0") + + with patch("os.path.isfile", lambda _: False): + self.assertFalse(self.manager.can_upgrade()) + + def test_can_upgrade_to_current(self) -> None: + """ + Check if we can upgrade to the currently supported version. + """ + self.manager.get_versions = Mock(return_value=[FROM]) + self.manager.get_current_version = Mock(return_value=TO) + + with patch("os.path.isfile", lambda _: False): + self.assertEqual(FROM, self.manager.can_upgrade()) + + def test_can_upgrade_to_git(self) -> None: + """ + Check if we can upgrade to the git version. + """ + self.manager.get_versions = Mock(return_value=[FROM]) + self.manager.get_current_version = Mock(return_value=None) + + with patch("os.path.isfile", lambda _: False): + self.assertEqual(FROM, self.manager.can_upgrade()) diff --git a/src/tribler/tribler_config.py b/src/tribler/tribler_config.py index ef2dcb04e9..27b0e2b3e4 100644 --- a/src/tribler/tribler_config.py +++ b/src/tribler/tribler_config.py @@ -3,12 +3,15 @@ import json import logging import os +from importlib.metadata import PackageNotFoundError, version from json import JSONDecodeError from pathlib import Path from typing import TypedDict from ipv8.configuration import default as ipv8_default_config +from tribler.upgrade_script import TO + logger = logging.getLogger(__name__) @@ -58,6 +61,14 @@ class DatabaseConfig(TypedDict): enabled: bool +class VersioningConfig(TypedDict): + """ + Settings for the versioning component. + """ + + enabled: bool + + class DownloadDefaultsConfig(TypedDict): """ Settings for default downloads, used by libtorrent. @@ -151,6 +162,7 @@ class TriblerConfig(TypedDict): torrent_checker: TorrentCheckerConfig tunnel_community: TunnelCommunityConfig user_activity: UserActivityConfig + versioning: VersioningConfig state_dir: str memory_db: bool @@ -181,8 +193,8 @@ class TriblerConfig(TypedDict): socks_listen_ports=[0, 0, 0, 0, 0], port=0, proxy_type=0, - proxy_server='', - proxy_auth='', + proxy_server="", + proxy_auth="", max_connections_download=-1, max_download_rate=0, max_upload_rate=0, @@ -197,7 +209,7 @@ class TriblerConfig(TypedDict): number_hops=1, safeseeding_enabled=True, saveas=str(Path("~/Downloads").expanduser()), - seeding_mode='forever', + seeding_mode="forever", seeding_ratio=2.0, seeding_time=60, channel_download=False, @@ -207,8 +219,9 @@ class TriblerConfig(TypedDict): "torrent_checker": TorrentCheckerConfig(enabled=True), "tunnel_community": TunnelCommunityConfig(enabled=True, min_circuits=3, max_circuits=8), "user_activity": UserActivityConfig(enabled=True, max_query_history=500, health_check_interval=5.0), + "versioning": VersioningConfig(enabled=True), - "state_dir": str((Path(os.environ.get("APPDATA", "~")) / ".TriblerExperimental").expanduser().absolute()), + "state_dir": str((Path(os.environ.get("APPDATA", "~")) / ".Tribler").expanduser().absolute()), "memory_db": False } @@ -230,6 +243,12 @@ class TriblerConfig(TypedDict): if "file" in key_entry: key_entry["file"] = str(Path(DEFAULT_CONFIG["state_dir"]) / key_entry["file"]) +try: + version("tribler") + VERSION_SUBDIR = TO # We use the latest known version's directory NOT our own version +except PackageNotFoundError: + VERSION_SUBDIR = "git" + class TriblerConfigManager: """ @@ -277,6 +296,12 @@ def get(self, option: os.PathLike | str) -> dict | list | str | float | bool | N break return out + def get_version_state_dir(self) -> str: + """ + Get the state dir for our current version. + """ + return os.path.join(self.get("state_dir"), VERSION_SUBDIR) + def set(self, option: os.PathLike | str, value: dict | list | str | float | bool | None) -> None: """ Set a config option value based on the path-like descriptor. diff --git a/src/tribler/ui/public/locales/en_US.json b/src/tribler/ui/public/locales/en_US.json index 9767e4fa11..25b434b85e 100644 --- a/src/tribler/ui/public/locales/en_US.json +++ b/src/tribler/ui/public/locales/en_US.json @@ -124,5 +124,11 @@ "Socks5Auth": "Socks5 with authentication", "HTTP": "HTTP", "HTTPAuth": "HTTP with authentication", - "WebServerSettings": "Web server settings" + "WebServerSettings": "Web server settings", + "VersionCurrent": "Current version", + "VersionOld": "Old version", + "VersionAvailable": "NEW VERSION AVAILABLE", + "VersionUpgrading": "Upgrading", + "VersionImport": "IMPORT", + "VersionRemove": "REMOVE" } diff --git a/src/tribler/ui/public/locales/es_ES.json b/src/tribler/ui/public/locales/es_ES.json index ad8c9f9e2d..f412adc071 100644 --- a/src/tribler/ui/public/locales/es_ES.json +++ b/src/tribler/ui/public/locales/es_ES.json @@ -124,5 +124,11 @@ "Socks5Auth": "Socks5 con autenticación", "HTTP": "HTTP", "HTTPAuth": "HTTP con autenticación", - "WebServerSettings": "Configurações do servidor web" + "WebServerSettings": "Configurações do servidor web", + "VersionCurrent": "Versión actual", + "VersionOld": "Versión antigua", + "VersionAvailable": "NUEVA VERSIÓN DISPONIBLE", + "VersionUpgrading": "Actualización", + "VersionImport": "IMPORTAR", + "VersionRemove": "ELIMINAR" } diff --git a/src/tribler/ui/public/locales/pt_BR.json b/src/tribler/ui/public/locales/pt_BR.json index 49c8c50595..63d9d62012 100644 --- a/src/tribler/ui/public/locales/pt_BR.json +++ b/src/tribler/ui/public/locales/pt_BR.json @@ -116,5 +116,11 @@ "Socks5Auth": "Socks5 com autenticação", "HTTP": "HTTP", "HTTPAuth": "HTTP com autenticação", - "WebServerSettings": "Configurações do servidor web" + "WebServerSettings": "Configurações do servidor web", + "VersionCurrent": "Versão atual", + "VersionOld": "Versão antiga", + "VersionAvailable": "NOVA VERSÃO DISPONÍVEL", + "VersionUpgrading": "Atualizando", + "VersionImport": "IMPORTAR", + "VersionRemove": "REMOVER" } diff --git a/src/tribler/ui/public/locales/ru_RU.json b/src/tribler/ui/public/locales/ru_RU.json index 9b035c9ea3..8a82d12205 100644 --- a/src/tribler/ui/public/locales/ru_RU.json +++ b/src/tribler/ui/public/locales/ru_RU.json @@ -124,5 +124,11 @@ "Socks5Auth": "Socks5 с аутентификацией", "HTTP": "HTTP", "HTTPAuth": "HTTP с аутентификацией", - "WebServerSettings": "Настройки веб-сервера" + "WebServerSettings": "Настройки веб-сервера", + "VersionCurrent": "Текущая версия", + "VersionOld": "Старая версия", + "VersionAvailable": "ДОСТУПНА НОВАЯ ВЕРСИЯ", + "VersionUpgrading": "Обновление", + "VersionImport": "ИМПОРТ", + "VersionRemove": "УДАЛЯТЬ" } diff --git a/src/tribler/ui/public/locales/zh_CN.json b/src/tribler/ui/public/locales/zh_CN.json index 1b12506d5d..9750d293a7 100644 --- a/src/tribler/ui/public/locales/zh_CN.json +++ b/src/tribler/ui/public/locales/zh_CN.json @@ -123,5 +123,11 @@ "Socks5Auth": "带身份验证的 Socks5", "HTTP": "HTTP", "HTTPAuth": "带身份验证的 HTTP和", - "WebServerSettings": "网络服务器设置" + "WebServerSettings": "网络服务器设置", + "VersionCurrent": "当前版本", + "VersionOld": "旧版", + "VersionAvailable": "新版本上线", + "VersionUpgrading": "升级中", + "VersionImport": "进口", + "VersionRemove": "消除" } diff --git a/src/tribler/ui/src/Router.tsx b/src/tribler/ui/src/Router.tsx index a7324fafbb..56af2239bf 100644 --- a/src/tribler/ui/src/Router.tsx +++ b/src/tribler/ui/src/Router.tsx @@ -12,6 +12,7 @@ import Bandwidth from "./pages/Settings/Bandwidth"; import Seeding from "./pages/Settings/Seeding"; import Anonymity from "./pages/Settings/Anonymity"; import Debugging from "./pages/Settings/Debugging"; +import Versions from "./pages/Settings/Versions"; import GeneralDebug from "./pages/Debug/General"; import IPv8 from "./pages/Debug/IPv8"; import Tunnels from "./pages/Debug/Tunnels"; @@ -81,6 +82,10 @@ export const router = createHashRouter([ path: "settings/debugging", element: , }, + { + path: "settings/versions", + element: , + }, { path: "debug/general", element: , diff --git a/src/tribler/ui/src/config/menu.ts b/src/tribler/ui/src/config/menu.ts index 7d81019ece..4b354337f5 100644 --- a/src/tribler/ui/src/config/menu.ts +++ b/src/tribler/ui/src/config/menu.ts @@ -80,6 +80,10 @@ export const sideMenu: NavItemWithChildren[] = [ title: 'Debug', to: '/settings/debugging', }, + { + title: 'Versions', + to: '/settings/versions', + }, ], }, { diff --git a/src/tribler/ui/src/pages/Settings/Versions.tsx b/src/tribler/ui/src/pages/Settings/Versions.tsx new file mode 100644 index 0000000000..4a7968277b --- /dev/null +++ b/src/tribler/ui/src/pages/Settings/Versions.tsx @@ -0,0 +1,99 @@ +import { Suspense, useEffect, useState } from 'react'; +import { Button } from "@/components/ui/button"; +import { Label } from "@/components/ui/label"; +import { triblerService } from "@/services/tribler.service"; +import { useTranslation } from "react-i18next"; +import { RefreshCw } from 'lucide-react'; + +export default function Versions() { + const { t } = useTranslation(); + + const [version, setVersion] = useState(); + const [versions, setVersions] = useState(new Array()); + const [newVersion, setNewVersion] = useState(false); + const [canUpgrade, setCanUpgrade] = useState(false); + const [isUpgrading, setIsUpgrading] = useState(false); + + const clickedImport = (e, old_version?) => { + triblerService.performUpgrade(); + setIsUpgrading(true); + } + + const clickedRemove = (e, old_version?) => { + triblerService.removeVersion(old_version); + setVersions(versions.filter((v) => v != old_version)); + } + + const useMountEffect = (fun) => useEffect(fun, []) + useMountEffect(() => { + (async () => { + const version = await triblerService.getVersion(); + setVersion(version); + + var allVersions = await triblerService.getVersions(); + const versions = (allVersions.versions).filter((v) => v != allVersions.current); + setVersions(versions); + + const newVersion = await triblerService.getNewVersion(); + setNewVersion(newVersion); + + const canUpgrade = await triblerService.canUpgrade(); + setCanUpgrade(canUpgrade); + })(); + }); + useEffect(() => { + (async () => { + const isUpgrading = await triblerService.isUpgrading(); + setIsUpgrading(isUpgrading) + })(); + }); + + return ( +
+
+ + ...}> + + + }> + {newVersion ? : } + + + + + + + + + + { + versions.reduce((r, e) => r.push(e, e, e, e) && r, []).map(function(old_version, i){ + switch (i % 4){ + case 0: { + return () + } + case 1: { + return () // Blank column to outline with the data above + } + case 2: { + return ( + canUpgrade == old_version ? ( + isUpgrading ?
+ : ) + : + ) + } + default: { + return () + } + } + }) + } +
+
+ ) +} diff --git a/src/tribler/ui/src/services/tribler.service.ts b/src/tribler/ui/src/services/tribler.service.ts index 3e514757e9..62fa1997a9 100644 --- a/src/tribler/ui/src/services/tribler.service.ts +++ b/src/tribler/ui/src/services/tribler.service.ts @@ -172,6 +172,36 @@ export class TriblerService { return (await this.http.get(`/libtorrent/settings?hop=${hops}`)).data.settings; } + // Versions + async getVersion() { + return (await this.http.get(`/versioning/versions/current`)).data.version; + } + + async getNewVersion() { + const version_info_json = (await this.http.get(`/versioning/versions/check`)).data; + return (version_info_json.has_version ? version_info_json.new_version : false); + } + + async getVersions() { + return (await this.http.get(`/versioning/versions`)).data; + } + + async canUpgrade() { + return (await this.http.get(`/versioning/upgrade/available`)).data.can_upgrade; + } + + async isUpgrading() { + return (await this.http.get(`/versioning/upgrade/working`)).data.running; + } + + async performUpgrade() { + return (await this.http.post(`/versioning/upgrade`)) + } + + async removeVersion(version_str) { + return (await this.http.delete(`/versioning/versions/${version_str}`)) + } + // Misc async browseFiles(path: string, showFiles: boolean): Promise<{ current: string, paths: Path[] }> { diff --git a/src/tribler/upgrade_script.py b/src/tribler/upgrade_script.py new file mode 100644 index 0000000000..8a97dc1e13 --- /dev/null +++ b/src/tribler/upgrade_script.py @@ -0,0 +1,152 @@ +""" +UPDATE THIS FILE WHENEVER A NEW VERSION GETS RELEASED. + +Checklist: + + - Have you changed ``FROM`` to the previous version? + - Have you changed ``TO`` to the current version? + - Have you changed ``upgrade()`` to perform the upgrade? +""" +from __future__ import annotations + +import logging +import os +import shutil +import sqlite3 +from pathlib import Path +from typing import TYPE_CHECKING + +from configobj import ConfigObj + +if TYPE_CHECKING: + from tribler.tribler_config import TriblerConfigManager + +FROM: str = "7.14" +TO: str = "8.0" + + +def _copy_if_not_exist(src: str, dst: str) -> None: + """ + Copy a file if it does not exist. + """ + if os.path.exists(src) and not os.path.exists(dst): + shutil.copy(src, dst) + + +def _copy_if_exists(src: ConfigObj, src_path: str, dst: TriblerConfigManager, dst_path: str) -> None: + """ + Check if the src path is set and copy it into the dst if it is. + """ + out = src + for part in Path(src_path).parts: + if part in out: + out = out.get(part) + else: + return + dst.set(dst_path, out) + + +def _import_7_14_settings(src: str, dst: TriblerConfigManager) -> None: + """ + Read the file at the source path and import its settings. + """ + old = ConfigObj(src) + _copy_if_exists(old, "api/key", dst, "api/key") + _copy_if_exists(old, "api/http_enabled", dst, "api/http_enabled") + _copy_if_exists(old, "api/https_enabled", dst, "api/https_enabled") + _copy_if_exists(old, "ipv8/statistics", dst, "statistics") + _copy_if_exists(old, "libtorrent/port", dst, "libtorrent/port") + _copy_if_exists(old, "libtorrent/proxy_type", dst, "libtorrent/proxy_type") + _copy_if_exists(old, "libtorrent/proxy_server", dst, "libtorrent/proxy_server") + _copy_if_exists(old, "libtorrent/proxy_auth", dst, "libtorrent/proxy_auth") + _copy_if_exists(old, "libtorrent/max_connections_download", dst, "libtorrent/max_connections_download") + _copy_if_exists(old, "libtorrent/max_download_rate", dst, "libtorrent/max_download_rate") + _copy_if_exists(old, "libtorrent/max_upload_rate", dst, "libtorrent/max_upload_rate") + _copy_if_exists(old, "libtorrent/utp", dst, "libtorrent/utp") + _copy_if_exists(old, "libtorrent/dht", dst, "libtorrent/dht") + _copy_if_exists(old, "libtorrent/dht_readiness_timeout", dst, "libtorrent/dht_readiness_timeout") + _copy_if_exists(old, "libtorrent/upnp", dst, "libtorrent/upnp") + _copy_if_exists(old, "libtorrent/natpmp", dst, "libtorrent/natpmp") + _copy_if_exists(old, "libtorrent/lsd", dst, "libtorrent/lsd") + _copy_if_exists(old, "download_defaults/anonymity_enabled", dst, "libtorrent/download_defaults/anonymity_enabled") + _copy_if_exists(old, "download_defaults/number_hops", dst, "libtorrent/download_defaults/number_hops") + _copy_if_exists(old, "download_defaults/safeseeding_enabled", + dst, "libtorrent/download_defaults/safeseeding_enabled") + _copy_if_exists(old, "download_defaults/saveas", dst, "libtorrent/download_defaults/saveas") + _copy_if_exists(old, "download_defaults/seeding_mode", dst, "libtorrent/download_defaults/seeding_mode") + _copy_if_exists(old, "download_defaults/seeding_ratio", dst, "libtorrent/download_defaults/seeding_ratio") + _copy_if_exists(old, "download_defaults/seeding_time", dst, "libtorrent/download_defaults/seeding_time") + _copy_if_exists(old, "download_defaults/channel_download", dst, "libtorrent/download_defaults/channel_download") + _copy_if_exists(old, "download_defaults/add_download_to_channel", + dst, "libtorrent/download_defaults/add_download_to_channel") + _copy_if_exists(old, "popularity_community/enabled", dst, "content_discovery_community/enabled") + _copy_if_exists(old, "torrent_checking/enabled", dst, "torrent_checker/enabled") + _copy_if_exists(old, "tunnel_community/enabled", dst, "tunnel_community/enabled") + _copy_if_exists(old, "tunnel_community/min_circuits", dst, "tunnel_community/min_circuits") + _copy_if_exists(old, "tunnel_community/max_circuits", dst, "tunnel_community/max_circuits") + + +def _inject_7_14_tables(src_db: str, dst_db: str) -> None: + """ + Fetch data from the old database and attempt to insert it into a new one. + """ + # If the src does not exist, there is nothing to copy. + if not os.path.exists(src_db): + return + + # If the dst does not exist, simply copy the src over. + if not os.path.exists(dst_db): + shutil.copy(src_db, dst_db) + return + + # If they both exist, we have to inject data. + src_con = sqlite3.connect(os.path.abspath(src_db)) + insert_script = list(src_con.iterdump()) + src_con.close() + + from pony.orm import db_session + dst_con = sqlite3.connect(os.path.abspath(dst_db)) + with db_session: + for line in insert_script: + try: + dst_con.execute(line) + except sqlite3.DatabaseError as e: + logging.exception(e) + dst_con.commit() # This should be part of the dump already but just to be sure. + dst_con.close() + + +def upgrade(config: TriblerConfigManager, source: str, destination: str) -> None: + """ + Perform the upgrade from the previous version to the next version. + When complete, write a ".upgraded" file to the destination path. + + The files in ``source`` should be expected to be in the FROM format. + The files in ``destination`` should be expected to be in the TO format. + + Make sure to deal with corruption and/or missing files! + """ + # Step 1: import settings + os.makedirs(destination, exist_ok=True) + if os.path.exists(os.path.join(source, "triblerd.conf")): + _import_7_14_settings(os.path.join(source, "triblerd.conf"), config) + config.write() + + # Step 2: copy downloads + os.makedirs(os.path.join(destination, "dlcheckpoints"), exist_ok=True) + for checkpoint in os.listdir(os.path.join(source, "dlcheckpoints")): + _copy_if_not_exist(os.path.join(source, "dlcheckpoints", checkpoint), + os.path.join(destination, "dlcheckpoints", checkpoint)) + + # Step 3: Copy tribler db. + os.makedirs(os.path.join(destination, "sqlite"), exist_ok=True) + _inject_7_14_tables(os.path.join(source, "sqlite", "tribler.db"), + os.path.join(destination, "sqlite", "tribler.db")) + + # Step 4: Copy metadata db. + _inject_7_14_tables(os.path.join(source, "sqlite", "metadata.db"), + os.path.join(destination, "sqlite", "metadata.db")) + + # Step 5: Signal that our upgrade is done. + with open(os.path.join(config.get_version_state_dir(), ".upgraded"), "a"): + pass