From 601fa2dcfc7e698719eb379ad35bd8de372db9a5 Mon Sep 17 00:00:00 2001 From: Zachary Lentz Date: Wed, 20 Mar 2024 14:36:30 -0700 Subject: [PATCH] FIX: learn how to handle cwd/cd for put/get connections --- pmpsdb_client/ssh_data.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pmpsdb_client/ssh_data.py b/pmpsdb_client/ssh_data.py index ef0c1e2..3335b4f 100644 --- a/pmpsdb_client/ssh_data.py +++ b/pmpsdb_client/ssh_data.py @@ -11,6 +11,7 @@ from contextlib import contextmanager from dataclasses import dataclass from io import StringIO +from pathlib import Path from typing import Iterator, TypeVar from fabric import Connection @@ -68,6 +69,8 @@ def ssh( result = conn.run(f"mkdir -p {directory}") if result.exited != 0: raise RuntimeError(f"Failed to create directory {directory}") + # Note: conn.cd only affects calls to conn.run, not conn.get or conn.put + # Use conn.cwd property to check this live with conn.cd(directory): yield conn if not connected: @@ -168,7 +171,9 @@ def upload_filename( if dest_filename is None: dest_filename = filename with ssh(hostname=hostname, directory=directory) as conn: - conn.put(local=filename, remote=dest_filename) + if directory is None: + directory = conn.cwd + conn.put(local=filename, remote=str(Path(directory) / dest_filename)) def download_file_text( @@ -200,5 +205,7 @@ def download_file_text( logger.debug("download_file_text(%s, %s, %s)", hostname, filename, directory) stringio = StringIO() with ssh(hostname=hostname, directory=directory) as conn: - conn.get(remote=filename, local=stringio) + if directory is None: + directory = conn.cwd + conn.get(remote=str(Path(directory) / filename), local=stringio) return stringio.getvalue()