diff --git a/pyiron_workflow/snippets/files.py b/pyiron_workflow/snippets/files.py index bb2957a6..b9d6b099 100644 --- a/pyiron_workflow/snippets/files.py +++ b/pyiron_workflow/snippets/files.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pathlib import Path @@ -35,9 +36,48 @@ def categorize_folder_items(folder_path): return results +def _resolve_directory_and_path( + file_name: str, + directory: DirectoryObject | str | None = None, + default_directory: str = ".", +): + """ + Internal routine to separate the file name and the directory in case + file name is given in absolute path etc. + """ + path = Path(file_name) + file_name = path.name + if path.is_absolute(): + if directory is not None: + raise ValueError( + "You cannot set `directory` when `file_name` is an absolute path" + ) + # If absolute path, take that of new_file_name regardless of the + # name of directory + directory = str(path.parent) + else: + if directory is None: + # If directory is not given, take default directory + directory = default_directory + else: + # If the directory is given, use it as the main path and append + # additional path if given in new_file_name + if isinstance(directory, DirectoryObject): + directory = directory.path + directory = directory / path.parent + if not isinstance(directory, DirectoryObject): + directory = DirectoryObject(directory) + return file_name, directory + + class DirectoryObject: - def __init__(self, directory): - self.path = Path(directory) + def __init__(self, directory: str | Path | DirectoryObject): + if isinstance(directory, str): + self.path = Path(directory) + elif isinstance(directory, Path): + self.path = directory + elif isinstance(directory, DirectoryObject): + self.path = directory.path self.create() def create(self): @@ -83,9 +123,10 @@ def remove_files(self, *files: str): class FileObject: - def __init__(self, file_name: str, directory: DirectoryObject): - self.directory = directory - self._file_name = file_name + def __init__(self, file_name: str, directory: DirectoryObject = None): + self._file_name, self.directory = _resolve_directory_and_path( + file_name=file_name, directory=directory, default_directory="." + ) @property def file_name(self): diff --git a/tests/unit/snippets/test_files.py b/tests/unit/snippets/test_files.py index 93d8e90e..32b28ac1 100644 --- a/tests/unit/snippets/test_files.py +++ b/tests/unit/snippets/test_files.py @@ -1,6 +1,7 @@ import unittest from pyiron_workflow.snippets.files import DirectoryObject, FileObject from pathlib import Path +import platform class TestFiles(unittest.TestCase): @@ -10,6 +11,30 @@ def setUp(cls): def tearDown(cls): cls.directory.delete() + def test_directory_instantiation(self): + directory = DirectoryObject(Path("test")) + self.assertEqual(directory.path, self.directory.path) + directory = DirectoryObject(self.directory) + self.assertEqual(directory.path, self.directory.path) + + def test_file_instantiation(self): + self.assertEqual( + FileObject("test.txt", self.directory).path, + FileObject("test.txt", "test").path, + msg="DirectoryObject and str must give the same object" + ) + self.assertEqual( + FileObject("test/test.txt").path, + FileObject("test.txt", "test").path, + msg="File path not same as directory path" + ) + + if platform.system() == "Windows": + self.assertRaises(ValueError, FileObject, "C:\\test.txt", "test") + else: + self.assertRaises(ValueError, FileObject, "/test.txt", "test") + + def test_directory_exists(self): self.assertTrue(Path("test").exists() and Path("test").is_dir())