diff --git a/pyproject.toml b/pyproject.toml index 3472ea9c8..0e2df72c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,10 +32,10 @@ dependencies = [ "pyarrow>=10.0, <10.1.0", # numpy>=2.0 is not compatible with the old pyarrow v10.x. "numpy>=1.23.4, <2.0", + "paramiko==2.11.0", "defusedxml>=0.7.1", "aiohttp>=3.10.5", - "pytest-mock>=3.14.0", -] + ] requires-python = ">=3.10" readme = "README.md" @@ -84,6 +84,7 @@ dev-dependencies = [ "ruff>=0.6.6", "pytest-asyncio>=0.23.8", "moto>=5.0.13", + "pytest-mock>=3.14.0", ] [tool.pytest.ini_options] diff --git a/requirements-dev.lock b/requirements-dev.lock index e74fc29b2..d4e0a3138 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -42,6 +42,8 @@ attrs==24.2.0 # via visions babel==2.16.0 # via mkdocs-material +bcrypt==4.2.0 + # via paramiko beautifulsoup4==4.12.3 # via mkdocs-mermaid2-plugin # via nbconvert @@ -73,6 +75,7 @@ cffi==1.17.0 # via cairocffi # via cryptography # via pygit2 + # via pynacl charset-normalizer==3.3.2 # via requests click==8.1.7 @@ -97,6 +100,7 @@ croniter==2.0.7 # via prefect cryptography==43.0.0 # via moto + # via paramiko # via prefect cssselect2==0.7.0 # via cairosvg @@ -362,6 +366,8 @@ pandas==2.2.2 # via visions pandocfilters==1.5.1 # via nbconvert +paramiko==2.11.0 + # via viadot2 parso==0.8.4 # via jedi pathspec==0.12.1 @@ -426,6 +432,8 @@ pymdown-extensions==10.9 # via mkdocs-material # via mkdocs-mermaid2-plugin # via mkdocstrings +pynacl==1.5.0 + # via paramiko pyodbc==5.1.0 # via viadot2 pyparsing==3.1.2 @@ -553,6 +561,7 @@ six==1.16.0 # via bleach # via jsbeautifier # via kubernetes + # via paramiko # via python-dateutil # via rfc3339-validator smmap==5.0.1 diff --git a/requirements.lock b/requirements.lock index 1f2dc9fca..8fbaf36af 100644 --- a/requirements.lock +++ b/requirements.lock @@ -38,6 +38,8 @@ attrs==24.2.0 # via jsonschema # via referencing # via visions +bcrypt==4.2.0 + # via paramiko beautifulsoup4==4.12.3 # via o365 cachetools==5.5.0 @@ -52,6 +54,7 @@ certifi==2024.7.4 cffi==1.17.0 # via cryptography # via pygit2 + # via pynacl charset-normalizer==3.3.2 # via requests click==8.1.7 @@ -68,6 +71,7 @@ coolname==2.2.0 croniter==2.0.7 # via prefect cryptography==43.0.0 + # via paramiko # via prefect dateparser==1.2.0 # via prefect @@ -192,6 +196,8 @@ packaging==24.1 pandas==2.2.2 # via viadot2 # via visions +paramiko==2.11.0 + # via viadot2 pathspec==0.12.1 # via prefect pendulum==2.1.2 @@ -229,6 +235,8 @@ pygit2==1.14.1 # via viadot2 pygments==2.18.0 # via rich +pynacl==1.5.0 + # via paramiko pyodbc==5.1.0 # via viadot2 pytest==8.3.3 @@ -315,6 +323,7 @@ shellingham==1.5.4 # via typer six==1.16.0 # via kubernetes + # via paramiko # via python-dateutil # via rfc3339-validator sniffio==1.3.1 diff --git a/src/viadot/orchestration/prefect/flows/__init__.py b/src/viadot/orchestration/prefect/flows/__init__.py index 8d47f2e29..a5c8168d6 100644 --- a/src/viadot/orchestration/prefect/flows/__init__.py +++ b/src/viadot/orchestration/prefect/flows/__init__.py @@ -16,6 +16,7 @@ from .outlook_to_adls import outlook_to_adls from .sap_to_parquet import sap_to_parquet from .sap_to_redshift_spectrum import sap_to_redshift_spectrum +from .sftp_to_adls import sftp_to_adls from .sharepoint_to_adls import sharepoint_to_adls from .sharepoint_to_databricks import sharepoint_to_databricks from .sharepoint_to_redshift_spectrum import sharepoint_to_redshift_spectrum @@ -44,6 +45,7 @@ "outlook_to_adls", "sap_to_parquet", "sap_to_redshift_spectrum", + "sftp_to_adls", "sharepoint_to_adls", "sharepoint_to_databricks", "sharepoint_to_redshift_spectrum", diff --git a/src/viadot/orchestration/prefect/flows/sftp_to_adls.py b/src/viadot/orchestration/prefect/flows/sftp_to_adls.py new file mode 100644 index 000000000..f0c6f71e0 --- /dev/null +++ b/src/viadot/orchestration/prefect/flows/sftp_to_adls.py @@ -0,0 +1,64 @@ +"""Download data from a SFTP server to Azure Data Lake Storage.""" + +from prefect import flow +from prefect.task_runners import ConcurrentTaskRunner + +from viadot.orchestration.prefect.tasks import df_to_adls, sftp_to_df + + +@flow( + name="SFTP extraction to ADLS", + description="Extract data from a SFTP server and " + + "load it into Azure Data Lake Storage.", + retries=1, + retry_delay_seconds=60, + task_runner=ConcurrentTaskRunner, +) +def sftp_to_adls( + config_key: str | None = None, + azure_key_vault_secret: str | None = None, + file_name: str | None = None, + sep: str = "\t", + columns: list[str] | None = None, + adls_config_key: str | None = None, + adls_azure_key_vault_secret: str | None = None, + adls_path: str | None = None, + adls_path_overwrite: bool = False, +) -> None: + r"""Flow to download data from a SFTP server to Azure Data Lake. + + Args: + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (str, optional): The name of the Azure Key Vault secret + where credentials are stored. Defaults to None. + file_name (str, optional): Path to the file in SFTP server. Defaults to None. + sep (str, optional): The separator to use to read the CSV file. + Defaults to "\t". + columns (List[str], optional): Columns to read from the file. Defaults to None. + adls_config_key (str, optional): The key in the viadot config holding + relevant credentials. Defaults to None. + adls_azure_key_vault_secret (str, optional): The name of the Azure Key + Vault secret containing a dictionary with ACCOUNT_NAME and Service Principal + credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake. + Defaults to None. + adls_path (str, optional): Azure Data Lake destination file path + (with file name). Defaults to None. + adls_path_overwrite (bool, optional): Whether to overwrite the file in ADLS. + Defaults to True. + """ + data_frame = sftp_to_df( + config_key=config_key, + azure_key_vault_secret=azure_key_vault_secret, + file_name=file_name, + sep=sep, + columns=columns, + ) + + return df_to_adls( + df=data_frame, + path=adls_path, + credentials_secret=adls_azure_key_vault_secret, + config_key=adls_config_key, + overwrite=adls_path_overwrite, + ) diff --git a/src/viadot/orchestration/prefect/tasks/__init__.py b/src/viadot/orchestration/prefect/tasks/__init__.py index 5dd3d2ebb..d3f69b75c 100644 --- a/src/viadot/orchestration/prefect/tasks/__init__.py +++ b/src/viadot/orchestration/prefect/tasks/__init__.py @@ -19,6 +19,7 @@ from .redshift_spectrum import df_to_redshift_spectrum from .s3 import s3_upload_file from .sap_rfc import sap_rfc_to_df +from .sftp import sftp_list, sftp_to_df from .sharepoint import sharepoint_download_file, sharepoint_to_df from .sql_server import create_sql_server_table, sql_server_query, sql_server_to_df from .supermetrics import supermetrics_to_df @@ -47,6 +48,8 @@ "outlook_to_df", "s3_upload_file", "sap_rfc_to_df", + "sftp_list", + "sftp_to_df", "sharepoint_download_file", "sharepoint_to_df", "sql_server_query", diff --git a/src/viadot/orchestration/prefect/tasks/sftp.py b/src/viadot/orchestration/prefect/tasks/sftp.py new file mode 100644 index 000000000..774684489 --- /dev/null +++ b/src/viadot/orchestration/prefect/tasks/sftp.py @@ -0,0 +1,88 @@ +"""Tasks from SFTP API.""" + +import pandas as pd +from prefect import task + +from viadot.orchestration.prefect.exceptions import MissingSourceCredentialsError +from viadot.orchestration.prefect.utils import get_credentials +from viadot.sources import Sftp + + +@task(retries=3, log_prints=True, retry_delay_seconds=10, timeout_seconds=60 * 60) +def sftp_to_df( + config_key: str | None = None, + azure_key_vault_secret: str | None = None, + file_name: str | None = None, + sep: str = "\t", + columns: list[str] | None = None, +) -> pd.DataFrame: + r"""Querying SFTP server and saving data as the data frame. + + Args: + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (str, optional): The name of the Azure Key Vault secret + where credentials are stored. Defaults to None. + file_name (str, optional): Path to the file in SFTP server. Defaults to None. + sep (str, optional): The separator to use to read the CSV file. + Defaults to "\t". + columns (List[str], optional): Columns to read from the file. Defaults to None. + + Returns: + pd.DataFrame: The response data as a pandas DataFrame. + """ + if not (azure_key_vault_secret or config_key): + raise MissingSourceCredentialsError + + if not config_key: + credentials = get_credentials(azure_key_vault_secret) + + sftp = Sftp( + credentials=credentials, + config_key=config_key, + ) + sftp.get_connection() + + return sftp.to_df(file_name=file_name, sep=sep, columns=columns) + + +@task(retries=3, log_prints=True, retry_delay_seconds=10, timeout_seconds=60 * 60) +def sftp_list( + config_key: str | None = None, + azure_key_vault_secret: str | None = None, + path: str | None = None, + recursive: bool = False, + matching_path: str | None = None, +) -> list[str]: + """Listing files in the SFTP server. + + Args: + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (str, optional): The name of the Azure Key Vault secret + where credentials are stored. Defaults to None. + path (str, optional): Destination path from where to get the structure. + Defaults to None. + recursive (bool, optional): Get the structure in deeper folders. + Defaults to False. + matching_path (str, optional): Filtering folders to return by a regex pattern. + Defaults to None. + + Returns: + files_list (list[str]): List of files in the SFTP server. + """ + if not (azure_key_vault_secret or config_key): + raise MissingSourceCredentialsError + + if not config_key: + credentials = get_credentials(azure_key_vault_secret) + + sftp = Sftp( + credentials=credentials, + config_key=config_key, + ) + sftp.get_connection() + + return sftp.get_files_list( + path=path, recursive=recursive, matching_path=matching_path + ) diff --git a/src/viadot/sources/__init__.py b/src/viadot/sources/__init__.py index fe5417a1e..5644114a4 100644 --- a/src/viadot/sources/__init__.py +++ b/src/viadot/sources/__init__.py @@ -12,6 +12,7 @@ from .hubspot import Hubspot from .mindful import Mindful from .outlook import Outlook +from .sftp import Sftp from .sharepoint import Sharepoint from .sql_server import SQLServer from .supermetrics import Supermetrics, SupermetricsCredentials @@ -26,6 +27,7 @@ "Genesys", "Hubspot", "Mindful", + "Sftp", "Outlook", "SQLServer", "Sharepoint", diff --git a/src/viadot/sources/sftp.py b/src/viadot/sources/sftp.py new file mode 100644 index 000000000..30a3b189c --- /dev/null +++ b/src/viadot/sources/sftp.py @@ -0,0 +1,268 @@ +"""SFTP connector.""" + +from io import BytesIO, StringIO +from pathlib import Path +import re +from stat import S_ISDIR, S_ISREG +import time +from typing import Literal + +import pandas as pd +import paramiko +from paramiko.sftp import SFTPError +from pydantic import BaseModel + +from viadot.config import get_source_credentials +from viadot.exceptions import CredentialError +from viadot.sources.base import Source +from viadot.utils import add_viadot_metadata_columns + + +class SftpCredentials(BaseModel): + """Checking for values in SFTP credentials dictionary. + + Two key values are held in the Salesforce connector: + - hostname: IP address of the SFTP server.. + - username: The user name for SFTP connection. + - password: The passwrod for SFTP connection. + - port: The port to use for the connection. + - rsa_key: The SSH key to use for the connection. Only RSA is currently + supported. + + Args: + BaseModel (pydantic.main.ModelMetaclass): A base class for creating + Pydantic models. + """ + + hostname: str + username: str + password: str + port: int + rsa_key: str | None + + +class Sftp(Source): + """Class implementing a SFTP server connection.""" + + def __init__( + self, + *args, + credentials: SftpCredentials | None = None, + config_key: str = "sftp", + **kwargs, + ): + """Create an instance of SFTP. + + Args: + credentials (SftpCredentials, optional): SFTP credentials. Defaults to None. + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to "sftp". + + Notes: + self.conn is paramiko.SFTPClient.from_transport method that contains + additional methods like get, put, open etc. Some of them were not + implemented in that class. For more check documentation + (https://docs.paramiko.org/en/stable/api/sftp.html). + + sftp = Sftp() + sftp.conn.open(filename='folder_a/my_file.zip', mode='r') + + Raises: + CredentialError: If credentials are not provided in viadot config or + directly as a parameter. + """ + credentials = credentials or get_source_credentials(config_key) + + if credentials is None: + message = "Missing credentials." + raise CredentialError(message) + + validated_creds = dict(SftpCredentials(**credentials)) + super().__init__(*args, credentials=validated_creds, **kwargs) + + self.conn = None + self.hostname = validated_creds.get("hostname") + self.username = validated_creds.get("username") + self.password = validated_creds.get("password") + self.port = validated_creds.get("port") + self.rsa_key = validated_creds.get("rsa_key") + + def _get_file_object(self, file_name: str) -> BytesIO: + """Copy a remote file from the SFTP server and write to a file-like object. + + Args: + file_name (str, optional): File name to copy. + + Returns: + BytesIO: file-like object. + """ + file_object = BytesIO() + try: + self.conn.getfo(file_name, file_object) + + except FileNotFoundError as error: + raise SFTPError from error + + else: + return file_object + + def get_connection(self) -> paramiko.SFTPClient: + """Returns a SFTP connection object. + + Returns: paramiko.SFTPClient. + """ + ssh = paramiko.SSHClient() + + if not self.rsa_key: + transport = paramiko.Transport((self.hostname, self.port)) + transport.connect(None, self.username, self.password) + + self.conn = paramiko.SFTPClient.from_transport(transport) + + else: + mykey = paramiko.RSAKey.from_private_key(StringIO(self.rsa_key)) + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # noqa: S507 + ssh.connect(self.hostname, username=self.username, pkey=mykey) + time.sleep(1) + self.conn = ssh.open_sftp() + + self.logger.info("Connected to the SFTP server.") + + @add_viadot_metadata_columns + def to_df( + self, + if_empty: Literal["warn", "skip", "fail"] = "warn", + file_name: str | None = None, + sep: str = "\t", + columns: list[str] | None = None, + ) -> pd.DataFrame: + r"""Copy a remote file from the SFTP server and write it to Pandas dataframe. + + Args: + if_empty (Literal["warn", "skip", "fail"], optional): What to do if + the fetch produces no data. Defaults to "warn". + file_name (str, optional): The name of the file to download. + sep (str, optional): The delimiter for the source file. Defaults to "\t". + columns (list[str], optional): List of columns to select from file. + Defaults to None. + + Returns: + pd.DataFrame: The response data as a Pandas Data Frame plus viadot metadata. + """ + byte_file = self._get_file_object(file_name=file_name) + byte_file.seek(0) + + self._close_conn() + + suffix = Path(file_name).suffix + if suffix == ".csv": + df = pd.read_csv(byte_file, sep=sep, usecols=columns) + + elif suffix == ".parquet": + df = pd.read_parquet(byte_file, usecols=columns) + + elif suffix == ".tsv": + df = pd.read_csv(byte_file, sep=sep, usecols=columns) + + elif suffix in [".xls", ".xlsx", ".xlsm"]: + df = pd.read_excel(byte_file, usecols=columns) + + elif suffix == ".json": + df = pd.read_json(byte_file) + + elif suffix == ".pkl": + df = pd.read_pickle(byte_file) # noqa: S301 + + elif suffix == ".sql": + df = pd.read_sql(byte_file) + + elif suffix == ".hdf": + df = pd.read_hdf(byte_file) + + else: + message = ( + f"Unable to read file '{Path(file_name).name}', " + + f"unsupported filetype: {suffix}" + ) + raise ValueError(message) + + if df.empty: + self._handle_if_empty( + if_empty=if_empty, + message="The response does not contain any data.", + ) + else: + self.logger.info("Successfully downloaded data from the SFTP server.") + + return df + + def _ls(self, path: str | None = ".", recursive: bool = False) -> list[str]: + """List files in specified directory, with optional recursion. + + Args: + path (str | None): Full path to the remote directory to list. + Defaults to ".". + recursive (bool): Whether to list files recursively. Defaults to False. + + Returns: + list[str]: List of files in the specified directory. + """ + files_list = [] + + path = "." if path is None else path + try: + if not recursive: + return [ + str(Path(path) / attr.filename) + for attr in self.conn.listdir_attr(path) + if S_ISREG(attr.st_mode) + ] + + for attr in self.conn.listdir_attr(path): + full_path = str(Path(path) / attr.filename) + if S_ISDIR(attr.st_mode): + files_list.extend(self._ls(full_path, recursive=True)) + else: + files_list.append(full_path) + except FileNotFoundError as e: + self.logger.info(f"Directory not found: {path}. Error: {e}") + except Exception as e: + self.logger.info(f"Error accessing {path}: {e}") + + return files_list + + def get_files_list( + self, + path: str | None = None, + recursive: bool = False, + matching_path: str | None = None, + ) -> list[str]: + """List files in `path`. + + Args: + path (str | None): Destination path from where to get the structure. + Defaults to None. + recursive (bool): Get the structure in deeper folders. + Defaults to False. + matching_path (str | None): Filtering folders to return by a regex + pattern. Defaults to None. + + Returns: + list[str]: List of files in the specified path. + """ + files_list = self._ls(path=path, recursive=recursive) + + if matching_path is not None: + files_list = [f for f in files_list if re.match(matching_path, f)] + + self._close_conn() + + self.logger.info("Successfully loaded file list from SFTP server.") + + return files_list + + def _close_conn(self) -> None: + """Close the SFTP server connection.""" + if self.conn is not None: + self.conn.close() + self.conn = None diff --git a/tests/unit/test_sftp.py b/tests/unit/test_sftp.py new file mode 100644 index 000000000..219286e8e --- /dev/null +++ b/tests/unit/test_sftp.py @@ -0,0 +1,232 @@ +"""'test_sftp.py'.""" + +from io import BytesIO, StringIO +import json + +import pandas as pd +import pytest + +from viadot.exceptions import CredentialError +from viadot.sources import Sftp +from viadot.sources.sftp import SftpCredentials + + +variables = { + "credentials": { + "hostname": "", + "username": "test_user", + "password": "test_password", # pragma: allowlist secret + "port": 999, + "rsa_key": "", + }, +} + + +class TestSftpCredentials: + """Test SFTP Credentials Class.""" + + @pytest.mark.basic + def test_sftp_credentials(self): + """Test SFTP credentials.""" + SftpCredentials( + hostname=variables["credentials"]["hostname"], + username=variables["credentials"]["username"], + password=variables["credentials"]["password"], + port=variables["credentials"]["port"], + rsa_key=variables["credentials"]["rsa_key"], + ) + + +@pytest.mark.basic +def test_sftp_connector_initialization_without_credentials(): + """Test SFTP server without credentials.""" + with pytest.raises(CredentialError, match="Missing credentials."): + Sftp(credentials=None) + + +@pytest.mark.connect +def test_get_connection_without_rsa_key(mocker): + """Test `get_connection()` method without specifying the RSA key.""" + mock_transport = mocker.patch("viadot.sources.sftp.paramiko.Transport") + mock_sftp_client = mocker.patch( + "viadot.sources.sftp.paramiko.SFTPClient.from_transport" + ) + + connector = Sftp(credentials=variables["credentials"]) + connector.get_connection() + + mock_transport.assert_called_once_with((variables["credentials"]["hostname"], 999)) + mock_transport().connect.assert_called_once_with( + None, variables["credentials"]["username"], variables["credentials"]["password"] + ) + mock_sftp_client.assert_called_once() + + +@pytest.mark.connect +def test_get_connection_with_rsa_key(mocker): + """Test SFTP `get_connection` method with ras_key.""" + mock_ssh_client = mocker.patch("viadot.sources.sftp.paramiko.SSHClient") + mock_ssh_instance = mock_ssh_client.return_value + mock_ssh_connect = mocker.patch.object(mock_ssh_instance, "connect") + mock_rsa_key = mocker.patch( + "viadot.sources.sftp.paramiko.RSAKey.from_private_key", + return_value=mocker.Mock(), + ) + mock_transport = mocker.Mock() + + # Ensure the SSHClient's transport attribute isn't None + mock_ssh_instance._transport = mock_transport + + dummy_rsa_key = "test_rsa_key" + keyfile = StringIO(dummy_rsa_key) + credentials = variables["credentials"] + credentials["rsa_key"] = keyfile.getvalue() + + connector = Sftp(credentials=variables["credentials"]) + connector.get_connection() + + mock_rsa_key.assert_called_once() + mock_ssh_connect.assert_called_once_with( + "", username="test_user", pkey=mock_rsa_key.return_value + ) + assert connector.conn is not None + + +@pytest.mark.functions +def test_to_df_with_csv(mocker): + """Test SFTP `to_df` method with csv.""" + mock_get_file_object = mocker.patch.object(Sftp, "_get_file_object", autospec=True) + mock_get_file_object.return_value = BytesIO(b"col1,col2\n1,2\n3,4\n") + + connector = Sftp(credentials=variables["credentials"]) + df = connector.to_df(file_name="test.csv", sep=",") + + assert isinstance(df, pd.DataFrame) + assert df.shape == (2, 4) + assert list(df.columns) == [ + "col1", + "col2", + "_viadot_source", + "_viadot_downloaded_at_utc", + ] + + +@pytest.mark.functions +def test_to_df_with_json(mocker): + """Test SFTP `to_df` method with json.""" + mock_get_file_object = mocker.patch.object(Sftp, "_get_file_object", autospec=True) + json_data = json.dumps({"col1": [1, 3], "col2": [2, 4]}) + mock_get_file_object.return_value = BytesIO(json_data.encode("utf-8")) + + connector = Sftp(credentials=variables["credentials"]) + df = connector.to_df(file_name="test.json") + + expected_df = pd.DataFrame({"col1": [1, 3], "col2": [2, 4]}) + expected_df["_viadot_source"] = "Sftp" + expected_df["_viadot_downloaded_at_utc"] = pd.Timestamp.now() + + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == [ + "col1", + "col2", + "_viadot_source", + "_viadot_downloaded_at_utc", + ] + + +@pytest.mark.functions +def test_to_df_unsupported_file_type(mocker): + """Test raising ValueError for unsupported file types.""" + mocker.patch.object(Sftp, "_get_file_object", return_value=BytesIO(b"dummy data")) + sftp = Sftp(credentials=variables["credentials"]) + + with pytest.raises(ValueError, match="Unable to read file"): + sftp.to_df(file_name="test.txt") + + +@pytest.mark.functions +def test_to_df_empty_dataframe_warn(mocker, caplog): + """Test handling of empty DataFrame with 'warn' option.""" + mocker.patch.object(Sftp, "_get_file_object", return_value=BytesIO(b"column")) + sftp = Sftp(credentials=variables["credentials"]) + + with caplog.at_level("INFO"): + sftp.to_df(file_name="test.csv", if_empty="warn") + assert "The response does not contain any" in caplog.text + + +@pytest.mark.functions +def test_ls(mocker): + """Test SFTP `_ls` method.""" + mock_sftp = mocker.MagicMock() + + mock_sftp.listdir_attr.side_effect = [ + [ + mocker.MagicMock(st_mode=0o40755, filename="folder_a"), + mocker.MagicMock(st_mode=0o100644, filename="file1.txt"), + mocker.MagicMock(st_mode=0o100644, filename="file2.txt"), + ], + [ + mocker.MagicMock(st_mode=0o40755, filename="folder_b"), + mocker.MagicMock(st_mode=0o100644, filename="file3.txt"), + ], + ] + + sftp = Sftp(credentials=variables["credentials"]) + sftp.conn = mock_sftp + + files_list = sftp._ls(path=".", recursive=False) + + assert len(files_list) == 2 + assert files_list == ["file1.txt", "file2.txt"] + + +@pytest.mark.functions +def test_recursive_ls(mocker): + """Test SFTP recursive `_ls` method.""" + mock_sftp = mocker.MagicMock() + mock_sftp.listdir_attr.side_effect = [ + [ + mocker.MagicMock(st_mode=0o40755, filename="folder_a"), + mocker.MagicMock(st_mode=0o100644, filename="file1.txt"), + mocker.MagicMock(st_mode=0o100644, filename="file2.txt"), + ], + [ + mocker.MagicMock(st_mode=0o40755, filename="folder_b"), + mocker.MagicMock(st_mode=0o100644, filename="file3.txt"), + ], + ] + + sftp = Sftp(credentials=variables["credentials"]) + sftp.conn = mock_sftp + + files_list = sftp._ls(path=".", recursive=True) + + assert len(files_list) == 3 + assert files_list == ["folder_a/file3.txt", "file1.txt", "file2.txt"] + + +@pytest.mark.functions +def test_get_files_list(mocker): + """Test SFTP `get_files_list` method.""" + mock_list_directory = mocker.patch.object( + Sftp, "_ls", return_value=["file1.txt", "file2.txt"] + ) + + connector = Sftp(credentials=variables["credentials"]) + files = connector.get_files_list(path="test_path", recursive=False) + + assert files == ["file1.txt", "file2.txt"] + mock_list_directory.assert_called_once_with(path="test_path", recursive=False) + + +@pytest.mark.functions +def test_close_conn(mocker): + """Test SFTP `_close_conn` method.""" + mock_conn = mocker.MagicMock() + connector = Sftp(credentials=variables["credentials"]) + connector.conn = mock_conn + connector._close_conn() + + mock_conn.close.assert_called_once() + assert connector.conn is None