diff --git a/antarest/core/utils/utils.py b/antarest/core/utils/utils.py index 5b730b1552..08fce0cb9d 100644 --- a/antarest/core/utils/utils.py +++ b/antarest/core/utils/utils.py @@ -10,7 +10,6 @@ import py7zr import redis -from py7zr import SevenZipFile from antarest.core.config import RedisConfig from antarest.core.exceptions import ShouldNotHappenException @@ -181,17 +180,6 @@ def zip_dir(dir_path: Path, zip_path: Path, remove_source_dir: bool = False) -> shutil.rmtree(dir_path) -def seven_zip_dir(dir_path: Path, seven_zip_path: Path, remove_source_dir: bool = False) -> None: - len_dir_path = len(str(dir_path)) - with SevenZipFile(seven_zip_path, "w") as szf: - for root, _, files in os.walk(dir_path): - for file in files: - file_path = os.path.join(root, file) - szf.write(file_path, arcname=file_path[len_dir_path:]) - if remove_source_dir: - shutil.rmtree(dir_path) - - def unzip(dir_path: Path, zip_path: Path, remove_source_zip: bool = False) -> None: with zipfile.ZipFile(zip_path, mode="r") as zipf: zipf.extractall(dir_path) diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index 4394e22bdc..85e79815df 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -9,12 +9,13 @@ import numpy as np import pandas as pd +import py7zr from antarest.core.config import Config from antarest.core.exceptions import BadOutputError, StudyOutputNotFoundError from antarest.core.interfaces.cache import CacheConstants, ICache from antarest.core.model import JSON, PublicMode -from antarest.core.utils.utils import StopWatch, extract_archive, seven_zip_dir, unzip, zip_dir +from antarest.core.utils.utils import StopWatch, extract_archive, unzip, zip_dir from antarest.login.model import GroupDTO from antarest.study.common.default_values import AreasQueryFile, LinksQueryFile from antarest.study.common.studystorage import IStudyStorageService, T @@ -523,7 +524,8 @@ def export_study(self, metadata: T, target: Path, outputs: bool = True) -> Path: tmp_study_path = Path(tmpdir) / "tmp_copy" self.export_study_flat(metadata, tmp_study_path, outputs) stopwatch = StopWatch() - seven_zip_dir(tmp_study_path, target) + with py7zr.SevenZipFile(target, "w") as szf: + szf.writeall(tmp_study_path, arcname=".") stopwatch.log_elapsed(lambda x: logger.info(f"Study {path_study} exported (7zip mode) in {x}s")) return target diff --git a/tests/study/storage/test_abstract_storage_service.py b/tests/study/storage/test_abstract_storage_service.py index 317bb97458..1941a19977 100644 --- a/tests/study/storage/test_abstract_storage_service.py +++ b/tests/study/storage/test_abstract_storage_service.py @@ -1,125 +1,44 @@ import datetime from pathlib import Path -from typing import List, Optional, Sequence -from unittest.mock import Mock, call from py7zr import SevenZipFile -from antarest.core.config import Config, StorageConfig -from antarest.core.interfaces.cache import ICache from antarest.core.model import PublicMode -from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, User from antarest.study.model import Study -from antarest.study.storage.abstract_storage_service import AbstractStorageService -from antarest.study.storage.patch_service import PatchService -from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO -from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy, StudyFactory +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService from tests.helpers import with_db_context -class MyStorageService(AbstractStorageService): - """ - This class is only defined to test `AbstractStorageService` class PUBLIC methods. - Abstract methods are not implemented: there are not used or patched with a Mock object. - """ - - def create(self, metadata: Study) -> Study: - raise NotImplementedError - - def exists(self, metadata: Study) -> bool: - raise NotImplementedError - - # noinspection SpellCheckingInspection - def copy(self, src_meta: Study, dest_name: str, groups: Sequence[str], with_outputs: bool = False) -> Study: - raise NotImplementedError - - def get_raw( - self, - metadata: Study, - use_cache: bool = True, - output_dir: Optional[Path] = None, - ) -> FileStudy: - raise NotImplementedError - - def set_reference_output(self, metadata: Study, output_id: str, status: bool) -> None: - raise NotImplementedError - - def delete(self, metadata: Study) -> None: - raise NotImplementedError - - def delete_output(self, metadata: Study, output_id: str) -> None: - raise NotImplementedError - - def get_study_path(self, metadata: Study) -> Path: - raise NotImplementedError - - def export_study_flat( - self, - metadata: Study, - dst_path: Path, - outputs: bool = True, - output_list_filter: Optional[List[str]] = None, - denormalize: bool = True, - ) -> None: - raise NotImplementedError - - def get_synthesis(self, metadata: Study, params: Optional[RequestParameters] = None) -> FileStudyTreeConfigDTO: - raise NotImplementedError - - def initialize_additional_data(self, study: Study) -> bool: - raise NotImplementedError - - -class TmpCopy(object): - """A helper object that compares equal if a folder is a "tmp_copy" folder.""" - - def __init__(self, tmp_path: Path): - self.tmp_path = tmp_path - - def __eq__(self, other: Path): - if isinstance(other, Path) and other.name == "tmp_copy": - # `is_relative_to` is not available for Python < 3.9 - try: - other.relative_to(self.tmp_path) - return True - except ValueError: - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return f"" - - class TestAbstractStorageService: @with_db_context - def test_export_study(self, tmp_path: Path) -> None: - tmp_dir = tmp_path / "tmp" - tmp_dir.mkdir() - study_path = tmp_path / "My Study" + def test_export_study(self, tmp_path: Path, raw_study_service: RawStudyService) -> None: + """ + Test the `export_study` method of the `AbstractStorageService` class. + Args: + tmp_path: The temporary directory where to store the 7z file. + raw_study_service: The `RawStudyService` instance to test the `export_study` method. - service = MyStorageService( - config=Config(storage=StorageConfig(tmp_dir=tmp_dir)), - study_factory=Mock(spec=StudyFactory), - patch_service=Mock(spec=PatchService), - cache=Mock(spec=ICache), - ) + Returns: - ## Prepare database objects + """ + # Prepare a dummy study with a `study.antarest` file, and non-empty input and output folder + study_path = tmp_path / "My Study" + study_path.mkdir() + content_list = ["study.antares", "input/areas/foo.ini", "output/20240424-1200eco/result.log"] + for content in content_list: + study_path.joinpath(content).parent.mkdir(parents=True, exist_ok=True) + study_path.joinpath(content).touch() # noinspection PyArgumentList user = User(id=0, name="admin") db.session.add(user) db.session.commit() - # noinspection PyArgumentList group = Group(id="my-group", name="group") db.session.add(group) db.session.commit() - # noinspection PyArgumentList metadata = Study( name="My Study", @@ -135,18 +54,23 @@ def test_export_study(self, tmp_path: Path) -> None: ) db.session.add(metadata) db.session.commit() + db.session.refresh(metadata) - ## Check the `export_study` function - service.export_study_flat = Mock(return_value=None) + # Check the `export_study` function target_path = tmp_path / "export.7z" - actual = service.export_study(metadata, target_path, outputs=True) + actual = raw_study_service.export_study(metadata, target_path, outputs=True) assert actual == target_path - ## Check the call to export_study_flat - assert service.export_study_flat.mock_calls == [call(metadata, TmpCopy(tmp_path), True)] - - ## Check that the 7zip file exist and is valid + # Check that the 7zip file exist and is valid with SevenZipFile(target_path) as szf: - # Actually, there is nothing is the .7z file, - # because the Study files doesn't really exist. - assert not szf.getnames() + # Check that the content of the 7z file is the same as the study folder + assert set(content_list) == ( + set(szf.getnames()) + - { + ".", + "input", + "output", + "input/areas", + "output/20240424-1200eco", + } + )