Skip to content

Commit

Permalink
Merge pull request #4501 from neutrinoceros/cleanup_after_load_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros authored Jun 18, 2023
2 parents f3a7ea3 + 8264ef6 commit b754a80
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
8 changes: 8 additions & 0 deletions yt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,7 @@ def load_sample(
_download_sample_data_file,
_get_test_data_dir_path,
get_data_registry_table,
get_download_cache_dir,
)

pooch_logger = pooch.utils.get_logger()
Expand Down Expand Up @@ -1624,6 +1625,13 @@ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
if load_name not in str(loadable_path):
loadable_path = loadable_path.joinpath(load_name, specific_file)

try:
# clean cache dir
get_download_cache_dir().rmdir()
except OSError:
# cache dir isn't empty
pass

return load(loadable_path, **kwargs)


Expand Down
33 changes: 21 additions & 12 deletions yt/sample_data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
import re
import sys
from functools import lru_cache
from itertools import chain
from pathlib import Path
from typing import Optional, Union
from typing import Optional
from warnings import warn

from yt.config import ytcfg
from yt.funcs import mylog
from yt.utilities.on_demand_imports import (
_pandas as pd,
_pooch as pooch,
Expand Down Expand Up @@ -190,15 +188,26 @@ def lookup_on_disk_data(fn) -> Path:
raise FileNotFoundError(err_msg)


@lru_cache(maxsize=128)
def _get_pooch_instance():
data_registry = get_data_registry_table()
cache_storage = _get_test_data_dir_path() / "yt_download_cache"
def get_download_cache_dir():
return _get_test_data_dir_path() / "yt_download_cache"

registry = {k: v["hash"] for k, v in _get_sample_data_registry().items()}
return pooch.create(
path=cache_storage, base_url="https://yt-project.org/data/", registry=registry
)

_POOCHIE = None


def _get_pooch_instance():
global _POOCHIE
if _POOCHIE is None:
data_registry = get_data_registry_table()
cache_storage = get_download_cache_dir()

registry = {k: v["hash"] for k, v in _get_sample_data_registry().items()}
_POOCHIE = pooch.create(
path=cache_storage,
base_url="https://yt-project.org/data/",
registry=registry,
)
return _POOCHIE


def _download_sample_data_file(
Expand All @@ -216,4 +225,4 @@ def _download_sample_data_file(

poochie = _get_pooch_instance()
poochie.fetch(filename, downloader=downloader)
return Path.joinpath(poochie.path, filename)
return poochie.path / filename
9 changes: 8 additions & 1 deletion yt/tests/test_load_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from yt.config import ytcfg
from yt.loaders import load_sample
from yt.sample_data.api import get_data_registry_table
from yt.sample_data.api import (
get_data_registry_table,
get_download_cache_dir,
)
from yt.testing import requires_module_pytest
from yt.utilities.logger import ytLogger

Expand Down Expand Up @@ -77,8 +80,12 @@ def capturable_logger(caplog):
def test_load_sample_small_dataset(
fn, archive, exact_loc, class_: str, tmp_data_dir, caplog
):
cache_path = get_download_cache_dir()
assert not cache_path.exists()

ds = load_sample(fn, progressbar=False, timeout=30)
assert type(ds).__name__ == class_
assert not cache_path.exists()

text = textwrap.dedent(
f"""
Expand Down

0 comments on commit b754a80

Please sign in to comment.