Skip to content

Commit cade3f4

Browse files
committed
refactor(ingest/s3): enhance readability
- Refactor S3Source().get_folder_info() to enhance readability - Add a test to ensure that get_folder_info() returns the expected result.
1 parent d180544 commit cade3f4

File tree

2 files changed

+72
-42
lines changed

2 files changed

+72
-42
lines changed

metadata-ingestion/src/datahub/ingestion/source/s3/source.py

+19-40
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def get_folder_info(
847847
path_spec: PathSpec,
848848
bucket: "Bucket",
849849
prefix: str,
850-
) -> List[Folder]:
850+
) -> Iterable[Folder]:
851851
"""
852852
Retrieves all the folders in a path by listing all the files in the prefix.
853853
If the prefix is a full path then only that folder will be extracted.
@@ -877,51 +877,30 @@ def _is_allowed_path(path_spec_: PathSpec, s3_uri: str) -> bool:
877877
s3_objects = (
878878
obj
879879
for obj in bucket.objects.filter(Prefix=prefix).page_size(PAGE_SIZE)
880-
if _is_allowed_path(path_spec, f"s3://{obj.bucket_name}/{obj.key}")
880+
if _is_allowed_path(
881+
path_spec, self.create_s3_path(obj.bucket_name, obj.key)
882+
)
881883
)
882-
883-
partitions: List[Folder] = []
884884
grouped_s3_objects_by_dirname = groupby_unsorted(
885885
s3_objects,
886886
key=lambda obj: obj.key.rsplit("/", 1)[0],
887887
)
888-
for key, group in grouped_s3_objects_by_dirname:
889-
file_size = 0
890-
creation_time = None
891-
modification_time = None
892-
893-
for item in group:
894-
file_size += item.size
895-
if creation_time is None or item.last_modified < creation_time:
896-
creation_time = item.last_modified
897-
if modification_time is None or item.last_modified > modification_time:
898-
modification_time = item.last_modified
899-
max_file = item
900-
901-
if modification_time is None:
902-
logger.warning(
903-
f"Unable to find any files in the folder {key}. Skipping..."
904-
)
905-
continue
906-
907-
id = path_spec.get_partition_from_path(
908-
self.create_s3_path(max_file.bucket_name, max_file.key)
888+
for _, group in grouped_s3_objects_by_dirname:
889+
max_file = max(group, key=lambda x: x.last_modified)
890+
max_file_s3_path = self.create_s3_path(max_file.bucket_name, max_file.key)
891+
892+
# If partition_id is None, it means the folder is not a partition
893+
partition_id = path_spec.get_partition_from_path(max_file_s3_path)
894+
895+
yield Folder(
896+
partition_id=partition_id,
897+
is_partition=bool(partition_id),
898+
creation_time=min(obj.last_modified for obj in group),
899+
modification_time=max_file.last_modified,
900+
sample_file=max_file_s3_path,
901+
size=sum(obj.size for obj in group),
909902
)
910903

911-
# If id is None, it means the folder is not a partition
912-
partitions.append(
913-
Folder(
914-
partition_id=id,
915-
is_partition=bool(id),
916-
creation_time=creation_time if creation_time else None, # type: ignore[arg-type]
917-
modification_time=modification_time,
918-
sample_file=self.create_s3_path(max_file.bucket_name, max_file.key),
919-
size=file_size,
920-
)
921-
)
922-
923-
return partitions
924-
925904
def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePath]:
926905
if self.source_config.aws_config is None:
927906
raise ValueError("aws_config not set. Cannot browse s3")
@@ -1000,7 +979,7 @@ def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePa
1000979
min=True,
1001980
)
1002981
dirs_to_process.append(dirs_to_process_min[0])
1003-
folders = []
982+
folders: List[Folder] = []
1004983
for dir in dirs_to_process:
1005984
logger.info(f"Getting files from folder: {dir}")
1006985
prefix_to_process = urlparse(dir).path.lstrip("/")

metadata-ingestion/tests/unit/s3/test_s3_source.py

+53-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from datahub.ingestion.api.workunit import MetadataWorkUnit
1010
from datahub.ingestion.source.data_lake_common.data_lake_utils import ContainerWUCreator
1111
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
12-
from datahub.ingestion.source.s3.source import S3Source, partitioned_folder_comparator
12+
from datahub.ingestion.source.s3.source import (
13+
Folder,
14+
S3Source,
15+
partitioned_folder_comparator,
16+
)
1317

1418

1519
def _get_s3_source(path_spec_: PathSpec) -> S3Source:
@@ -257,7 +261,7 @@ def container_properties_filter(x: MetadataWorkUnit) -> bool:
257261
}
258262

259263

260-
def test_get_folder_info():
264+
def test_get_folder_info_returns_latest_file_in_each_folder() -> None:
261265
"""
262266
Test S3Source.get_folder_info returns the latest file in each folder
263267
"""
@@ -298,6 +302,7 @@ def test_get_folder_info():
298302
res = _get_s3_source(path_spec).get_folder_info(
299303
path_spec, bucket, prefix="/my-folder"
300304
)
305+
res = list(res)
301306

302307
# assert
303308
assert len(res) == 2
@@ -336,6 +341,7 @@ def test_get_folder_info_ignores_disallowed_path(
336341

337342
# act
338343
res = s3_source.get_folder_info(path_spec, bucket, prefix="/my-folder")
344+
res = list(res)
339345

340346
# assert
341347
expected_called_s3_uri = "s3://my-bucket/my-folder/ignore/this/path/0001.csv"
@@ -350,3 +356,48 @@ def test_get_folder_info_ignores_disallowed_path(
350356
"Dropped file should be in the report.filtered"
351357
)
352358
assert res == [], "Dropped file should not be in the result"
359+
360+
361+
def test_get_folder_info_returns_expected_folder() -> None:
362+
# arrange
363+
path_spec = PathSpec(
364+
include="s3://my-bucket/{table}/{partition0}/*.csv",
365+
table_name="{table}",
366+
)
367+
368+
bucket = Mock()
369+
bucket.objects.filter().page_size = Mock(
370+
return_value=[
371+
Mock(
372+
bucket_name="my-bucket",
373+
key="my-folder/dir1/0001.csv",
374+
creation_time=datetime(2025, 1, 1, 1),
375+
last_modified=datetime(2025, 1, 1, 1),
376+
size=100,
377+
),
378+
Mock(
379+
bucket_name="my-bucket",
380+
key="my-folder/dir1/0002.csv",
381+
creation_time=datetime(2025, 1, 1, 2),
382+
last_modified=datetime(2025, 1, 1, 2),
383+
size=50,
384+
),
385+
]
386+
)
387+
388+
# act
389+
res = _get_s3_source(path_spec).get_folder_info(
390+
path_spec, bucket, prefix="/my-folder"
391+
)
392+
res = list(res)
393+
394+
# assert
395+
assert len(res) == 1
396+
assert res[0] == Folder(
397+
partition_id=[("partition0", "dir1")],
398+
is_partition=True,
399+
creation_time=datetime(2025, 1, 1, 1),
400+
modification_time=datetime(2025, 1, 1, 2),
401+
size=150,
402+
sample_file="s3://my-bucket/my-folder/dir1/0002.csv",
403+
)

0 commit comments

Comments
 (0)