Skip to content

Commit

Permalink
Fixing get_file_signals for custom types (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtulga authored Aug 28, 2024
1 parent 477d7d5 commit 13f6982
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
11 changes: 1 addition & 10 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,17 +1560,8 @@ def get_file_signals(
version = self.get_dataset(dataset_name).get_version(dataset_version)

file_signals_values = {}
file_schemas = {}
# TODO: To remove after we properly fix deserialization
for signal, type_name in version.feature_schema.items():
from datachain.lib.model_store import ModelStore

type_name_parsed, v = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(type_name_parsed, v)
if fr and issubclass(fr, File):
file_schemas[signal] = type_name

schema = SignalSchema.deserialize(file_schemas)
schema = SignalSchema.deserialize(version.feature_schema)
for file_signals in schema.get_signals(File):
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
file_signals_values[file_signals] = {
Expand Down
30 changes: 30 additions & 0 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,36 @@ def test_get_file_signals(cloud_test_catalog, dogs_dataset):
}


def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog
catalog.metastore.update_dataset_version(
dogs_dataset,
1,
feature_schema={
"name": "str",
"age": "str",
"f1": "File@v1",
"f2": "File@v1",
"_custom_types": {
"File@v1": {"source": "str", "name": "str"},
},
},
)
row = {
"name": "Jon",
"age": 25,
"f1__source": "s3://first_bucket",
"f1__name": "image1.jpg",
"f2__source": "s3://second_bucket",
"f2__name": "image2.jpg",
}

assert catalog.get_file_signals(dogs_dataset.name, 1, row) == {
"source": "s3://first_bucket",
"name": "image1.jpg",
}


def test_get_file_signals_no_signals(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog
catalog.metastore.update_dataset_version(
Expand Down

0 comments on commit 13f6982

Please sign in to comment.