diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 89ad6e86d83..c7c9d193afc 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1778,11 +1778,13 @@ class Features(dict): - [`Translation`] or [`TranslationVariableLanguages`] feature specific to Machine Translation. """ - def __init__(*args, **kwargs): + def __init__(*args, non_nullable_flds: set | None = None, **kwargs): # self not in the signature to allow passing self as a kwarg if not args: raise TypeError("descriptor '__init__' of 'Features' object needs an argument") self, *args = args + + self.non_nullable_flds: set[str] = non_nullable_flds or set() super(Features, self).__init__(*args, **kwargs) self._column_requires_decoding: dict[str, bool] = { col: require_decoding(feature) for col, feature in self.items() @@ -1818,14 +1820,17 @@ def arrow_schema(self): :obj:`pyarrow.Schema` """ hf_metadata = {"info": {"features": self.to_dict()}} - return pa.schema(self.type).with_metadata({"huggingface": json.dumps(hf_metadata)}) + schema = pa.schema(self.type, metadata={"huggingface": json.dumps(hf_metadata)}) + schema = restore_non_nullable_fields(schema, self.non_nullable_flds) + return schema @classmethod def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features": """ Construct [`Features`] from Arrow Schema. It also checks the schema metadata for Hugging Face Datasets features. - Non-nullable fields are not supported and set to nullable. + Non-nullable fields are supported and are stored in the non_nullable_flds attribute. + Calling `arrow_schema` will attempt to restore the non-nullable fields. Also, pa.dictionary is not supported and it uses its underlying type instead. Therefore datasets convert DictionaryArray objects to their actual values. @@ -1845,14 +1850,18 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features": metadata_features = Features.from_dict(metadata["info"]["features"]) metadata_features_schema = metadata_features.arrow_schema obj = { - field.name: ( - metadata_features[field.name] - if field.name in metadata_features and metadata_features_schema.field(field.name) == field - else generate_from_arrow_type(field.type) + schema_field.name: ( + metadata_features[schema_field.name] + if schema_field.name in metadata_features + and metadata_features_schema.field(schema_field.name) == schema_field + else generate_from_arrow_type(schema_field.type) ) - for field in pa_schema + for schema_field in pa_schema } - return cls(**obj) + + non_nullable = find_non_nullable_fields(pa_schema) + + return cls(**obj, non_nullable_flds=non_nullable) @classmethod def from_dict(cls, dic) -> "Features": @@ -2325,3 +2334,83 @@ def _check_if_features_can_be_aligned(features_list: list[Features]): raise ValueError( f'The features can\'t be aligned because the key {k} of features {features} has unexpected type - {v} (expected either {name2feature[k]} or Value("null").' ) + + +def find_non_nullable_fields(schema: pa.Schema, parent_path: str = "") -> set[str]: + """Recursively find non-nullable fields in a PyArrow schema and return them + as a set of period-separated paths, useful for deeper structures. + + Args: + schema (pa.Schema): PyArrow schema to inspect + parent_path (str, optional): Path to the current field for nested types (recursion) + + Returns: + set[str]: Set of non-nullable field paths, where embedded paths are separated by a period + """ + non_nullable_fields = set() + + if hasattr(schema, "name"): + parent_path = f"{parent_path}.{schema.name}".lstrip(".") + + # Full Schema + if isinstance(schema, pa.Schema): + for schema_field in schema: + non_nullable_fields.update(find_non_nullable_fields(schema_field, parent_path)) + # Regular Fields + elif hasattr(schema, "type"): + # Check for non-nullable top-level Field + if not schema.nullable: + non_nullable_fields.add(parent_path) + + # Recursively inspect nested types + non_nullable_fields.update(find_non_nullable_fields(schema.type, parent_path)) + + elif pa.types.is_struct(schema): + for schema_field in schema: + non_nullable_fields.update(find_non_nullable_fields(schema_field, parent_path)) + elif pa.types.is_list(schema): + non_nullable_fields.update(find_non_nullable_fields(schema.value_field, parent_path)) + + return non_nullable_fields + + +def restore_non_nullable_fields(schema: pa.Schema, non_nullable: set[str]) -> pa.Schema: + """Recover non-nullable fields in a PyArrow schema based on a set of period-separated paths. + See `find_non_nullable_fields` for more information. + + Args: + schema (pa.Schema): PyArrow schema to update + non_nullable (set[str]): Set of non-nullable field paths, where embedded paths are separated by a period + + Returns: + pa.Schema: Updated PyArrow schema + """ + + # Recursively update the schema + def update_field(schema_field: pa.Field, parent_path: str = ""): + # Check if the current field is non-nullable + current_path = f"{parent_path}.{schema_field.name}".lstrip(".") + if current_path in non_nullable: + schema_field = schema_field.with_nullable(False) + + # Recursively update nested fields + if pa.types.is_struct(schema_field.type): + new_fields = [] + for nested_field in schema_field.type: + new_fields.append(update_field(nested_field, current_path)) + schema_field = schema_field.with_type(pa.struct(new_fields)) + + # Recursively update list value types + elif pa.types.is_list(schema_field.type): + value_type = schema_field.type.value_type + if hasattr(value_type, "type"): + schema_field = schema_field.with_type(pa.list_(update_field(value_type, current_path))) + + return schema_field + + # Update all fields in the schema + new_fields = [] + for schema_field in schema: + new_fields.append(update_field(schema_field)) + + return pa.schema(new_fields) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 6234d7ede62..c8524612263 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -21,12 +21,14 @@ cast_to_python_objects, decode_nested_example, encode_nested_example, + find_non_nullable_fields, generate_from_arrow_type, generate_from_dict, get_nested_type, require_decoding, require_storage_cast, require_storage_embed, + restore_non_nullable_fields, string_to_arrow, ) from datasets.features.translation import Translation, TranslationVariableLanguages @@ -996,3 +998,85 @@ def func(x): result = _visit(feature, func) assert result == expected + + +def test_non_nullable_fields_in_schema(): + """Test that non-nullable fields are correctly identified in a schema.""" + schema = pa.schema( + [ + pa.field("nullable", pa.int32(), nullable=True), + pa.field("non_nullable", pa.int32(), nullable=False), + ] + ) + + non_nullable = find_non_nullable_fields(schema) + assert non_nullable == {"non_nullable"} + + # Test restoring non-nullable fields + restored_schema = restore_non_nullable_fields( + pa.schema( + [ + pa.field("nullable", pa.int32()), + pa.field("non_nullable", pa.int32()), + ] + ), + non_nullable, + ) + + assert restored_schema.field("nullable").nullable is True + assert restored_schema.field("non_nullable").nullable is False + + +def test_nested_non_nullable_fields_in_schema(): + """Test that non-nullable fields are correctly identified in deeply nested structures.""" + schema = pa.schema( + [ + pa.field( + "top_level", + pa.struct( + [ + pa.field("nested_nullable", pa.int32(), nullable=True), + pa.field("nested_non_nullable", pa.int32(), nullable=False), + ] + ), + ), + pa.field( + "list_field", + pa.list_( + pa.field( + "item", + pa.struct( + [ + pa.field("deeply_nested_nullable", pa.int32(), nullable=True), + pa.field("deeply_nested_non_nullable", pa.int32(), nullable=False), + ] + ), + ) + ), + ), + ] + ) + + non_nullable = find_non_nullable_fields(schema) + expected = {"top_level.nested_non_nullable", "list_field.item.deeply_nested_non_nullable"} + assert non_nullable == expected + + +def test_from_arrow_schema_preserves_non_nullable(): + """Test that from_arrow_schema correctly preserves non-nullable information.""" + schema = pa.schema( + [ + pa.field("nullable", pa.int32(), nullable=True), + pa.field("non_nullable", pa.int32(), nullable=False), + ] + ) + + # Convert to Features and back to schema + features = Features.from_arrow_schema(schema) + assert "non_nullable" in features.non_nullable_flds + assert "nullable" not in features.non_nullable_flds + + # Ensure the schema is preserved when converted back + new_schema = features.arrow_schema + assert new_schema.field("nullable").nullable is True + assert new_schema.field("non_nullable").nullable is False