Skip to content

Commit

Permalink
feat(archive-apis)!: use .7z format to archive studies (#2013)
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte authored Oct 24, 2024
2 parents d44a242 + b86b528 commit 843bd3a
Show file tree
Hide file tree
Showing 37 changed files with 580 additions and 440 deletions.
9 changes: 9 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,12 @@ def __init__(self) -> None:
class ChildNotFoundError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.NOT_FOUND, message)


class BadArchiveContent(Exception):
"""
Exception raised when the archive file is corrupted (or unknown).
"""

def __init__(self, message: str = "Unsupported archive format") -> None:
super().__init__(message)
188 changes: 188 additions & 0 deletions antarest/core/utils/archives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
#
# See AUTHORS.txt
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
import logging
import os
import shutil
import tempfile
import typing as t
import zipfile
from enum import Enum
from pathlib import Path

import py7zr

from antarest.core.exceptions import BadArchiveContent, ShouldNotHappenException

logger = logging.getLogger(__name__)


class ArchiveFormat(str, Enum):
ZIP = ".zip"
SEVEN_ZIP = ".7z"


def is_archive_format(suffix: str) -> bool:
return suffix in {ArchiveFormat.ZIP, ArchiveFormat.SEVEN_ZIP}


def archive_dir(
src_dir_path: Path,
target_archive_path: Path,
remove_source_dir: bool = False,
archive_format: t.Optional[ArchiveFormat] = None,
) -> None:
if archive_format is not None and target_archive_path.suffix != archive_format:
raise ShouldNotHappenException(
f"Non matching archive format {archive_format} and target archive suffix {target_archive_path.suffix}"
)
if target_archive_path.suffix == ArchiveFormat.SEVEN_ZIP:
with py7zr.SevenZipFile(target_archive_path, mode="w") as szf:
szf.writeall(src_dir_path, arcname="")
elif target_archive_path.suffix == ArchiveFormat.ZIP:
with zipfile.ZipFile(target_archive_path, mode="w", compression=zipfile.ZIP_DEFLATED, compresslevel=2) as zipf:
len_dir_path = len(str(src_dir_path))
for root, _, files in os.walk(src_dir_path):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, file_path[len_dir_path:])
else:
raise ShouldNotHappenException(f"Unsupported archive format {target_archive_path.suffix}")
if remove_source_dir:
shutil.rmtree(src_dir_path)


def unzip(dir_path: Path, zip_path: Path, remove_source_zip: bool = False) -> None:
with zipfile.ZipFile(zip_path, mode="r") as zipf:
zipf.extractall(dir_path)
if remove_source_zip:
zip_path.unlink()


def is_zip(path: Path) -> bool:
return path.name.endswith(".zip")


def read_in_zip(
zip_path: Path,
inside_zip_path: Path,
read: t.Callable[[t.Optional[Path]], None],
) -> None:
tmp_dir = None
try:
path, tmp_dir = extract_file_to_tmp_dir(zip_path, inside_zip_path)
read(path)
except KeyError:
logger.warning(f"{inside_zip_path} not found in {zip_path}")
read(None)
finally:
if tmp_dir is not None:
tmp_dir.cleanup()


def extract_archive(stream: t.BinaryIO, target_dir: Path) -> None:
"""
Extract a ZIP archive to a given destination.
Args:
stream: The stream containing the archive.
target_dir: The directory where to extract the archive.
Raises:
BadArchiveContent: If the archive is corrupted or in an unknown format.
"""

# Read the first few bytes to identify the file format
file_format = stream.read(4)
stream.seek(0)

if file_format[:4] == b"PK\x03\x04":
try:
with zipfile.ZipFile(stream) as zf:
zf.extractall(path=target_dir)
except zipfile.BadZipFile as error:
raise BadArchiveContent("Unsupported ZIP format") from error

elif file_format[:2] == b"7z":
try:
with py7zr.SevenZipFile(stream, "r") as zf:
zf.extractall(target_dir)
except py7zr.exceptions.Bad7zFile as error:
raise BadArchiveContent("Unsupported 7z format") from error

else:
raise BadArchiveContent


def extract_file_to_tmp_dir(archive_path: Path, inside_archive_path: Path) -> t.Tuple[Path, t.Any]:
str_inside_archive_path = str(inside_archive_path).replace("\\", "/")
tmp_dir = tempfile.TemporaryDirectory()
try:
if archive_path.suffix == ArchiveFormat.ZIP:
with zipfile.ZipFile(archive_path) as zip_obj:
zip_obj.extract(str_inside_archive_path, tmp_dir.name)
elif archive_path.suffix == ArchiveFormat.SEVEN_ZIP:
with py7zr.SevenZipFile(archive_path, mode="r") as szf:
szf.extract(path=tmp_dir.name, targets=[str_inside_archive_path])
else:
raise ValueError(f"Unsupported archive format for {archive_path}")
except Exception as e:
logger.warning(
f"Failed to extract {str_inside_archive_path} in archive {archive_path}",
exc_info=e,
)
tmp_dir.cleanup()
raise
path = Path(tmp_dir.name) / str_inside_archive_path
return path, tmp_dir


def read_file_from_archive(archive_path: Path, posix_path: str) -> str:
"""
Read a file from an archive.
Args:
archive_path: the path to the archive file.
posix_path: path to the file inside the archive.
Returns:
The content of the file as a string.
"""

if archive_path.suffix == ArchiveFormat.ZIP:
with zipfile.ZipFile(archive_path) as zip_obj:
with zip_obj.open(posix_path) as f:
return f.read().decode("utf-8")
elif archive_path.suffix == ArchiveFormat.SEVEN_ZIP:
with py7zr.SevenZipFile(archive_path, mode="r") as szf:
file_text: str = szf.read([posix_path])[posix_path].read().decode("utf-8")
return file_text
else:
raise ValueError(f"Unsupported {archive_path.suffix} archive format for {archive_path}")


def extract_lines_from_archive(root: Path, posix_path: str) -> t.List[str]:
"""
Extract text lines from various types of files.
Args:
root: 7zip or ZIP file containing the study.
posix_path: Relative path to the file to extract.
Returns:
list of lines
"""
try:
text = read_file_from_archive(root, posix_path)
return text.splitlines(keepends=False)
# File not found in the archive
except KeyError:
return []
104 changes: 0 additions & 104 deletions antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
import glob
import http
import logging
import os
import re
import shutil
import tempfile
import time
import typing as t
import zipfile
from pathlib import Path

import py7zr
from fastapi import HTTPException

from antarest.core.exceptions import ShouldNotHappenException
Expand Down Expand Up @@ -65,49 +60,6 @@ def sanitize_string(string: str) -> str:
return str(glob.escape(string))


class BadArchiveContent(Exception):
"""
Exception raised when the archive file is corrupted (or unknown).
"""

def __init__(self, message: str = "Unsupported archive format") -> None:
super().__init__(message)


def extract_zip(stream: t.BinaryIO, target_dir: Path) -> None:
"""
Extract a ZIP archive to a given destination.
Args:
stream: The stream containing the archive.
target_dir: The directory where to extract the archive.
Raises:
BadArchiveContent: If the archive is corrupted or in an unknown format.
"""

# Read the first few bytes to identify the file format
file_format = stream.read(4)
stream.seek(0)

if file_format[:4] == b"PK\x03\x04":
try:
with zipfile.ZipFile(stream) as zf:
zf.extractall(path=target_dir)
except zipfile.BadZipFile as error:
raise BadArchiveContent("Unsupported ZIP format") from error

elif file_format[:2] == b"7z":
try:
with py7zr.SevenZipFile(stream, "r") as zf:
zf.extractall(target_dir)
except py7zr.exceptions.Bad7zFile as error:
raise BadArchiveContent("Unsupported 7z format") from error

else:
raise BadArchiveContent


def get_default_config_path() -> t.Optional[Path]:
config = Path("config.yaml")
if config.exists():
Expand Down Expand Up @@ -181,62 +133,6 @@ def concat_files_to_str(files: t.List[Path]) -> str:
return concat_str


def zip_dir(dir_path: Path, zip_path: Path, remove_source_dir: bool = False) -> None:
with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED, compresslevel=2) as zipf:
len_dir_path = len(str(dir_path))
for root, _, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, file_path[len_dir_path:])
if remove_source_dir:
shutil.rmtree(dir_path)


def unzip(dir_path: Path, zip_path: Path, remove_source_zip: bool = False) -> None:
with zipfile.ZipFile(zip_path, mode="r") as zipf:
zipf.extractall(dir_path)
if remove_source_zip:
zip_path.unlink()


def is_zip(path: Path) -> bool:
return path.name.endswith(".zip")


def extract_file_to_tmp_dir(zip_path: Path, inside_zip_path: Path) -> t.Tuple[Path, t.Any]:
str_inside_zip_path = str(inside_zip_path).replace("\\", "/")
tmp_dir = tempfile.TemporaryDirectory()
try:
with zipfile.ZipFile(zip_path) as zip_obj:
zip_obj.extract(str_inside_zip_path, tmp_dir.name)
except Exception as e:
logger.warning(
f"Failed to extract {str_inside_zip_path} in zip {zip_path}",
exc_info=e,
)
tmp_dir.cleanup()
raise
path = Path(tmp_dir.name) / inside_zip_path
return path, tmp_dir


def read_in_zip(
zip_path: Path,
inside_zip_path: Path,
read: t.Callable[[t.Optional[Path]], None],
) -> None:
tmp_dir = None
try:
path, tmp_dir = extract_file_to_tmp_dir(zip_path, inside_zip_path)
read(path)
except KeyError:
logger.warning(f"{inside_zip_path} not found in {zip_path}")
read(None)
finally:
if tmp_dir is not None:
tmp_dir.cleanup()


def suppress_exception(
callback: t.Callable[[], T],
logger_: t.Callable[[Exception], None],
Expand Down
3 changes: 2 additions & 1 deletion antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from antarest.core.interfaces.eventbus import Event, EventType, IEventBus
from antarest.core.model import PermissionInfo, PublicMode
from antarest.core.requests import RequestParameters
from antarest.core.utils.utils import assert_this, unzip
from antarest.core.utils.archives import unzip
from antarest.core.utils.utils import assert_this
from antarest.launcher.adapters.abstractlauncher import AbstractLauncher, LauncherCallbacks, LauncherInitException
from antarest.launcher.adapters.log_manager import LogTailManager
from antarest.launcher.model import JobStatus, LauncherParametersDTO, LogType, XpansionParametersDTO
Expand Down
7 changes: 4 additions & 3 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.tasks.model import TaskResult, TaskType
from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier
from antarest.core.utils.archives import ArchiveFormat, archive_dir, is_zip, read_in_zip
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import StopWatch, concat_files, concat_files_to_str, is_zip, read_in_zip, zip_dir
from antarest.core.utils.utils import StopWatch, concat_files, concat_files_to_str
from antarest.launcher.adapters.abstractlauncher import LauncherCallbacks
from antarest.launcher.adapters.factory_launcher import FactoryLauncher
from antarest.launcher.extensions.adequacy_patch.extension import AdequacyPatchExtension
Expand Down Expand Up @@ -548,7 +549,7 @@ def _import_output(
if not output_is_zipped and job_launch_params.archive_output:
logger.info("Re zipping output for transfer")
zip_path = output_true_path.parent / f"{output_true_path.name}.zip"
zip_dir(output_true_path, zip_path=zip_path)
archive_dir(output_true_path, target_archive_path=zip_path, archive_format=ArchiveFormat.ZIP)
stopwatch.log_elapsed(lambda x: logger.info(f"Zipped output for job {job_id} in {x}s"))

final_output_path = zip_path or output_true_path
Expand Down Expand Up @@ -600,7 +601,7 @@ def _download_fallback_output(self, job_id: str, params: RequestParameters) -> F
def export_task(_: TaskUpdateNotifier) -> TaskResult:
try:
#
zip_dir(output_path, export_path)
archive_dir(output_path, export_path, archive_format=ArchiveFormat.ZIP)
self.file_transfer_manager.set_ready(export_id)
return TaskResult(success=True, message="")
except Exception as e:
Expand Down
5 changes: 3 additions & 2 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
from antarest.core.serialization import from_json
from antarest.core.tasks.model import TaskResult, TaskType
from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier
from antarest.core.utils.archives import ArchiveFormat, archive_dir
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import StopWatch, zip_dir
from antarest.core.utils.utils import StopWatch
from antarest.login.service import LoginService
from antarest.matrixstore.exceptions import MatrixDataSetNotFound
from antarest.matrixstore.model import (
Expand Down Expand Up @@ -469,7 +470,7 @@ def create_matrix_files(self, matrix_ids: t.Sequence[str], export_path: Path) ->
else:
# noinspection PyTypeChecker
np.savetxt(filepath, array, delimiter="\t", fmt="%.18f")
zip_dir(Path(tmpdir), export_path)
archive_dir(Path(tmpdir), export_path, archive_format=ArchiveFormat.ZIP)
stopwatch.log_elapsed(lambda x: logger.info(f"Matrix dataset exported (zipped mode) in {x}s"))
return str(export_path)

Expand Down
Loading

0 comments on commit 843bd3a

Please sign in to comment.