diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 364535c5e..5e0298724 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -14,7 +14,7 @@ on: - "v*.*.*" env: - CACHE_PREFIX: v3 # Change this to invalidate existing cache. + CACHE_PREFIX: v5 # Change this to invalidate existing cache. PYTHON_PATH: ./ DEFAULT_PYTHON: 3.9 WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 9093a4ce0..1b5642588 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - + - Added the `Workspace.remove_step()` method to safely remove steps. - The `GSWorkspace()` can now be initialized with google cloud bucket subfolders. ### Fixed diff --git a/tango/integrations/beaker/workspace.py b/tango/integrations/beaker/workspace.py index 69355ae28..1915fc719 100644 --- a/tango/integrations/beaker/workspace.py +++ b/tango/integrations/beaker/workspace.py @@ -420,3 +420,10 @@ def _update_step_info(self, step_info: StepInfo): self.Constants.STEP_INFO_FNAME, # step info filename quiet=True, ) + + def _remove_step_info(self, step_info: StepInfo) -> None: + # remove dir from beaker workspace + dataset_name = self.Constants.step_artifact_name(step_info) + step_dataset = self.beaker.dataset.get(dataset_name) + if step_dataset is not None: + self.beaker.dataset.delete(step_dataset) diff --git a/tango/integrations/gs/workspace.py b/tango/integrations/gs/workspace.py index cd4ef31e6..fd827f02a 100644 --- a/tango/integrations/gs/workspace.py +++ b/tango/integrations/gs/workspace.py @@ -400,6 +400,15 @@ def _update_step_info(self, step_info: StepInfo): self._ds.put(step_info_entity) + def _remove_step_info(self, step_info: StepInfo) -> None: + # remove dir from bucket + step_artifact = self.client.get(self.Constants.step_artifact_name(step_info)) + if step_artifact is not None: + self.client.delete(step_artifact) + + # remove datastore entities + self._ds.delete(key=self._ds.key("stepinfo", step_info.unique_id)) + def _save_run_log(self, name: str, log_file: Path): """ The logs are stored in the bucket. The Run object details are stored in diff --git a/tango/integrations/transformers/__init__.py b/tango/integrations/transformers/__init__.py index 31f6a7049..386e48ecb 100644 --- a/tango/integrations/transformers/__init__.py +++ b/tango/integrations/transformers/__init__.py @@ -45,6 +45,7 @@ from tango.integrations.transformers import * available_models = [] + for name in sorted(Model.list_available()): if name.startswith("transformers::AutoModel"): available_models.append(name) diff --git a/tango/integrations/wandb/workspace.py b/tango/integrations/wandb/workspace.py index 36e663a0f..97b008484 100644 --- a/tango/integrations/wandb/workspace.py +++ b/tango/integrations/wandb/workspace.py @@ -292,6 +292,13 @@ def step_failed(self, step: Step, e: BaseException) -> None: if step.unique_id in self._running_step_info: del self._running_step_info[step.unique_id] + def remove_step(self, step_unique_id: str): + """ + Removes cached step using the given unique step id + :raises KeyError: If there is no step with the given name. + """ + raise NotImplementedError() + def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: all_steps = set(targets) for step in targets: diff --git a/tango/step_cache.py b/tango/step_cache.py index 29ef70b35..99dea91a4 100644 --- a/tango/step_cache.py +++ b/tango/step_cache.py @@ -48,6 +48,11 @@ def __setitem__(self, step: Step, value: Any) -> None: """Writes the results for the given step. Throws an exception if the step is already cached.""" raise NotImplementedError() + @abstractmethod + def __delitem__(self, step_unique_id: Union[Step, StepInfo]) -> None: + """Removes a step from step cache""" + raise NotImplementedError() + @abstractmethod def __len__(self) -> int: """Returns the number of results saved in this cache.""" diff --git a/tango/step_caches/local_step_cache.py b/tango/step_caches/local_step_cache.py index dcc519ae1..1ed9dcd12 100644 --- a/tango/step_caches/local_step_cache.py +++ b/tango/step_caches/local_step_cache.py @@ -1,5 +1,7 @@ import collections import logging +import os +import shutil import warnings import weakref from pathlib import Path @@ -89,6 +91,17 @@ def _get_from_cache(self, key: str) -> Optional[Any]: except KeyError: return None + def _remove_from_cache(self, key: str) -> None: + # check and remove from strong cache + if key in self.strong_cache: + del self.strong_cache[key] + assert key not in self.strong_cache + + # check and remove from weak cache + if key in self.weak_cache: + del self.weak_cache[key] + assert key not in self.weak_cache + def _metadata_path(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path: return self.step_dir(step_or_unique_id) / self.METADATA_FILE_NAME @@ -147,6 +160,14 @@ def __setitem__(self, step: Step, value: Any) -> None: pass raise + def __delitem__(self, step: Union[Step, StepInfo]) -> None: + location = str(self.dir) + "/" + str(step.unique_id) + try: + shutil.rmtree(location) + self._remove_from_cache(step.unique_id) + except OSError: + raise OSError(f"Step cache folder for '{step.unique_id}' not found. Cannot be deleted.") + def __len__(self) -> int: return sum(1 for _ in self.dir.glob(f"*/{self.METADATA_FILE_NAME}")) diff --git a/tango/step_caches/memory_step_cache.py b/tango/step_caches/memory_step_cache.py index e57751c70..184b7a4ce 100644 --- a/tango/step_caches/memory_step_cache.py +++ b/tango/step_caches/memory_step_cache.py @@ -35,6 +35,12 @@ def __setitem__(self, step: Step, value: Any) -> None: UserWarning, ) + def __delitem__(self, step: Union[Step, StepInfo]) -> None: + if step.unique_id in self.cache: + del self.cache[step.unique_id] + else: + raise KeyError(f"{step.unique_id} not present in the memory cache. Cannot be deleted.") + def __contains__(self, step: object) -> bool: if isinstance(step, (Step, StepInfo)): return step.unique_id in self.cache diff --git a/tango/workspace.py b/tango/workspace.py index e9bb72815..1261b267b 100644 --- a/tango/workspace.py +++ b/tango/workspace.py @@ -419,6 +419,14 @@ def step_result(self, step_name: str) -> Any: return self.step_cache[run.steps[step_name]] raise KeyError(f"No step named '{step_name}' found in previous runs") + @abstractmethod + def remove_step(self, step_unique_id: str): + """ + Removes cached step using the given unique step id + :raises KeyError: If there is no step with the given name. + """ + raise NotImplementedError() + def capture_logs_for_run(self, name: str) -> ContextManager[None]: """ Should return a context manager that can be used to capture the logs for a run. diff --git a/tango/workspaces/local_workspace.py b/tango/workspaces/local_workspace.py index 905ca5d79..61b4113e4 100644 --- a/tango/workspaces/local_workspace.py +++ b/tango/workspaces/local_workspace.py @@ -322,6 +322,20 @@ def step_failed(self, step: Step, e: BaseException) -> None: lock.release() del self.locks[step] + def remove_step(self, step_unique_id: str) -> None: + """ + Get Step unique id from the user and remove the step information from cache + :raises KeyError: If no step with the unique name found in the cache dir + """ + with SqliteDict(self.step_info_file) as d: + try: + step_info = self.step_info(step_unique_id) + del d[step_unique_id] + d.commit() + del self.cache[step_info] + except KeyError: + raise KeyError(f"No step named '{step_unique_id}' found") + def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: # sanity check targets targets = list(targets) diff --git a/tango/workspaces/memory_workspace.py b/tango/workspaces/memory_workspace.py index bcbb89498..47b29c077 100644 --- a/tango/workspaces/memory_workspace.py +++ b/tango/workspaces/memory_workspace.py @@ -98,6 +98,18 @@ def step_failed(self, step: Step, e: BaseException) -> None: existing_step_info.end_time = utc_now_datetime() existing_step_info.error = exception_to_string(e) + def remove_step(self, step_unique_id: str) -> None: + """ + Get Step unique id from the user and remove the step information from memory cache + :raises KeyError: If no step with the unique name found in the cache dir + """ + try: + step_info = self.step_info(step_unique_id) + del self.unique_id_to_info[step_unique_id] + del self.step_cache[step_info] + except KeyError: + raise KeyError(f"{step_unique_id} step info not found, step cache cannot be deleted") + def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: if name is None: name = petname.generate() diff --git a/tango/workspaces/remote_workspace.py b/tango/workspaces/remote_workspace.py index 3460ca4a9..49d351a9f 100644 --- a/tango/workspaces/remote_workspace.py +++ b/tango/workspaces/remote_workspace.py @@ -174,6 +174,22 @@ def step_failed(self, step: Step, e: BaseException) -> None: finally: self.locks.pop(step).release() + def remove_step(self, step_unique_id: str) -> None: + """ + Get Step unique id from the user and remove the step information from cache + :raises KeyError: If no step with the unique name found in the cache dir + """ + try: + step_info = self.step_info(step_unique_id) + # remove remote objects + self._remove_step_info(step_info) + + # remove cache info + del self.cache[step_info] + except KeyError: + raise KeyError(f"No step named '{step_unique_id}' found.") + return None + def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]: import concurrent.futures @@ -229,3 +245,7 @@ def capture_logs_for_run(self, name: str) -> Generator[None, None, None]: @abstractmethod def _update_step_info(self, step_info: StepInfo): raise NotImplementedError() + + @abstractmethod + def _remove_step_info(self, step_info: StepInfo): + raise NotImplementedError() diff --git a/tests/integrations/beaker/workspace_test.py b/tests/integrations/beaker/workspace_test.py index 4c68bcef0..d42a1064b 100644 --- a/tests/integrations/beaker/workspace_test.py +++ b/tests/integrations/beaker/workspace_test.py @@ -1,3 +1,6 @@ +import pytest +from beaker import DatasetNotFound + from tango.common.testing.steps import FloatStep from tango.integrations.beaker.workspace import BeakerWorkspace from tango.step_info import StepState @@ -5,6 +8,7 @@ def test_from_url(beaker_workspace: str): + print(beaker_workspace) workspace = Workspace.from_url(f"beaker://{beaker_workspace}") assert isinstance(workspace, BeakerWorkspace) @@ -22,3 +26,27 @@ def test_direct_usage(beaker_workspace: str): workspace.step_finished(step, 1.0) assert workspace.step_info(step).state == StepState.COMPLETED assert workspace.step_result_for_run(run.name, "float") == 1.0 + + +def test_remove_step(beaker_workspace: str): + beaker_workspace = "ai2/tango_remove_cache_test" + workspace = BeakerWorkspace(beaker_workspace) + step = FloatStep(step_name="float", result=1.0) + + workspace.step_starting(step) + workspace.step_finished(step, 1.0) + + step_info = workspace.step_info(step) + dataset_name = workspace.Constants.step_artifact_name(step_info) + cache = workspace.step_cache + + assert workspace.beaker.dataset.get(dataset_name) is not None + assert step in cache + + workspace.remove_step(step.unique_id) + cache = workspace.step_cache + dataset_name = workspace.Constants.step_artifact_name(step_info) + + with pytest.raises(DatasetNotFound): + workspace.beaker.dataset.get(dataset_name) + assert step not in cache diff --git a/tests/integrations/gs/workspace_test.py b/tests/integrations/gs/workspace_test.py index d34f82e02..ac30e36ab 100644 --- a/tests/integrations/gs/workspace_test.py +++ b/tests/integrations/gs/workspace_test.py @@ -48,3 +48,30 @@ def test_direct_usage(self, gs_path: str): workspace.step_finished(step, 1.0) assert workspace.step_info(step).state == StepState.COMPLETED assert workspace.step_result_for_run(run.name, "float") == 1.0 + + def test_remove_step(self): + workspace = GSWorkspace(GS_BUCKET_NAME) + step = FloatStep(step_name="float", result=1.0) + step_info = workspace.step_info(step) + + workspace.step_starting(step) + workspace.step_finished(step, 1.0) + bucket_artifact = workspace.Constants.step_artifact_name(step_info) + ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id)) + cache = workspace.step_cache + + assert workspace.client.artifacts(prefix=bucket_artifact) is not None + assert ds_entity is not None + assert step in cache + + workspace.remove_step(step.unique_id) + cache = workspace.step_cache + + ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id)) + + with pytest.raises(Exception) as excinfo: + workspace.client.artifacts(prefix=bucket_artifact) + + assert "KeyError" in str(excinfo) + assert ds_entity is None + assert step not in cache diff --git a/tests/workspaces/local_workspace_test.py b/tests/workspaces/local_workspace_test.py index b987a87c3..58c4dacf5 100644 --- a/tests/workspaces/local_workspace_test.py +++ b/tests/workspaces/local_workspace_test.py @@ -1,6 +1,7 @@ from shutil import copytree import pytest +from sqlitedict import SqliteDict from tango import Step from tango.common.testing import TangoTestCase @@ -73,3 +74,23 @@ def test_local_workspace_upgrade_v1_to_v2(self): while len(dependencies) > 0: step_info = workspace.step_info(dependencies.pop()) dependencies.extend(step_info.dependencies) + + def test_remove_step(self): + workspace = LocalWorkspace(self.TEST_DIR) + step = AdditionStep(a=1, b=2) + workspace.step_starting(step) + workspace.step_finished(step, 1.0) + + with SqliteDict(workspace.step_info_file) as d: + assert step.unique_id in d + + cache = workspace.step_cache + assert step in cache + + workspace.remove_step(step.unique_id) + + with SqliteDict(workspace.step_info_file) as d: + assert step.unique_id not in d + + cache = workspace.step_cache + assert step not in cache diff --git a/tests/workspaces/memory_workspace_test.py b/tests/workspaces/memory_workspace_test.py new file mode 100644 index 000000000..41529ee07 --- /dev/null +++ b/tests/workspaces/memory_workspace_test.py @@ -0,0 +1,20 @@ +from tango.common.testing.steps import FloatStep +from tango.workspaces import MemoryWorkspace + + +def test_remove_step(): + workspace = MemoryWorkspace() + step = FloatStep(step_name="float", result=1.0) + + workspace.step_starting(step) + workspace.step_finished(step, 1.0) + cache = workspace.step_cache + + assert step.unique_id in workspace.unique_id_to_info + assert step in cache + + workspace.remove_step(step.unique_id) + cache = workspace.step_cache + + assert step.unique_id not in workspace.unique_id_to_info + assert step not in cache