diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index c02cbf966b..f26a16c59b 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -666,3 +666,12 @@ def __init__(self) -> None: class ChildNotFoundError(HTTPException): def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) + + +class BadArchiveContent(Exception): + """ + Exception raised when the archive file is corrupted (or unknown). + """ + + def __init__(self, message: str = "Unsupported archive format") -> None: + super().__init__(message) diff --git a/antarest/core/utils/archives.py b/antarest/core/utils/archives.py new file mode 100644 index 0000000000..a128cd4438 --- /dev/null +++ b/antarest/core/utils/archives.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. +import logging +import os +import shutil +import tempfile +import typing as t +import zipfile +from enum import Enum +from pathlib import Path + +import py7zr + +from antarest.core.exceptions import BadArchiveContent, ShouldNotHappenException + +logger = logging.getLogger(__name__) + + +class ArchiveFormat(str, Enum): + ZIP = ".zip" + SEVEN_ZIP = ".7z" + + +def is_archive_format(suffix: str) -> bool: + return suffix in {ArchiveFormat.ZIP, ArchiveFormat.SEVEN_ZIP} + + +def archive_dir( + src_dir_path: Path, + target_archive_path: Path, + remove_source_dir: bool = False, + archive_format: t.Optional[ArchiveFormat] = None, +) -> None: + if archive_format is not None and target_archive_path.suffix != archive_format: + raise ShouldNotHappenException( + f"Non matching archive format {archive_format} and target archive suffix {target_archive_path.suffix}" + ) + if target_archive_path.suffix == ArchiveFormat.SEVEN_ZIP: + with py7zr.SevenZipFile(target_archive_path, mode="w") as szf: + szf.writeall(src_dir_path, arcname="") + elif target_archive_path.suffix == ArchiveFormat.ZIP: + with zipfile.ZipFile(target_archive_path, mode="w", compression=zipfile.ZIP_DEFLATED, compresslevel=2) as zipf: + len_dir_path = len(str(src_dir_path)) + for root, _, files in os.walk(src_dir_path): + for file in files: + file_path = os.path.join(root, file) + zipf.write(file_path, file_path[len_dir_path:]) + else: + raise ShouldNotHappenException(f"Unsupported archive format {target_archive_path.suffix}") + if remove_source_dir: + shutil.rmtree(src_dir_path) + + +def unzip(dir_path: Path, zip_path: Path, remove_source_zip: bool = False) -> None: + with zipfile.ZipFile(zip_path, mode="r") as zipf: + zipf.extractall(dir_path) + if remove_source_zip: + zip_path.unlink() + + +def is_zip(path: Path) -> bool: + return path.name.endswith(".zip") + + +def read_in_zip( + zip_path: Path, + inside_zip_path: Path, + read: t.Callable[[t.Optional[Path]], None], +) -> None: + tmp_dir = None + try: + path, tmp_dir = extract_file_to_tmp_dir(zip_path, inside_zip_path) + read(path) + except KeyError: + logger.warning(f"{inside_zip_path} not found in {zip_path}") + read(None) + finally: + if tmp_dir is not None: + tmp_dir.cleanup() + + +def extract_archive(stream: t.BinaryIO, target_dir: Path) -> None: + """ + Extract a ZIP archive to a given destination. + + Args: + stream: The stream containing the archive. + target_dir: The directory where to extract the archive. + + Raises: + BadArchiveContent: If the archive is corrupted or in an unknown format. + """ + + # Read the first few bytes to identify the file format + file_format = stream.read(4) + stream.seek(0) + + if file_format[:4] == b"PK\x03\x04": + try: + with zipfile.ZipFile(stream) as zf: + zf.extractall(path=target_dir) + except zipfile.BadZipFile as error: + raise BadArchiveContent("Unsupported ZIP format") from error + + elif file_format[:2] == b"7z": + try: + with py7zr.SevenZipFile(stream, "r") as zf: + zf.extractall(target_dir) + except py7zr.exceptions.Bad7zFile as error: + raise BadArchiveContent("Unsupported 7z format") from error + + else: + raise BadArchiveContent + + +def extract_file_to_tmp_dir(archive_path: Path, inside_archive_path: Path) -> t.Tuple[Path, t.Any]: + str_inside_archive_path = str(inside_archive_path).replace("\\", "/") + tmp_dir = tempfile.TemporaryDirectory() + try: + if archive_path.suffix == ArchiveFormat.ZIP: + with zipfile.ZipFile(archive_path) as zip_obj: + zip_obj.extract(str_inside_archive_path, tmp_dir.name) + elif archive_path.suffix == ArchiveFormat.SEVEN_ZIP: + with py7zr.SevenZipFile(archive_path, mode="r") as szf: + szf.extract(path=tmp_dir.name, targets=[str_inside_archive_path]) + else: + raise ValueError(f"Unsupported archive format for {archive_path}") + except Exception as e: + logger.warning( + f"Failed to extract {str_inside_archive_path} in archive {archive_path}", + exc_info=e, + ) + tmp_dir.cleanup() + raise + path = Path(tmp_dir.name) / str_inside_archive_path + return path, tmp_dir + + +def read_file_from_archive(archive_path: Path, posix_path: str) -> str: + """ + Read a file from an archive. + + Args: + archive_path: the path to the archive file. + posix_path: path to the file inside the archive. + + Returns: + The content of the file as a string. + """ + + if archive_path.suffix == ArchiveFormat.ZIP: + with zipfile.ZipFile(archive_path) as zip_obj: + with zip_obj.open(posix_path) as f: + return f.read().decode("utf-8") + elif archive_path.suffix == ArchiveFormat.SEVEN_ZIP: + with py7zr.SevenZipFile(archive_path, mode="r") as szf: + file_text: str = szf.read([posix_path])[posix_path].read().decode("utf-8") + return file_text + else: + raise ValueError(f"Unsupported {archive_path.suffix} archive format for {archive_path}") + + +def extract_lines_from_archive(root: Path, posix_path: str) -> t.List[str]: + """ + Extract text lines from various types of files. + + Args: + root: 7zip or ZIP file containing the study. + posix_path: Relative path to the file to extract. + + Returns: + list of lines + """ + try: + text = read_file_from_archive(root, posix_path) + return text.splitlines(keepends=False) + # File not found in the archive + except KeyError: + return [] diff --git a/antarest/core/utils/utils.py b/antarest/core/utils/utils.py index 63576fc1a4..c748420549 100644 --- a/antarest/core/utils/utils.py +++ b/antarest/core/utils/utils.py @@ -14,16 +14,11 @@ import glob import http import logging -import os import re -import shutil -import tempfile import time import typing as t -import zipfile from pathlib import Path -import py7zr from fastapi import HTTPException from antarest.core.exceptions import ShouldNotHappenException @@ -65,49 +60,6 @@ def sanitize_string(string: str) -> str: return str(glob.escape(string)) -class BadArchiveContent(Exception): - """ - Exception raised when the archive file is corrupted (or unknown). - """ - - def __init__(self, message: str = "Unsupported archive format") -> None: - super().__init__(message) - - -def extract_zip(stream: t.BinaryIO, target_dir: Path) -> None: - """ - Extract a ZIP archive to a given destination. - - Args: - stream: The stream containing the archive. - target_dir: The directory where to extract the archive. - - Raises: - BadArchiveContent: If the archive is corrupted or in an unknown format. - """ - - # Read the first few bytes to identify the file format - file_format = stream.read(4) - stream.seek(0) - - if file_format[:4] == b"PK\x03\x04": - try: - with zipfile.ZipFile(stream) as zf: - zf.extractall(path=target_dir) - except zipfile.BadZipFile as error: - raise BadArchiveContent("Unsupported ZIP format") from error - - elif file_format[:2] == b"7z": - try: - with py7zr.SevenZipFile(stream, "r") as zf: - zf.extractall(target_dir) - except py7zr.exceptions.Bad7zFile as error: - raise BadArchiveContent("Unsupported 7z format") from error - - else: - raise BadArchiveContent - - def get_default_config_path() -> t.Optional[Path]: config = Path("config.yaml") if config.exists(): @@ -181,62 +133,6 @@ def concat_files_to_str(files: t.List[Path]) -> str: return concat_str -def zip_dir(dir_path: Path, zip_path: Path, remove_source_dir: bool = False) -> None: - with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED, compresslevel=2) as zipf: - len_dir_path = len(str(dir_path)) - for root, _, files in os.walk(dir_path): - for file in files: - file_path = os.path.join(root, file) - zipf.write(file_path, file_path[len_dir_path:]) - if remove_source_dir: - shutil.rmtree(dir_path) - - -def unzip(dir_path: Path, zip_path: Path, remove_source_zip: bool = False) -> None: - with zipfile.ZipFile(zip_path, mode="r") as zipf: - zipf.extractall(dir_path) - if remove_source_zip: - zip_path.unlink() - - -def is_zip(path: Path) -> bool: - return path.name.endswith(".zip") - - -def extract_file_to_tmp_dir(zip_path: Path, inside_zip_path: Path) -> t.Tuple[Path, t.Any]: - str_inside_zip_path = str(inside_zip_path).replace("\\", "/") - tmp_dir = tempfile.TemporaryDirectory() - try: - with zipfile.ZipFile(zip_path) as zip_obj: - zip_obj.extract(str_inside_zip_path, tmp_dir.name) - except Exception as e: - logger.warning( - f"Failed to extract {str_inside_zip_path} in zip {zip_path}", - exc_info=e, - ) - tmp_dir.cleanup() - raise - path = Path(tmp_dir.name) / inside_zip_path - return path, tmp_dir - - -def read_in_zip( - zip_path: Path, - inside_zip_path: Path, - read: t.Callable[[t.Optional[Path]], None], -) -> None: - tmp_dir = None - try: - path, tmp_dir = extract_file_to_tmp_dir(zip_path, inside_zip_path) - read(path) - except KeyError: - logger.warning(f"{inside_zip_path} not found in {zip_path}") - read(None) - finally: - if tmp_dir is not None: - tmp_dir.cleanup() - - def suppress_exception( callback: t.Callable[[], T], logger_: t.Callable[[Exception], None], diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py index 577ee031b5..d4f5c30da8 100644 --- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py +++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py @@ -34,7 +34,8 @@ from antarest.core.interfaces.eventbus import Event, EventType, IEventBus from antarest.core.model import PermissionInfo, PublicMode from antarest.core.requests import RequestParameters -from antarest.core.utils.utils import assert_this, unzip +from antarest.core.utils.archives import unzip +from antarest.core.utils.utils import assert_this from antarest.launcher.adapters.abstractlauncher import AbstractLauncher, LauncherCallbacks, LauncherInitException from antarest.launcher.adapters.log_manager import LogTailManager from antarest.launcher.model import JobStatus, LauncherParametersDTO, LogType, XpansionParametersDTO diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index c03902eecb..573aaf3b9d 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -34,8 +34,9 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.tasks.model import TaskResult, TaskType from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier +from antarest.core.utils.archives import ArchiveFormat, archive_dir, is_zip, read_in_zip from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.core.utils.utils import StopWatch, concat_files, concat_files_to_str, is_zip, read_in_zip, zip_dir +from antarest.core.utils.utils import StopWatch, concat_files, concat_files_to_str from antarest.launcher.adapters.abstractlauncher import LauncherCallbacks from antarest.launcher.adapters.factory_launcher import FactoryLauncher from antarest.launcher.extensions.adequacy_patch.extension import AdequacyPatchExtension @@ -548,7 +549,7 @@ def _import_output( if not output_is_zipped and job_launch_params.archive_output: logger.info("Re zipping output for transfer") zip_path = output_true_path.parent / f"{output_true_path.name}.zip" - zip_dir(output_true_path, zip_path=zip_path) + archive_dir(output_true_path, target_archive_path=zip_path, archive_format=ArchiveFormat.ZIP) stopwatch.log_elapsed(lambda x: logger.info(f"Zipped output for job {job_id} in {x}s")) final_output_path = zip_path or output_true_path @@ -600,7 +601,7 @@ def _download_fallback_output(self, job_id: str, params: RequestParameters) -> F def export_task(_: TaskUpdateNotifier) -> TaskResult: try: # - zip_dir(output_path, export_path) + archive_dir(output_path, export_path, archive_format=ArchiveFormat.ZIP) self.file_transfer_manager.set_ready(export_id) return TaskResult(success=True, message="") except Exception as e: diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index aae9a52cca..3a10ed09df 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -33,8 +33,9 @@ from antarest.core.serialization import from_json from antarest.core.tasks.model import TaskResult, TaskType from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier +from antarest.core.utils.archives import ArchiveFormat, archive_dir from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.core.utils.utils import StopWatch, zip_dir +from antarest.core.utils.utils import StopWatch from antarest.login.service import LoginService from antarest.matrixstore.exceptions import MatrixDataSetNotFound from antarest.matrixstore.model import ( @@ -469,7 +470,7 @@ def create_matrix_files(self, matrix_ids: t.Sequence[str], export_path: Path) -> else: # noinspection PyTypeChecker np.savetxt(filepath, array, delimiter="\t", fmt="%.18f") - zip_dir(Path(tmpdir), export_path) + archive_dir(Path(tmpdir), export_path, archive_format=ArchiveFormat.ZIP) stopwatch.log_elapsed(lambda x: logger.info(f"Matrix dataset exported (zipped mode) in {x}s")) return str(export_path) diff --git a/antarest/study/service.py b/antarest/study/service.py index 9e3ad837ed..28afa70858 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -58,6 +58,7 @@ from antarest.core.serialization import to_json from antarest.core.tasks.model import TaskListFilter, TaskResult, TaskStatus, TaskType from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier, noop_notifier +from antarest.core.utils.archives import ArchiveFormat, is_archive_format from antarest.core.utils.fastapi_sqlalchemy import db from antarest.core.utils.utils import StopWatch from antarest.login.model import Group @@ -168,7 +169,7 @@ def get_disk_usage(path: t.Union[str, Path]) -> int: """Calculate the total disk usage (in bytes) of a study in a compressed file or directory.""" path = Path(path) - if path.suffix.lower() in {".zip", ".7z"}: + if is_archive_format(path.suffix.lower()): return os.path.getsize(path) total_size = 0 with os.scandir(path) as it: @@ -1046,7 +1047,7 @@ def export_study( logger.info("Exporting study %s", uuid) export_name = f"Study {study.name} ({uuid}) export" export_file_download = self.file_transfer_manager.request_download( - f"{study.name}-{uuid}.zip", export_name, params.user + f"{study.name}-{uuid}{ArchiveFormat.ZIP}", export_name, params.user ) export_path = Path(export_file_download.path) export_id = export_file_download.id @@ -1111,7 +1112,7 @@ def export_output( logger.info(f"Exporting {output_uuid} from study {study_uuid}") export_name = f"Study output {study.name}/{output_uuid} export" export_file_download = self.file_transfer_manager.request_download( - f"{study.name}-{study_uuid}-{output_uuid}.zip", + f"{study.name}-{study_uuid}-{output_uuid}{ArchiveFormat.ZIP}", export_name, params.user, ) @@ -1210,7 +1211,7 @@ def delete_study(self, uuid: str, children: bool, params: RequestParameters) -> self.storage_service.get_storage(study).delete(study) else: if isinstance(study, RawStudy): - os.unlink(self.storage_service.raw_study_service.get_archive_path(study)) + os.unlink(self.storage_service.raw_study_service.find_archive_path(study)) logger.info("study %s deleted by user %s", uuid, params.get_user_id()) @@ -1469,7 +1470,7 @@ def import_output( remove_from_cache(cache=self.cache_service, root_id=study.id) logger.info("output added to study %s by user %s", uuid, params.get_user_id()) - if output_id and isinstance(output, Path) and output.suffix == ".zip" and auto_unzip: + if output_id and isinstance(output, Path) and output.suffix == ArchiveFormat.ZIP and auto_unzip: self.unarchive_output(uuid, output_id, not is_managed(study), params) return output_id @@ -2056,7 +2057,7 @@ def unarchive_task(notifier: TaskUpdateNotifier) -> TaskResult: self.storage_service.raw_study_service.unarchive(study_to_archive) study_to_archive.archived = False - os.unlink(self.storage_service.raw_study_service.get_archive_path(study_to_archive)) + os.unlink(self.storage_service.raw_study_service.find_archive_path(study_to_archive)) self.repository.save(study_to_archive) self.event_bus.push( Event( @@ -2446,7 +2447,7 @@ def unarchive_output_task( workspace = getattr(study, "workspace", DEFAULT_WORKSPACE_NAME) if workspace != DEFAULT_WORKSPACE_NAME: dest = Path(study.path) / "output" / output_id - src = Path(study.path) / "output" / f"{output_id}.zip" + src = Path(study.path) / "output" / f"{output_id}{ArchiveFormat.ZIP}" task_id = self.task_service.add_worker_task( TaskType.UNARCHIVE, f"unarchive_{workspace}", diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index 3de3c3a3aa..782eb5516f 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -18,12 +18,15 @@ from pathlib import Path from uuid import uuid4 +import py7zr + from antarest.core.config import Config from antarest.core.exceptions import BadOutputError, StudyOutputNotFoundError from antarest.core.interfaces.cache import CacheConstants, ICache from antarest.core.model import JSON, PublicMode from antarest.core.serialization import from_json -from antarest.core.utils.utils import StopWatch, extract_zip, unzip, zip_dir +from antarest.core.utils.archives import ArchiveFormat, archive_dir, extract_archive, unzip +from antarest.core.utils.utils import StopWatch from antarest.login.model import GroupDTO from antarest.study.common.studystorage import IStudyStorageService, T from antarest.study.model import ( @@ -245,20 +248,20 @@ def import_output( stopwatch = StopWatch() try: if isinstance(output, Path): - if output != path_output and output.suffix != ".zip": + if output != path_output and output.suffix != ArchiveFormat.ZIP: shutil.copytree(output, path_output / "imported") - elif output.suffix == ".zip": + elif output.suffix == ArchiveFormat.ZIP: is_zipped = True path_output.rmdir() - path_output = Path(str(path_output) + ".zip") + path_output = Path(str(path_output) + f"{ArchiveFormat.ZIP}") shutil.copyfile(output, path_output) else: - extract_zip(output, path_output) + extract_archive(output, path_output) stopwatch.log_elapsed(lambda elapsed_time: logger.info(f"Copied output for {study_id} in {elapsed_time}s")) fix_study_root(path_output) output_full_name = extract_output_name(path_output, output_name) - extension = ".zip" if is_zipped else "" + extension = f"{ArchiveFormat.ZIP}" if is_zipped else "" path_output = path_output.rename(Path(path_output.parent, output_full_name + extension)) data = self.get(metadata, f"output/{output_full_name}", 1, use_cache=False) @@ -271,14 +274,14 @@ def import_output( logger.error("Failed to import output", exc_info=e) shutil.rmtree(path_output, ignore_errors=True) if is_zipped: - Path(str(path_output) + ".zip").unlink(missing_ok=True) + Path(str(path_output) + f"{ArchiveFormat.ZIP}").unlink(missing_ok=True) output_full_name = None return output_full_name def export_study(self, metadata: T, target: Path, outputs: bool = True) -> Path: """ - Export and compress the study inside a ZIP file. + Export and compress the study inside a 7zip file. Args: metadata: Study metadata object. @@ -286,7 +289,7 @@ def export_study(self, metadata: T, target: Path, outputs: bool = True) -> Path: outputs: Flag to indicate whether to include the output folder inside the exportation. Returns: - The ZIP file containing the study files compressed inside. + The 7zip file containing the study files compressed inside. """ path_study = Path(metadata.path) with tempfile.TemporaryDirectory(dir=self.config.storage.tmp_dir) as tmpdir: @@ -294,8 +297,10 @@ def export_study(self, metadata: T, target: Path, outputs: bool = True) -> Path: tmp_study_path = Path(tmpdir) / "tmp_copy" self.export_study_flat(metadata, tmp_study_path, outputs) stopwatch = StopWatch() - zip_dir(tmp_study_path, target) - stopwatch.log_elapsed(lambda x: logger.info(f"Study {path_study} exported (zipped mode) in {x}s")) + archive_dir(tmp_study_path, target) + stopwatch.log_elapsed( + lambda x: logger.info(f"Study {path_study} exported ({target.suffix} format) in {x}s") + ) return target def export_output(self, metadata: T, output_id: str, target: Path) -> None: @@ -309,7 +314,7 @@ def export_output(self, metadata: T, output_id: str, target: Path) -> None: logger.info(f"Exporting output {output_id} from study {metadata.id}") path_output = Path(metadata.path) / "output" / output_id - path_output_zip = Path(metadata.path) / "output" / f"{output_id}.zip" + path_output_zip = Path(metadata.path) / "output" / f"{output_id}.{ArchiveFormat.ZIP}" if path_output_zip.exists(): shutil.copyfile(path_output_zip, target) @@ -319,7 +324,7 @@ def export_output(self, metadata: T, output_id: str, target: Path) -> None: raise StudyOutputNotFoundError() stopwatch = StopWatch() if not path_output_zip.exists(): - zip_dir(path_output, target) + archive_dir(path_output, target, archive_format=ArchiveFormat.ZIP) stopwatch.log_elapsed(lambda x: logger.info(f"Output {output_id} from study {metadata.path} exported in {x}s")) def _read_additional_data_from_files(self, file_study: FileStudy) -> StudyAdditionalData: @@ -332,10 +337,11 @@ def _read_additional_data_from_files(self, file_study: FileStudy) -> StudyAdditi def archive_study_output(self, study: T, output_id: str) -> bool: try: - zip_dir( + archive_dir( Path(study.path) / "output" / output_id, - Path(study.path) / "output" / f"{output_id}.zip", + Path(study.path) / "output" / f"{output_id}{ArchiveFormat.ZIP}", remove_source_dir=True, + archive_format=ArchiveFormat.ZIP, ) remove_from_cache(self.cache, study.id) return True @@ -347,10 +353,15 @@ def archive_study_output(self, study: T, output_id: str) -> bool: return False def unarchive_study_output(self, study: T, output_id: str, keep_src_zip: bool) -> bool: + if not (Path(study.path) / "output" / f"{output_id}{ArchiveFormat.ZIP}").exists(): + logger.warning( + f"Failed to archive study {study.name} output {output_id}. Maybe it's already unarchived", + ) + return False try: unzip( Path(study.path) / "output" / output_id, - Path(study.path) / "output" / f"{output_id}.zip", + Path(study.path) / "output" / f"{output_id}{ArchiveFormat.ZIP}", remove_source_zip=not keep_src_zip, ) remove_from_cache(self.cache, study.id) diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/files.py b/antarest/study/storage/rawstudy/model/filesystem/config/files.py index 73e67b1976..8f23fecd25 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/files.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/files.py @@ -20,10 +20,17 @@ from enum import Enum from pathlib import Path +import py7zr from antares.study.version import StudyVersion from antarest.core.model import JSON from antarest.core.serialization import from_json +from antarest.core.utils.archives import ( + ArchiveFormat, + extract_lines_from_archive, + is_archive_format, + read_file_from_archive, +) from antarest.study.model import STUDY_VERSION_8_1, STUDY_VERSION_8_6 from antarest.study.storage.rawstudy.ini_reader import IniReader from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( @@ -65,6 +72,31 @@ class FileType(Enum): MULTI_INI = "multi_ini" +def extract_data_from_archive( + root: Path, + posix_path: str, + reader: IniReader, +) -> t.Dict[str, t.Any]: + """ + Extract and process data from various types of files. + + Args: + root: 7zip or ZIP file containing the study. + posix_path: Relative path to the file to extract. + reader: IniReader object to use for processing the file. + + Returns: + The content of the file, processed according to its type: + - SIMPLE_INI or MULTI_INI: dictionary of keys/values + """ + try: + file_text = read_file_from_archive(root, posix_path) + buffer = io.StringIO(file_text) + return reader.read(buffer) + except KeyError: # File not found in the archive + return {} + + def build(study_path: Path, study_id: str, output_path: t.Optional[Path] = None) -> "FileStudyTreeConfig": """ Extracts data from the filesystem to build a study config. @@ -78,10 +110,10 @@ def build(study_path: Path, study_id: str, output_path: t.Optional[Path] = None) Returns: An instance of `FileStudyTreeConfig` filled with the study data. """ - is_zip_file = study_path.suffix.lower() == ".zip" + is_archive = is_archive_format(study_path.suffix.lower()) # Study directory to use if the study is compressed - study_dir = study_path.with_suffix("") if is_zip_file else study_path + study_dir = study_path.with_suffix("") if is_archive else study_path (sns, asi, enr_modelling) = _parse_parameters(study_path) outputs_dir: Path = output_path or study_path / "output" @@ -98,52 +130,10 @@ def build(study_path: Path, study_id: str, output_path: t.Optional[Path] = None) store_new_set=sns, archive_input_series=asi, enr_modelling=enr_modelling, - zip_path=study_path if is_zip_file else None, + archive_path=study_path if is_archive else None, ) -def _extract_text_from_zip(root: Path, posix_path: str) -> t.Sequence[str]: - """ - Extracts text from a file inside a ZIP archive and returns it as a list of lines. - - Args: - root: The path to the ZIP archive. - posix_path: The relative path to the file inside the ZIP archive. - - Returns: - A list of lines in the file. If the file is not found, an empty list is returned. - """ - with zipfile.ZipFile(root) as zf: - try: - with zf.open(posix_path) as f: - text = f.read().decode("utf-8") - return text.splitlines(keepends=False) - except KeyError: - return [] - - -def _extract_ini_from_zip(root: Path, posix_path: str, multi_ini_keys: t.Sequence[str] = ()) -> t.Mapping[str, t.Any]: - """ - Extracts data from an INI file inside a ZIP archive and returns it as a dictionary. - - Args: - root: The path to the ZIP archive. - posix_path: The relative path to the file inside the ZIP archive. - multi_ini_keys: List of keys to use for multi INI files. - - Returns: - A dictionary of keys/values in the INI file. If the file is not found, an empty dictionary is returned. - """ - reader = IniReader(multi_ini_keys) - with zipfile.ZipFile(root) as zf: - try: - with zf.open(posix_path) as f: - buffer = io.StringIO(f.read().decode("utf-8")) - return reader.read(buffer) - except KeyError: - return {} - - def _extract_data_from_file( root: Path, inside_root_path: Path, @@ -165,14 +155,14 @@ def _extract_data_from_file( - SIMPLE_INI or MULTI_INI: dictionary of keys/values """ - is_zip_file: bool = root.suffix.lower() == ".zip" + is_archive: bool = is_archive_format(root.suffix.lower()) posix_path: str = inside_root_path.as_posix() output_data_path = root / inside_root_path if file_type == FileType.TXT: # Parse the file as a list of lines, return an empty list if missing. - if is_zip_file: - return _extract_text_from_zip(root, posix_path) + if is_archive: + return extract_lines_from_archive(root, posix_path) else: try: return output_data_path.read_text(encoding="utf-8").splitlines(keepends=False) @@ -181,11 +171,11 @@ def _extract_data_from_file( elif file_type in {FileType.MULTI_INI, FileType.SIMPLE_INI}: # Parse the file as a dictionary of keys/values, return an empty dictionary if missing. - if is_zip_file: - return _extract_ini_from_zip(root, posix_path, multi_ini_keys=multi_ini_keys) + reader = IniReader(multi_ini_keys) + if is_archive: + return extract_data_from_archive(root, posix_path, reader) else: try: - reader = IniReader(multi_ini_keys) return reader.read(output_data_path) except FileNotFoundError: return {} diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/model.py b/antarest/study/storage/rawstudy/model/filesystem/config/model.py index f707ab4027..64c5be0ee7 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/model.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/model.py @@ -174,7 +174,7 @@ def __init__( archive_input_series: t.Optional[t.List[str]] = None, enr_modelling: str = str(EnrModelling.AGGREGATED), cache: t.Optional[t.Dict[str, t.List[str]]] = None, - zip_path: t.Optional[Path] = None, + archive_path: t.Optional[Path] = None, ): self.study_path = study_path self.path = path @@ -189,13 +189,13 @@ def __init__( self.archive_input_series = archive_input_series or [] self.enr_modelling = enr_modelling self.cache = cache or {} - self.zip_path = zip_path + self.archive_path = archive_path def next_file(self, name: str, is_output: bool = False) -> "FileStudyTreeConfig": if is_output and name in self.outputs and self.outputs[name].archived: - zip_path: t.Optional[Path] = self.path / f"{name}.zip" + archive_path: t.Optional[Path] = self.path / f"{name}.zip" else: - zip_path = self.zip_path + archive_path = self.archive_path return FileStudyTreeConfig( study_path=self.study_path, @@ -211,7 +211,7 @@ def next_file(self, name: str, is_output: bool = False) -> "FileStudyTreeConfig" archive_input_series=self.archive_input_series, enr_modelling=self.enr_modelling, cache=self.cache, - zip_path=zip_path, + archive_path=archive_path, ) def at_file(self, filepath: Path) -> "FileStudyTreeConfig": @@ -317,7 +317,7 @@ class FileStudyTreeConfigDTO(AntaresBaseModel): store_new_set: bool = False archive_input_series: t.List[str] = list() enr_modelling: str = str(EnrModelling.AGGREGATED) - zip_path: t.Optional[Path] = None + archive_path: t.Optional[Path] = None @field_serializer("version") def serialize_version(self, version: StudyVersion) -> int: @@ -344,7 +344,7 @@ def from_build_config( store_new_set=config.store_new_set, archive_input_series=config.archive_input_series, enr_modelling=config.enr_modelling, - zip_path=config.zip_path, + archive_path=config.archive_path, ) def to_build_config(self) -> FileStudyTreeConfig: @@ -361,5 +361,5 @@ def to_build_config(self) -> FileStudyTreeConfig: store_new_set=self.store_new_set, archive_input_series=self.archive_input_series, enr_modelling=self.enr_modelling, - zip_path=self.zip_path, + archive_path=self.archive_path, ) diff --git a/antarest/study/storage/rawstudy/model/filesystem/ini_file_node.py b/antarest/study/storage/rawstudy/model/filesystem/ini_file_node.py index bf9eecc574..55a8645ab6 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/ini_file_node.py +++ b/antarest/study/storage/rawstudy/model/filesystem/ini_file_node.py @@ -20,9 +20,11 @@ import zipfile from pathlib import Path +import py7zr import pydantic_core from filelock import FileLock +from antarest.core.exceptions import ShouldNotHappenException from antarest.core.model import JSON, SUB_JSON from antarest.core.serialization import from_json from antarest.study.storage.rawstudy.ini_reader import IniReader, IReader @@ -105,11 +107,18 @@ def _get( url = url or [] kwargs = self._get_filtering_kwargs(url) - if self.config.zip_path: - with zipfile.ZipFile(self.config.zip_path, mode="r") as zipped_folder: - inside_zip_path = self.config.path.relative_to(self.config.zip_path.with_suffix("")).as_posix() - with io.TextIOWrapper(zipped_folder.open(inside_zip_path)) as f: - data = self.reader.read(f, **kwargs) + if self.config.archive_path: + inside_archive_path = self.config.path.relative_to(self.config.archive_path.with_suffix("")).as_posix() + if self.config.archive_path.suffix == ".zip": + with zipfile.ZipFile(self.config.archive_path, mode="r") as zipped_folder: + with io.TextIOWrapper(zipped_folder.open(inside_archive_path)) as f: + data = self.reader.read(f, **kwargs) + elif self.config.archive_path.suffix == ".7z": + with py7zr.SevenZipFile(self.config.archive_path, mode="r") as zipped_folder: + with io.TextIOWrapper(zipped_folder.read([inside_archive_path])[inside_archive_path]) as f: + data = self.reader.read(f, **kwargs) + else: + raise ShouldNotHappenException(f"Unsupported archived study format: {self.config.archive_path.suffix}") else: data = self.reader.read(self.path, **kwargs) diff --git a/antarest/study/storage/rawstudy/model/filesystem/inode.py b/antarest/study/storage/rawstudy/model/filesystem/inode.py index e380442651..f88903c729 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/inode.py +++ b/antarest/study/storage/rawstudy/model/filesystem/inode.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar from antarest.core.exceptions import ShouldNotHappenException, WritingInsideZippedFileException -from antarest.core.utils.utils import extract_file_to_tmp_dir +from antarest.core.utils.archives import extract_file_to_tmp_dir from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig G = TypeVar("G") @@ -148,17 +148,17 @@ def _extract_file_to_tmp_dir( The actual path of the extracted file the tmp_dir object which MUST be cleared after use of the file """ - if self.config.zip_path is None: + if self.config.archive_path is None: raise ShouldNotHappenException() - inside_zip_path = str(self.config.path)[len(str(self.config.zip_path)[:-4]) + 1 :] - if self.config.zip_path: - return extract_file_to_tmp_dir(self.config.zip_path, Path(inside_zip_path)) + inside_archive_path = self.config.path.relative_to(self.config.archive_path.parent / self.config.study_id) + if self.config.archive_path: + return extract_file_to_tmp_dir(self.config.archive_path, inside_archive_path) else: raise ShouldNotHappenException() def _assert_not_in_zipped_file(self) -> None: """Prevents writing inside a zip file""" - if self.config.zip_path: + if self.config.archive_path: raise WritingInsideZippedFileException("Trying to save inside a zipped file") diff --git a/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py b/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py index 49b83079d6..6520802d90 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py +++ b/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py @@ -47,19 +47,19 @@ def _get_real_file_path( self, ) -> t.Tuple[Path, t.Any]: tmp_dir = None - if self.config.zip_path: + if self.config.archive_path: path, tmp_dir = self._extract_file_to_tmp_dir() else: path = self.config.path return path, tmp_dir def file_exists(self) -> bool: - if self.config.zip_path: - str_zipped_path = str(self.config.zip_path) + if self.config.archive_path: + str_zipped_path = str(self.config.archive_path) inside_zip_path = str(self.config.path)[len(str_zipped_path[:-4]) + 1 :] str_inside_zip_path = str(inside_zip_path).replace("\\", "/") if str_zipped_path not in LazyNode.ZIP_FILELIST_CACHE: - with ZipFile(file=self.config.zip_path) as zip_file: + with ZipFile(file=self.config.archive_path) as zip_file: LazyNode.ZIP_FILELIST_CACHE[str_zipped_path] = SimpleCache( value=zip_file.namelist(), expiration_date=datetime.utcnow() + timedelta(hours=2), diff --git a/antarest/study/storage/rawstudy/model/filesystem/matrix/matrix.py b/antarest/study/storage/rawstudy/model/filesystem/matrix/matrix.py index 364a5e3524..9c421f59e1 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/matrix/matrix.py +++ b/antarest/study/storage/rawstudy/model/filesystem/matrix/matrix.py @@ -84,7 +84,7 @@ def normalize(self) -> None: Raises: DenormalizationException: if the original matrix retrieval fails. """ - if self.get_link_path().exists() or self.config.zip_path: + if self.get_link_path().exists() or self.config.archive_path: return matrix = self.parse() diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/areas/list.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/areas/list.py index 5112bd7756..04564300c6 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/areas/list.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/areas/list.py @@ -10,14 +10,17 @@ # # This file is part of the Antares project. -from typing import List, Optional +import typing as t +from antarest.core.utils.archives import extract_lines_from_archive from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.context import ContextServer from antarest.study.storage.rawstudy.model.filesystem.inode import INode +AREAS_LIST_RELATIVE_PATH = "input/areas/list.txt" -class InputAreasList(INode[List[str], List[str], List[str]]): + +class InputAreasList(INode[t.List[str], t.List[str], t.List[str]]): def normalize(self) -> None: pass # no external store in this node @@ -30,42 +33,40 @@ def __init__(self, context: ContextServer, config: FileStudyTreeConfig): def get_node( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, depth: int = -1, expanded: bool = False, formatted: bool = True, - ) -> INode[List[str], List[str], List[str]]: + ) -> INode[t.List[str], t.List[str], t.List[str]]: return self def get( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, depth: int = -1, expanded: bool = False, formatted: bool = True, - ) -> List[str]: - if self.config.zip_path: - path, tmp_dir = self._extract_file_to_tmp_dir() - lines = path.read_text().split("\n") - tmp_dir.cleanup() + ) -> t.List[str]: + if self.config.archive_path: + lines = extract_lines_from_archive(self.config.archive_path, AREAS_LIST_RELATIVE_PATH) else: lines = self.config.path.read_text().split("\n") return [l.strip() for l in lines if l.strip()] - def save(self, data: List[str], url: Optional[List[str]] = None) -> None: + def save(self, data: t.List[str], url: t.Optional[t.List[str]] = None) -> None: self._assert_not_in_zipped_file() self.config.path.write_text("\n".join(data)) - def delete(self, url: Optional[List[str]] = None) -> None: + def delete(self, url: t.Optional[t.List[str]] = None) -> None: if self.config.path.exists(): self.config.path.unlink() def check_errors( self, - data: List[str], - url: Optional[List[str]] = None, + data: t.List[str], + url: t.Optional[t.List[str]] = None, raising: bool = False, - ) -> List[str]: + ) -> t.List[str]: errors = [] if any(a not in data for a in [area.name for area in self.config.areas.values()]): errors.append(f"list.txt should have {self.config.area_names()} nodes but given {data}") diff --git a/antarest/study/storage/rawstudy/raw_study_service.py b/antarest/study/storage/rawstudy/raw_study_service.py index 1dca10323f..e90b9b33ed 100644 --- a/antarest/study/storage/rawstudy/raw_study_service.py +++ b/antarest/study/storage/rawstudy/raw_study_service.py @@ -18,7 +18,6 @@ from pathlib import Path from threading import Thread from uuid import uuid4 -from zipfile import ZipFile from antares.study.version import StudyVersion @@ -27,7 +26,7 @@ from antarest.core.interfaces.cache import ICache from antarest.core.model import PublicMode from antarest.core.requests import RequestParameters -from antarest.core.utils.utils import extract_zip +from antarest.core.utils.archives import ArchiveFormat, extract_archive from antarest.study.model import DEFAULT_WORKSPACE_NAME, Patch, RawStudy, Study, StudyAdditionalData from antarest.study.storage.abstract_storage_service import AbstractStorageService from antarest.study.storage.patch_service import PatchService @@ -75,14 +74,17 @@ def __init__( ) self.cleanup_thread.start() - def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: t.Optional[bool] = False) -> None: + def update_from_raw_meta( + self, metadata: RawStudy, fallback_on_default: t.Optional[bool] = False, study_path: t.Optional[Path] = None + ) -> None: """ Update metadata from study raw metadata Args: metadata: study fallback_on_default: use default values in case of failure + study_path: optional study path """ - path = self.get_study_path(metadata) + path = study_path or self.get_study_path(metadata) study = self.study_factory.create_from_fs(path, study_id="") try: raw_meta = study.tree.get(["study", "antares"]) @@ -142,21 +144,19 @@ def update_name_and_version_from_raw_meta(self, metadata: RawStudy) -> bool: def exists(self, study: RawStudy) -> bool: """ - Check study exist. + Check if the study exists in the filesystem. + Args: - study: study + study: The study to check. Returns: true if study presents in disk, false else. - """ - path = self.get_study_path(study) - if study.archived: - path = self.get_archive_path(study) - zf = ZipFile(path, "r") - return str("study.antares") in zf.namelist() + archive_path = self.find_archive_path(study) + return archive_path.is_file() - return (path / "study.antares").is_file() + path = self.get_study_path(study) + return path.joinpath("study.antares").is_file() def get_raw( self, @@ -323,19 +323,19 @@ def import_study(self, metadata: RawStudy, stream: t.BinaryIO) -> Study: Raises: BadArchiveContent: If the archive is corrupted or in an unknown format. """ - path_study = Path(metadata.path) - path_study.mkdir() + study_path = Path(metadata.path) + study_path.mkdir() try: - extract_zip(stream, path_study) - fix_study_root(path_study) - self.update_from_raw_meta(metadata) + extract_archive(stream, study_path) + fix_study_root(study_path) + self.update_from_raw_meta(metadata, study_path=study_path) except Exception: - shutil.rmtree(path_study) + shutil.rmtree(study_path) raise - metadata.path = str(path_study) + metadata.path = str(study_path) return metadata def export_study_flat( @@ -384,7 +384,7 @@ def set_reference_output(self, study: RawStudy, output_id: str, status: bool) -> remove_from_cache(self.cache, study.id) def archive(self, study: RawStudy) -> Path: - archive_path = self.get_archive_path(study) + archive_path = self.config.storage.archive_dir.joinpath(f"{study.id}{ArchiveFormat.SEVEN_ZIP}") new_study_path = self.export_study(study, archive_path) shutil.rmtree(study.path) remove_from_cache(cache=self.cache, root_id=study.id) @@ -402,11 +402,25 @@ def unarchive(self, study: RawStudy) -> None: Raises: BadArchiveContent: If the archive is corrupted or in an unknown format. """ - with open(self.get_archive_path(study), mode="rb") as fh: + with open(self.find_archive_path(study), mode="rb") as fh: self.import_study(study, fh) - def get_archive_path(self, study: RawStudy) -> Path: - return Path(self.config.storage.archive_dir / f"{study.id}.zip") + def find_archive_path(self, study: RawStudy) -> Path: + """ + Fetch for archive path of a study if it exists else raise an incorrectly archived study. + + Args: + study: The study to get the archive path for. + + Returns: + The full path of the archive file (zip or 7z). + """ + archive_dir: Path = self.config.storage.archive_dir + for suffix in list(ArchiveFormat): + path = archive_dir.joinpath(f"{study.id}{suffix}") + if path.is_file(): + return path + raise FileNotFoundError(f"Study {study.id} archiving process is corrupted (no archive file found).") def get_study_path(self, metadata: Study) -> Path: """ @@ -418,7 +432,7 @@ def get_study_path(self, metadata: Study) -> Path: """ if metadata.archived: - return self.get_archive_path(metadata) + return self.find_archive_path(metadata) return Path(metadata.path) def initialize_additional_data(self, raw_study: RawStudy) -> bool: diff --git a/antarest/study/storage/utils.py b/antarest/study/storage/utils.py index d1ab8828a0..85ba04b630 100644 --- a/antarest/study/storage/utils.py +++ b/antarest/study/storage/utils.py @@ -31,6 +31,7 @@ from antarest.core.model import PermissionInfo, StudyPermissionType from antarest.core.permissions import check_permission from antarest.core.requests import UserHasNotPermissionError +from antarest.core.utils.archives import is_archive_format from antarest.core.utils.utils import StopWatch from antarest.study.model import ( DEFAULT_WORKSPACE_NAME, @@ -79,7 +80,7 @@ def fix_study_root(study_path: Path) -> None: study_path: the study initial root path """ # TODO: what if it is a zipped output ? - if study_path.suffix == ".zip": + if is_archive_format(study_path.suffix): return None if not study_path.is_dir(): diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index 17da76fffb..72bd1804c7 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -431,6 +431,7 @@ def walk_children( raw_study_accepted=True, ) children = self.repository.get_children(parent_id=parent_id) + # TODO : the bottom_first should always be True, otherwise we will have an infinite loop if not bottom_first: fun(study) for child in children: @@ -438,8 +439,8 @@ def walk_children( if bottom_first: fun(study) - def get_variants_parents(self, id: str, params: RequestParameters) -> t.List[StudyMetadataDTO]: - output_list: t.List[StudyMetadataDTO] = self._get_variants_parents(id, params) + def get_variants_parents(self, study_id: str, params: RequestParameters) -> t.List[StudyMetadataDTO]: + output_list: t.List[StudyMetadataDTO] = self._get_variants_parents(study_id, params) if output_list: output_list = output_list[1:] return output_list diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index 41398dee6a..6fe9758186 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -22,13 +22,13 @@ from pydantic import NonNegativeInt from antarest.core.config import Config -from antarest.core.exceptions import BadZipBinary +from antarest.core.exceptions import BadArchiveContent, BadZipBinary from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager from antarest.core.jwt import JWTUser from antarest.core.model import PublicMode from antarest.core.requests import RequestParameters, UserHasNotPermissionError -from antarest.core.utils.utils import BadArchiveContent, sanitize_string, sanitize_uuid +from antarest.core.utils.utils import sanitize_string, sanitize_uuid from antarest.core.utils.web import APITag from antarest.login.auth import Auth from antarest.study.model import ( diff --git a/antarest/worker/archive_worker.py b/antarest/worker/archive_worker.py index a488d42d0d..6ce01acf85 100644 --- a/antarest/worker/archive_worker.py +++ b/antarest/worker/archive_worker.py @@ -17,7 +17,8 @@ from antarest.core.interfaces.eventbus import IEventBus from antarest.core.serialization import AntaresBaseModel from antarest.core.tasks.model import TaskResult -from antarest.core.utils.utils import StopWatch, unzip +from antarest.core.utils.archives import unzip +from antarest.core.utils.utils import StopWatch from antarest.worker.worker import AbstractWorker, WorkerTaskCommand logger = logging.getLogger(__name__) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index d57ac68d21..e6022f139e 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -16,7 +16,8 @@ import pytest from antarest.core.exceptions import ShouldNotHappenException -from antarest.core.utils.utils import concat_files, concat_files_to_str, read_in_zip, retry, suppress_exception +from antarest.core.utils.archives import read_in_zip +from antarest.core.utils.utils import concat_files, concat_files_to_str, retry, suppress_exception def test_retry() -> None: diff --git a/tests/core/utils/test_extract_zip.py b/tests/core/utils/test_extract_zip.py index dfbde6bf6d..5d24e6c772 100644 --- a/tests/core/utils/test_extract_zip.py +++ b/tests/core/utils/test_extract_zip.py @@ -17,10 +17,11 @@ import py7zr import pytest -from antarest.core.utils.utils import BadArchiveContent, extract_zip +from antarest.core.exceptions import BadArchiveContent +from antarest.core.utils.archives import extract_archive -class TestExtractZip: +class TestExtractArchive: """ Test the `extract_zip` function. """ @@ -33,7 +34,7 @@ def test_extract_zip__with_zip(self, tmp_path: Path): # Then, call the function with open(zip_path, mode="rb") as stream: - extract_zip(stream, tmp_path) + extract_archive(stream, tmp_path) # Finally, check the result assert (tmp_path / "test.txt").read_text() == "Hello world!" @@ -46,7 +47,7 @@ def test_extract_zip__with_7z(self, tmp_path: Path): # Then, call the function with open(zip_path, mode="rb") as stream: - extract_zip(stream, tmp_path) + extract_archive(stream, tmp_path) # Finally, check the result assert (tmp_path / "test.txt").read_text() == "Hello world!" @@ -55,22 +56,22 @@ def test_extract_zip__empty_file(self): stream = io.BytesIO(b"") with pytest.raises(BadArchiveContent): - extract_zip(stream, Path("dummy/path")) + extract_archive(stream, Path("dummy/path")) def test_extract_zip__corrupted_zip(self): stream = io.BytesIO(b"PK\x03\x04 BLURP") with pytest.raises(BadArchiveContent): - extract_zip(stream, Path("dummy/path")) + extract_archive(stream, Path("dummy/path")) def test_extract_zip__corrupted_7z(self): stream = io.BytesIO(b"7z BLURP") with pytest.raises(BadArchiveContent): - extract_zip(stream, Path("dummy/path")) + extract_archive(stream, Path("dummy/path")) def test_extract_zip__unknown_format(self): stream = io.BytesIO(b"ZORRO") with pytest.raises(BadArchiveContent): - extract_zip(stream, Path("dummy/path")) + extract_archive(stream, Path("dummy/path")) diff --git a/tests/integration/assets/STA-mini.7z b/tests/integration/assets/STA-mini.7z index 4bae2fe52f..b861b0c5a6 100644 Binary files a/tests/integration/assets/STA-mini.7z and b/tests/integration/assets/STA-mini.7z differ diff --git a/tests/integration/raw_studies_blueprint/test_fetch_raw_data.py b/tests/integration/raw_studies_blueprint/test_fetch_raw_data.py index b67fffe9e2..d2c0a5540b 100644 --- a/tests/integration/raw_studies_blueprint/test_fetch_raw_data.py +++ b/tests/integration/raw_studies_blueprint/test_fetch_raw_data.py @@ -25,6 +25,7 @@ from antarest.core.utils.fastapi_sqlalchemy import db from antarest.study.model import RawStudy, Study from tests.integration.raw_studies_blueprint.assets import ASSETS_DIR +from tests.integration.utils import wait_for @pytest.mark.integration_test @@ -327,3 +328,53 @@ def test_delete_raw(client: TestClient, user_access_token: str, internal_study_i assert res.status_code == 403 assert res.json()["exception"] == "FileDeletionNotAllowed" assert "the given path doesn't exist" in res.json()["description"] + + +def test_retrieve_from_archive(client: TestClient, user_access_token: str) -> None: + # client headers + client.headers = {"Authorization": f"Bearer {user_access_token}"} + + # create a new study + res = client.post("/v1/studies?name=MyStudy") + assert res.status_code == 201 + + # get the study id + study_id = res.json() + + # add a new area to the study + res = client.post( + f"/v1/studies/{study_id}/areas", + json={ + "name": "area 1", + "type": "AREA", + "metadata": {"country": "FR", "tags": ["a"]}, + }, + ) + assert res.status_code == 200, res.json() + + # archive the study + res = client.put(f"/v1/studies/{study_id}/archive") + assert res.status_code == 200 + task_id = res.json() + wait_for( + lambda: client.get( + f"/v1/tasks/{task_id}", + ).json()["status"] + == 3 + ) + + # retrieve a `Desktop.ini` file from inside the archive + rel_path = "Desktop" + res = client.get( + f"/v1/studies/{study_id}/raw", + params={"path": rel_path, "formatted": True}, + ) + assert res.status_code == 200 + + # retrieve a `study.antares` file from inside the archive + rel_path = "study" + res = client.get( + f"/v1/studies/{study_id}/raw", + params={"path": rel_path, "formatted": True}, + ) + assert res.status_code == 200 diff --git a/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json b/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json index 1e0f3ada52..c0d5635c70 100644 --- a/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json +++ b/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json @@ -1067,5 +1067,5 @@ "store_new_set": true, "archive_input_series": [], "enr_modelling": "aggregated", - "zip_path": null + "archive_path": null } \ No newline at end of file diff --git a/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json b/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json index 7e449747e4..2d3c3baf2a 100644 --- a/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json +++ b/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json @@ -983,5 +983,5 @@ "store_new_set": true, "archive_input_series": [], "enr_modelling": "aggregated", - "zip_path": null + "archive_path": null } \ No newline at end of file diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index e3d614ea9d..da7b312e9c 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -254,7 +254,7 @@ def test_main(client: TestClient, admin_access_token: str) -> None: headers={"Authorization": f'Bearer {fred_credentials["refresh_token"]}'}, ) fred_credentials = res.json() - res = client.post( + client.post( f"/v1/studies?name=bar&groups={group_id}", headers={"Authorization": f'Bearer {george_credentials["access_token"]}'}, ) @@ -1554,7 +1554,7 @@ def test_archive(client: TestClient, admin_access_token: str, tmp_path: Path, in res = client.get(f"/v1/studies/{study_id}") assert res.json()["archived"] - assert (tmp_path / "archive_dir" / f"{study_id}.zip").exists() + assert (tmp_path / "archive_dir" / f"{study_id}.7z").exists() res = client.put(f"/v1/studies/{study_id}/unarchive") @@ -1563,12 +1563,12 @@ def test_archive(client: TestClient, admin_access_token: str, tmp_path: Path, in lambda: client.get( f"/v1/tasks/{task_id}", ).json()["status"] - == 3 + == 3, ) res = client.get(f"/v1/studies/{study_id}") assert not res.json()["archived"] - assert not (tmp_path / "archive_dir" / f"{study_id}.zip").exists() + assert not (tmp_path / "archive_dir" / f"{study_id}.7z").exists() def test_maintenance(client: TestClient, admin_access_token: str) -> None: diff --git a/tests/storage/business/test_export.py b/tests/storage/business/test_export.py index 667f357ab9..e9e37211e1 100644 --- a/tests/storage/business/test_export.py +++ b/tests/storage/business/test_export.py @@ -16,6 +16,7 @@ import pytest from checksumdir import dirhash +from py7zr import SevenZipFile from antarest.core.config import Config, StorageConfig from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy @@ -42,14 +43,14 @@ def test_export_file(tmp_path: Path): study_service.export_file.return_value = b"Hello" # Test good study - md = RawStudy(id=name, workspace=DEFAULT_WORKSPACE_NAME) - export_path = tmp_path / "export.zip" + md = RawStudy(id=name, workspace=DEFAULT_WORKSPACE_NAME, path=study_path) + export_path = tmp_path / "export.7z" study_service.export_study(md, export_path) @pytest.mark.unit_test @pytest.mark.parametrize("outputs", [True, False]) -def test_export_file(tmp_path: Path, outputs: bool): +def test_export_archived_study(tmp_path: Path, outputs: bool): root = tmp_path / "folder" root.mkdir() (root / "test").mkdir() @@ -58,7 +59,7 @@ def test_export_file(tmp_path: Path, outputs: bool): (root / "output/results1").mkdir(parents=True) (root / "output/results1/file.txt").write_text("42") - export_path = tmp_path / "study.zip" + export_path = tmp_path / "study.7z" study_factory = Mock() study_service = RawStudyService( @@ -74,10 +75,11 @@ def test_export_file(tmp_path: Path, outputs: bool): study_factory.create_from_fs.return_value = study_tree study_service.export_study(study, export_path, outputs=outputs) - with ZipFile(export_path) as zipf: - assert "file.txt" in zipf.namelist() - assert "test/file.txt" in zipf.namelist() - assert ("output/results1/file.txt" in zipf.namelist()) == outputs + with SevenZipFile(export_path) as szf: + szf_files = set(szf.getnames()) + assert "file.txt" in szf_files + assert "test/file.txt" in szf_files + assert ("output/results1/file.txt" in szf_files) == outputs @pytest.mark.unit_test diff --git a/tests/storage/business/test_import.py b/tests/storage/business/test_import.py index ddf984dc06..6cf9a7ef74 100644 --- a/tests/storage/business/test_import.py +++ b/tests/storage/business/test_import.py @@ -16,23 +16,16 @@ from pathlib import Path from unittest.mock import Mock +import py7zr import pytest -from antarest.core.exceptions import StudyValidationError -from antarest.core.utils.utils import BadArchiveContent +from antarest.core.exceptions import BadArchiveContent, StudyValidationError from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, StudyAdditionalData from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.rawstudy.raw_study_service import RawStudyService from antarest.study.storage.utils import fix_study_root -def build_storage_service(workspace: Path, uuid: str) -> RawStudyService: - service = Mock() - service.get_workspace_path.return_value = workspace - service.get_study_path.return_value = workspace / uuid - return service - - @pytest.mark.unit_test def test_import_study(tmp_path: Path) -> None: name = "my-study" @@ -57,7 +50,9 @@ def test_import_study(tmp_path: Path) -> None: study_service.get = Mock() study_service.get_study_path = Mock() study_service.get.return_value = data - study_service.get_study_path.return_value = tmp_path / "other-study" + + # first test importing a study for an archived study with `.zip` format + study_service.get_study_path.return_value = tmp_path / "other-study-zip" filepath_zip = shutil.make_archive(str(study_path.absolute()), "zip", study_path) shutil.rmtree(study_path) @@ -65,19 +60,48 @@ def test_import_study(tmp_path: Path) -> None: path_zip = Path(filepath_zip) md = RawStudy( - id="other-study", + id="other-study-zip", workspace=DEFAULT_WORKSPACE_NAME, - path=tmp_path / "other-study", + path=tmp_path / "other-study-zip", additional_data=StudyAdditionalData(), groups=["fake_group_1", "fake_group_2"], ) with path_zip.open("rb") as input_file: md = study_service.import_study(md, input_file) - assert md.path == f"{tmp_path}{os.sep}other-study" + assert md.path == f"{tmp_path}{os.sep}other-study-zip" # assert that importing file into a created study does not alter its group assert md.groups == ["fake_group_1", "fake_group_2"] - shutil.rmtree(tmp_path / "other-study") + shutil.rmtree(tmp_path / "other-study-zip") + + # second test for an archived study with a `.7z` format + study_service.get_study_path.return_value = tmp_path / "other-study-7zip" + + study_path.mkdir() + (study_path / "study.antares").touch() + + filepath_7zip = study_path.parent / f"{study_path.name}.7z" + with py7zr.SevenZipFile(filepath_7zip, "w") as archive: + archive.writeall(study_path, arcname="") + + shutil.rmtree(study_path) + + md = RawStudy( + id="other-study-7zip", + workspace=DEFAULT_WORKSPACE_NAME, + path=tmp_path / "other-study-7zip", + additional_data=StudyAdditionalData(), + groups=["fake_group_1", "fake_group_2"], + ) + with filepath_7zip.open("rb") as input_file: + md = study_service.import_study(md, input_file) + assert md.path == f"{tmp_path}{os.sep}other-study-7zip" + # assert that importing file into a created study does not alter its group + assert md.groups == ["fake_group_1", "fake_group_2"] + + shutil.rmtree(tmp_path / "other-study-7zip") + + # test for an unsupported archive format with pytest.raises(BadArchiveContent, match="Unsupported archive format"): study_service.import_study(md, io.BytesIO(b"")) diff --git a/tests/storage/integration/conftest.py b/tests/storage/integration/conftest.py index dcf7e5e830..3c7da36de9 100644 --- a/tests/storage/integration/conftest.py +++ b/tests/storage/integration/conftest.py @@ -11,10 +11,13 @@ # This file is part of the Antares project. import datetime +import shutil +import zipfile from pathlib import Path from unittest.mock import Mock from zipfile import ZipFile +import py7zr import pytest from sqlalchemy import create_engine @@ -43,6 +46,20 @@ def sta_mini_zip_path(project_path: Path) -> Path: return project_path / "examples/studies/STA-mini.zip" +@pytest.fixture +def sta_mini_seven_zip_path(project_path: Path, sta_mini_zip_path: Path) -> Path: + target = project_path / "examples/studies/STA-mini.7z" + if target.is_file(): + return target + with zipfile.ZipFile(sta_mini_zip_path, "r") as zf: + zf.extractall(sta_mini_zip_path.parent) + extracted_dir_path = sta_mini_zip_path.parent / "STA-mini" + with py7zr.SevenZipFile(target, "w") as szf: + szf.writeall(extracted_dir_path, arcname="") + shutil.rmtree(extracted_dir_path) + return target + + @pytest.fixture def storage_service(tmp_path: Path, project_path: Path, sta_mini_zip_path: Path) -> StudyService: engine = create_engine("sqlite:///:memory:", echo=False) diff --git a/tests/storage/integration/test_exporter.py b/tests/storage/integration/test_exporter.py index 46e077a9ca..39ad4b1b7a 100644 --- a/tests/storage/integration/test_exporter.py +++ b/tests/storage/integration/test_exporter.py @@ -17,6 +17,7 @@ from typing import List, Optional from unittest.mock import Mock +import py7zr import pytest from fastapi import FastAPI from starlette.testclient import TestClient @@ -35,11 +36,17 @@ from tests.storage.integration.conftest import UUID -def assert_url_content(url: str, tmp_dir: Path, sta_mini_zip_path: Path) -> bytes: +def assert_url_content(url: str, tmp_dir: Path, sta_mini_archive_path: Path) -> bytes: path_studies = tmp_dir / "studies" - with zipfile.ZipFile(sta_mini_zip_path) as zip_output: - zip_output.extractall(path=path_studies) + if sta_mini_archive_path.suffix == ".zip": + with zipfile.ZipFile(sta_mini_archive_path) as zip_output: + zip_output.extractall(path=path_studies) + elif sta_mini_archive_path.suffix == ".7z": + with py7zr.SevenZipFile(sta_mini_archive_path, "r") as szf: + szf.extractall(path=path_studies / "STA-mini") + else: + raise ValueError(f"Unsupported archive format {sta_mini_archive_path.suffix}") config = Config( resources_path=Path(), @@ -94,20 +101,30 @@ def assert_url_content(url: str, tmp_dir: Path, sta_mini_zip_path: Path) -> byte return fh.read() -def test_exporter_file(tmp_path: Path, sta_mini_zip_path: Path) -> None: +def test_exporter_file(tmp_path: Path, sta_mini_zip_path: Path, sta_mini_seven_zip_path: Path) -> None: + # test with zip file data = assert_url_content( - url=f"/v1/studies/{UUID}/export", - tmp_dir=tmp_path, - sta_mini_zip_path=sta_mini_zip_path, + url=f"/v1/studies/{UUID}/export", tmp_dir=tmp_path, sta_mini_archive_path=sta_mini_zip_path ) assert data and b" None: + # test with zip file + data = assert_url_content( + url=f"/v1/studies/{UUID}/export?no-output", tmp_dir=tmp_path, sta_mini_archive_path=sta_mini_zip_path + ) + assert data and b" None: + # test with 7zip file data = assert_url_content( - url=f"/v1/studies/{UUID}/export?no-output", - tmp_dir=tmp_path, - sta_mini_zip_path=sta_mini_zip_path, + url=f"/v1/studies/{UUID}/export?no-output", tmp_dir=tmp_path, sta_mini_archive_path=sta_mini_seven_zip_path ) assert data and b" INode[t.Any, t.Any, t.Any]: config = Mock() config.path.exist.return_value = True - config.zip_path = None + config.archive_path = None return TestMiddleNode( context=Mock(), config=config, diff --git a/tests/storage/repository/filesystem/test_ini_file_node.py b/tests/storage/repository/filesystem/test_ini_file_node.py index f4571d3ea9..3864fa875f 100644 --- a/tests/storage/repository/filesystem/test_ini_file_node.py +++ b/tests/storage/repository/filesystem/test_ini_file_node.py @@ -96,7 +96,7 @@ def test_get(tmp_path: Path) -> None: areas={}, outputs={}, study_id="id", - zip_path=zipped_path, + archive_path=zipped_path, ), types=types, ) @@ -147,7 +147,7 @@ def test_get_depth(tmp_path: Path) -> None: areas={}, outputs={}, study_id="id", - zip_path=zipped_path, + archive_path=zipped_path, ), types=types, ) diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index a1d0724c76..6490a46096 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -1569,7 +1569,7 @@ def test_get_save_logs(tmp_path: Path) -> None: output_config = Mock(get_file=Mock(return_value="output_id"), archived=False) - file_study_config = FileStudyTreeConfig(tmp_path, tmp_path, "study_id", 0, zip_path=None) + file_study_config = FileStudyTreeConfig(tmp_path, tmp_path, "study_id", 0, archive_path=None) file_study_config.outputs = {"output_id": output_config} context = Mock() diff --git a/tests/storage/web/test_studies_bp.py b/tests/storage/web/test_studies_bp.py index 6774f6127e..9729cf57bc 100644 --- a/tests/storage/web/test_studies_bp.py +++ b/tests/storage/web/test_studies_bp.py @@ -356,29 +356,6 @@ def test_edit_study() -> None: mock_storage_service.edit_study.assert_called_once_with("my-uuid", "url/to/change", {"Hello": "World"}, PARAMS) -# @pytest.mark.unit_test -# def test_edit_study_fail() -> None: -# mock_storage_service = Mock() -# -# app = FastAPI(title=__name__) -# build_study_service( -# app, -# cache=Mock(), -# task_service=Mock(), -# file_transfer_manager=Mock(), -# study_service=mock_storage_service, -# config=CONFIG, -# user_service=Mock(), -# matrix_service=Mock(spec=MatrixService), -# ) -# client = TestClient(app, raise_server_exceptions=False) -# res = client.post("/v1/studies/my-uuid/raw?path=url/to/change", json={}) -# -# assert res.status_code == 400 -# -# mock_storage_service.edit_study.assert_not_called() - - @pytest.mark.unit_test def test_validate() -> None: mock_service = Mock() diff --git a/tests/study/storage/test_abstract_storage_service.py b/tests/study/storage/test_abstract_storage_service.py index 97793c206d..7b2dc79c28 100644 --- a/tests/study/storage/test_abstract_storage_service.py +++ b/tests/study/storage/test_abstract_storage_service.py @@ -11,126 +11,46 @@ # This file is part of the Antares project. import datetime -import zipfile from pathlib import Path -from typing import List, Optional, Sequence -from unittest.mock import Mock, call -from antarest.core.config import Config, StorageConfig -from antarest.core.interfaces.cache import ICache +from py7zr import SevenZipFile + from antarest.core.model import PublicMode -from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, User from antarest.study.model import Study -from antarest.study.storage.abstract_storage_service import AbstractStorageService -from antarest.study.storage.patch_service import PatchService -from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO -from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy, StudyFactory +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService from tests.helpers import with_db_context -class MyStorageService(AbstractStorageService): - """ - This class is only defined to test `AbstractStorageService` class PUBLIC methods. - Abstract methods are not implemented: there are not used or patched with a Mock object. - """ - - def create(self, metadata: Study) -> Study: - raise NotImplementedError - - def exists(self, metadata: Study) -> bool: - raise NotImplementedError - - # noinspection SpellCheckingInspection - def copy(self, src_meta: Study, dest_name: str, groups: Sequence[str], with_outputs: bool = False) -> Study: - raise NotImplementedError - - def get_raw( - self, - metadata: Study, - use_cache: bool = True, - output_dir: Optional[Path] = None, - ) -> FileStudy: - raise NotImplementedError - - def set_reference_output(self, metadata: Study, output_id: str, status: bool) -> None: - raise NotImplementedError - - def delete(self, metadata: Study) -> None: - raise NotImplementedError - - def delete_output(self, metadata: Study, output_id: str) -> None: - raise NotImplementedError - - def get_study_path(self, metadata: Study) -> Path: - raise NotImplementedError - - def export_study_flat( - self, - metadata: Study, - dst_path: Path, - outputs: bool = True, - output_list_filter: Optional[List[str]] = None, - denormalize: bool = True, - ) -> None: - raise NotImplementedError - - def get_synthesis(self, metadata: Study, params: Optional[RequestParameters] = None) -> FileStudyTreeConfigDTO: - raise NotImplementedError - - def initialize_additional_data(self, study: Study) -> bool: - raise NotImplementedError - - -class TmpCopy(object): - """A helper object that compares equal if a folder is a "tmp_copy" folder.""" - - def __init__(self, tmp_path: Path): - self.tmp_path = tmp_path - - def __eq__(self, other: Path): - if isinstance(other, Path) and other.name == "tmp_copy": - # `is_relative_to` is not available for Python < 3.9 - try: - other.relative_to(self.tmp_path) - return True - except ValueError: - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return f"" - - class TestAbstractStorageService: @with_db_context - def test_export_study(self, tmp_path: Path) -> None: - tmp_dir = tmp_path / "tmp" - tmp_dir.mkdir() - study_path = tmp_path / "My Study" + def test_export_study(self, tmp_path: Path, raw_study_service: RawStudyService) -> None: + """ + Test the `export_study` method of the `AbstractStorageService` class. + Args: + tmp_path: The temporary directory where to store the 7z file. + raw_study_service: The `RawStudyService` instance to test the `export_study` method. - service = MyStorageService( - config=Config(storage=StorageConfig(tmp_dir=tmp_dir)), - study_factory=Mock(spec=StudyFactory), - patch_service=Mock(spec=PatchService), - cache=Mock(spec=ICache), - ) + Returns: - ## Prepare database objects + """ + # Prepare a dummy study with a `study.antares` file, and non-empty input and output folder + study_path = tmp_path / "My Study" + study_path.mkdir() + content_list = ["study.antares", "input/areas/foo.ini", "output/20240424-1200eco/result.log"] + for content in content_list: + study_path.joinpath(content).parent.mkdir(parents=True, exist_ok=True) + study_path.joinpath(content).touch() # noinspection PyArgumentList user = User(id=0, name="admin") db.session.add(user) db.session.commit() - # noinspection PyArgumentList group = Group(id="my-group", name="group") db.session.add(group) db.session.commit() - # noinspection PyArgumentList metadata = Study( name="My Study", @@ -146,18 +66,23 @@ def test_export_study(self, tmp_path: Path) -> None: ) db.session.add(metadata) db.session.commit() + db.session.refresh(metadata) - ## Check the `export_study` function - service.export_study_flat = Mock(return_value=None) - target_path = tmp_path / "export.zip" - actual = service.export_study(metadata, target_path, outputs=True) + # Check the `export_study` function + target_path = tmp_path / "export.7z" + actual = raw_study_service.export_study(metadata, target_path, outputs=True) assert actual == target_path - ## Check the call to export_study_flat - assert service.export_study_flat.mock_calls == [call(metadata, TmpCopy(tmp_path), True)] - - ## Check that the ZIP file exist and is valid - with zipfile.ZipFile(target_path) as zf: - # Actually, there is nothing is the ZIP file, - # because the Study files doesn't really exist. - assert not zf.namelist() + # Check that the 7zip file exist and is valid + with SevenZipFile(target_path) as szf: + # Check that the content of the 7z file is the same as the study folder + assert set(content_list) == ( + set(szf.getnames()) + - { + ".", + "input", + "output", + "input/areas", + "output/20240424-1200eco", + } + )