diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..ef57a8b 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -46,26 +46,35 @@ def copy_to_shm(file: str): yield file return - tmp_dir = "/dev/shm/" - fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) - try: - shutil.copyfile(file, tmp_path) - yield tmp_path - finally: - os.remove(tmp_path) - os.close(fd) + with tempfile.NamedTemporaryFile(dir="/dev/shm", delete=False) as tmp_file: + tmp_path = tmp_file.name + try: + shutil.copyfile(file, tmp_path) + yield tmp_path + finally: + try: + os.remove(tmp_path) + except OSError as e: + # Handle file deletion error gracefully + logger.error(f"Error deleting temporary file: {e}") + raise @contextlib.contextmanager def copy_from_shm(file: str): tmp_dir = "/dev/shm/" - fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) - try: - yield tmp_path - shutil.copyfile(tmp_path, file) - finally: - os.remove(tmp_path) - os.close(fd) + with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp_file: + tmp_path = tmp_file.name + try: + yield tmp_path + shutil.copyfile(tmp_path, file) + finally: + try: + os.remove(tmp_path) + except OSError as e: + # Handle file deletion error gracefully + logger.error(f"Error deleting temporary file: {e}") + raise def fast_unpickle(path: str) -> Any: