diff --git a/docs/changelog/next_release/326.breaking.rst b/docs/changelog/next_release/326.breaking.rst new file mode 100644 index 00000000..ee4962a3 --- /dev/null +++ b/docs/changelog/next_release/326.breaking.rst @@ -0,0 +1 @@ +Change the logic of ``FileConnection.walk`` to exclude file what returned ``True`` from ``limit.stops_at(path)``. diff --git a/onetl/base/base_file_connection.py b/onetl/base/base_file_connection.py index 28949499..70e96d34 100644 --- a/onetl/base/base_file_connection.py +++ b/onetl/base/base_file_connection.py @@ -435,11 +435,11 @@ def walk( If ``True``, walk in top-down order, otherwise walk in bottom-up order. filters : list of :obj:`BaseFileFilter `, optional - Return only files/directories matching these filters. See :ref:`file-filters` + Return only files/directories matching these filters. See :ref:`file-filters`. limits : list of :obj:`BaseFileLimit `, optional - Apply limits to the list of files/directories, and stop if one of the limits is reached. - See :ref:`file-limits` + Apply limits to the list of files/directories, and immediately stop if any of these limits is reached. + See :ref:`file-limits`. Returns ------- diff --git a/onetl/connection/file_connection/file_connection.py b/onetl/connection/file_connection/file_connection.py index c85eead4..180b9b44 100644 --- a/onetl/connection/file_connection/file_connection.py +++ b/onetl/connection/file_connection/file_connection.py @@ -415,6 +415,9 @@ def list_dir( limits = reset_limits(limits or []) for entry in self._scan_entries(remote_dir): + if limits_reached(limits): + break + name = self._extract_name_from_entry(entry) stat = self._extract_stat_from_entry(remote_dir, entry) @@ -423,12 +426,9 @@ def list_dir( else: path = RemoteFile(path=name, stats=stat) - if match_all_filters(path, filters): + if match_all_filters(path, filters) and not limits_stop_at(path, limits): result.append(path) - if limits_stop_at(path, limits): - break - return result @slot @@ -491,6 +491,9 @@ def _walk( # noqa: WPS231 dirs, files = [], [] for entry in self._scan_entries(root): + if limits_reached(limits): + break + name = self._extract_name_from_entry(entry) stat = self._extract_stat_from_entry(root, entry) @@ -499,21 +502,15 @@ def _walk( # noqa: WPS231 yield from self._walk(root=root / name, topdown=topdown, filters=filters, limits=limits) path = RemoteDirectory(path=root / name, stats=stat) - if match_all_filters(path, filters): + if match_all_filters(path, filters) and not limits_stop_at(path, limits): dirs.append(RemoteDirectory(path=name, stats=stat)) - - if limits_stop_at(path, limits): - break else: path = RemoteFile(path=root / name, stats=stat) - if match_all_filters(path, filters): + if match_all_filters(path, filters) and not limits_stop_at(path, limits): files.append(RemoteFile(path=name, stats=stat)) - if limits_stop_at(path, limits): - break - - if topdown: + if topdown and not limits_reached(limits): for name in dirs: yield from self._walk(root=root / name, topdown=topdown, filters=filters, limits=limits) diff --git a/onetl/core/file_limit/file_limit.py b/onetl/core/file_limit/file_limit.py index de82dafe..c235547e 100644 --- a/onetl/core/file_limit/file_limit.py +++ b/onetl/core/file_limit/file_limit.py @@ -67,7 +67,7 @@ def stops_at(self, path: PathProtocol) -> bool: @property def is_reached(self) -> bool: - return self._counter >= self.count_limit + return self._counter > self.count_limit @validator("count_limit") def _deprecated(cls, value): diff --git a/onetl/file/limit/max_files_count.py b/onetl/file/limit/max_files_count.py index 4e930933..8b18e841 100644 --- a/onetl/file/limit/max_files_count.py +++ b/onetl/file/limit/max_files_count.py @@ -76,4 +76,4 @@ def stops_at(self, path: PathProtocol) -> bool: @property def is_reached(self) -> bool: - return self._handled >= self.limit + return self._handled > self.limit diff --git a/onetl/file/limit/total_files_size.py b/onetl/file/limit/total_files_size.py index 5c923747..980001f4 100644 --- a/onetl/file/limit/total_files_size.py +++ b/onetl/file/limit/total_files_size.py @@ -85,4 +85,4 @@ def stops_at(self, path: PathProtocol) -> bool: @property def is_reached(self) -> bool: - return self._handled >= self.limit + return self._handled > self.limit diff --git a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py index f3f594e1..a7684196 100644 --- a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py @@ -790,7 +790,7 @@ def finalizer(): downloader.run([not_a_file]) -def test_file_downloader_with_file_limit(file_connection_with_path_and_files, tmp_path_factory, caplog): +def test_file_downloader_with_limit(file_connection_with_path_and_files, tmp_path_factory, caplog): file_connection, remote_path, _ = file_connection_with_path_and_files limit = 2 local_path = tmp_path_factory.mktemp("local_path") @@ -814,7 +814,7 @@ def test_file_downloader_with_file_limit(file_connection_with_path_and_files, tm assert len(download_result.successful) == limit -def test_file_downloader_file_limit_is_ignored_by_user_input( +def test_file_downloader_limit_is_ignored_by_user_input( file_connection_with_path_and_files, tmp_path_factory, ): diff --git a/tests/tests_unit/test_file/test_limit/test_max_files_count.py b/tests/tests_unit/test_file/test_limit/test_max_files_count.py index 19e6ae1b..71444ee6 100644 --- a/tests/tests_unit/test_file/test_limit/test_max_files_count.py +++ b/tests/tests_unit/test_file/test_limit/test_max_files_count.py @@ -36,10 +36,10 @@ def test_max_files_count(): assert not limit.stops_at(directory) assert not limit.is_reached - # limit is reached - all check are True, input does not matter - assert limit.stops_at(file3) - assert limit.is_reached + assert not limit.stops_at(file3) + assert not limit.is_reached + # limit is reached - all check are True, input does not matter assert limit.stops_at(file4) assert limit.is_reached @@ -56,5 +56,8 @@ def test_max_files_count(): assert not limit.stops_at(file1) assert not limit.is_reached - assert limit.stops_at(file1) + assert not limit.stops_at(file3) + assert not limit.is_reached + + assert limit.stops_at(file4) assert limit.is_reached diff --git a/tests/tests_unit/test_file/test_limit/test_total_files_size.py b/tests/tests_unit/test_file/test_limit/test_total_files_size.py index 26b6798a..75f92d60 100644 --- a/tests/tests_unit/test_file/test_limit/test_total_files_size.py +++ b/tests/tests_unit/test_file/test_limit/test_total_files_size.py @@ -77,5 +77,5 @@ def test_total_files_size(): assert not limit.stops_at(file1) assert not limit.is_reached - assert limit.stops_at(file1) + assert limit.stops_at(file3) assert limit.is_reached