From c95326a900a92d4aebe9d04bc9f7156763157942 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 12 Jan 2024 10:20:10 +0100 Subject: [PATCH] feature interpret metadata entries with nonenan as non existent instead of returning an error (#4477) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR resolves problems working with `None` and `nan` values in metadata property values. - The `nan` values are ignored when detected on the server side - The `None` values reset the metadata property value. So, if users want to reset the `prop`metadata property, they can set it up with a `None` value. ## UPDATE New changes have been introduced to improve API behaviour: - `Nan` values will raise a 422 error, which makes sense since is not a valid value. - `None` values are accepted as valid values. But the behaviour can be equivalent to avoid passing the metadata value for that property since the metadata whole data is replaced/overwritten on a record update. Saying this, the client will parse metadata values to detect `nan` and raise a proper. The `None` ones are fully accepted. cc @davidberenstein1957 @sdiazlor @jfcalvo Closes #4300 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] New feature (non-breaking change which adds functionality) - [X] Refactor (change restructuring the codebase without changing functionality) - [X] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) This feature has been tested locally using the Python SDK. **Checklist** - [ ] I added relevant documentation - [ ] I followed the style guidelines of this project - [X] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [X] My changes generate no new warnings - [X] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [X] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: José Francisco Calvo --- CHANGELOG.md | 12 +- src/argilla/client/feedback/constants.py | 13 -- .../client/feedback/schemas/metadata.py | 119 +++++++++++++----- src/argilla/server/contexts/datasets.py | 3 +- src/argilla/server/schemas/v1/datasets.py | 12 ++ src/argilla/server/schemas/v1/records.py | 14 ++- .../feedback/dataset/remote/test_dataset.py | 48 +++++++ .../feedback/dataset/local/test_dataset.py | 6 +- .../client/feedback/dataset/test_helpers.py | 12 ++ .../unit/client/sdk/models/test_workspaces.py | 2 +- tests/unit/server/api/v1/test_datasets.py | 117 +++++++++++++---- 11 files changed, 278 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9978f5609e..ec064ab68b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,23 +19,24 @@ These are the section headers that we use: ### Added - Restore filters from feedback dataset settings ([#4461])(https://github.com/argilla-io/argilla/pull/4461) -- Warning on feedback dataset settings when leaving page with unsaved changes ([#4461])(https://github.com/argilla-io/argilla/pull/4461) -- Added pydantic v2 support using the python SDK ([#4459](https://github.com/argilla-io/argilla/pull/4459)) +- Warning on feedback dataset settings when leaving page with unsaved changes. ([#4461](https://github.com/argilla-io/argilla/pull/4461)) +- Added pydantic v2 support using the python SDK. ([#4459](https://github.com/argilla-io/argilla/pull/4459)) -## Changed +### Changed - Module `argilla.cli.server` definitions have been moved to `argilla.server.cli` module. ([#4472](https://github.com/argilla-io/argilla/pull/4472)) - The constant definition `ES_INDEX_REGEX_PATTERN` in module `argilla._constants` is now private. ([#4472](https://github.com/argilla-io/argilla/pull/4474)) +- `nan` values in metadata properties will raise a 422 error when creating/updating records. ([#4300](https://github.com/argilla-io/argilla/issues/4300)) +- `None` values are now allowed in metadata properties. ([#4300](https://github.com/argilla-io/argilla/issues/4300)) ### Deprecated - The `missing` response status for filtering records is deprecated and will be removed in the release v1.24.0. Use `pending` instead. ([#4433](https://github.com/argilla-io/argilla/pull/4433)) -## Removed +### Removed - The deprecated `python -m argilla database` command has been removed. ([#4472](https://github.com/argilla-io/argilla/pull/4472)) - ## [1.21.0](https://github.com/argilla-io/argilla/compare/v1.20.0...v1.21.0) ### Added @@ -50,7 +51,6 @@ These are the section headers that we use: - Added `httpx_extra_kwargs` argument to `rg.init` and `Argilla` to allow passing extra arguments to `httpx.Client` used by `Argilla`. ([#4440](https://github.com/argilla-io/argilla/pull/4441)) - Added `ResponseStatusFilter` enum in `__init__` imports of Argilla ([#4118](https://github.com/argilla-io/argilla/pull/4463)). Contributed by @Piyush-Kumar-Ghosh. - ### Changed - More productive and simpler shortcuts system ([#4215](https://github.com/argilla-io/argilla/pull/4215)) diff --git a/src/argilla/client/feedback/constants.py b/src/argilla/client/feedback/constants.py index 79c2a33faa..1bae9abc30 100644 --- a/src/argilla/client/feedback/constants.py +++ b/src/argilla/client/feedback/constants.py @@ -21,16 +21,3 @@ DELETE_DATASET_RECORDS_MAX_NUMBER = 100 FIELD_TYPE_TO_PYTHON_TYPE = {FieldTypes.text: str} -# We are using `pydantic`'s strict types to avoid implicit type conversions -METADATA_PROPERTY_TYPE_TO_PYDANTIC_TYPE = { - MetadataPropertyTypes.terms: Union[StrictStr, List[StrictStr]], - MetadataPropertyTypes.integer: StrictInt, - MetadataPropertyTypes.float: StrictFloat, -} - -PYDANTIC_STRICT_TO_PYTHON_TYPE = { - StrictInt: int, - StrictFloat: float, - StrictStr: str, - Union[StrictStr, List[StrictStr]]: (str, list), -} diff --git a/src/argilla/client/feedback/schemas/metadata.py b/src/argilla/client/feedback/schemas/metadata.py index 24bfd45c41..d39f46094c 100644 --- a/src/argilla/client/feedback/schemas/metadata.py +++ b/src/argilla/client/feedback/schemas/metadata.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import math from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union -from argilla.client.feedback.constants import METADATA_PROPERTY_TYPE_TO_PYDANTIC_TYPE, PYDANTIC_STRICT_TO_PYTHON_TYPE from argilla.client.feedback.schemas.enums import MetadataPropertyTypes from argilla.client.feedback.schemas.validators import ( validate_numeric_metadata_filter_bounds, @@ -92,14 +91,18 @@ def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Di def _validate_filter(self, metadata_filter: "MetadataFilters") -> None: pass + @abstractmethod def _check_allowed_value_type(self, value: Any) -> Any: - expected_type = PYDANTIC_STRICT_TO_PYTHON_TYPE[METADATA_PROPERTY_TYPE_TO_PYDANTIC_TYPE[self.type]] - if not isinstance(value, expected_type): - raise ValueError( - f"Provided '{self.name}={value}' of type {type(value)} is not valid, " - f"only values of type {expected_type} are allowed." - ) - return value + ... + + @abstractmethod + def _validator(self, value: Any) -> Any: + ... + + +def _validator_definition(schema: MetadataPropertySchema) -> Dict[str, Any]: + """Returns a dictionary with the pydantic validator definition for the provided schema.""" + return {f"{schema.name}_validator": validator(schema.name, allow_reuse=True, pre=True)(schema._validator)} class TermsMetadataProperty(MetadataPropertySchema): @@ -161,21 +164,33 @@ def _all_values_exist(self, introduced_value: Optional[Union[str, List[str]]] = return introduced_value + def _check_allowed_value_type(self, value: Any) -> Any: + if value is None or isinstance(value, str): + return value + + if isinstance(value, list): + return [self._check_allowed_value_type(v) for v in value] + + raise TypeError( + f"Provided '{self.name}={value}' of type {type(value)} is not valid, " + "only values of type `str` or `list` are allowed." + ) + def _validator(self, value: Any) -> Any: return self._all_values_exist(self._check_allowed_value_type(value)) @property - def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[StrictStr, None]], Dict[str, Callable]]: + def _pydantic_field_with_validator(self) -> Tuple[Dict[str, tuple], Dict[str, Callable]]: # TODO: Simplify the validation logic and do not base on dynamic pydantic models - return ( - {self.name: (METADATA_PROPERTY_TYPE_TO_PYDANTIC_TYPE[self.type], None)}, - {f"{self.name}_validator": validator(self.name, allow_reuse=True, pre=True)(self._validator)}, - ) + field_type, default_value = Optional[Union[StrictStr, List[StrictStr]]], None + + return {self.name: (field_type, default_value)}, _validator_definition(self) def _validate_filter(self, metadata_filter: "TermsMetadataFilter") -> None: if self.values is not None and not all(value in self.values for value in metadata_filter.values): raise ValidationError( - f"Provided 'values={metadata_filter.values}' is not valid, only values in {self.values} are allowed." + f"Provided 'values={metadata_filter.values}' is not valid, only values in {self.values} are allowed.", + model=type(metadata_filter), ) @@ -233,37 +248,40 @@ def _value_in_bounds(self, provided_value: Optional[Union[int, float]]) -> Union ) return provided_value - def _validator(self, value: Any) -> Any: - return self._value_in_bounds(self._check_allowed_value_type(value)) + def _check_nan(self, provided_value: Optional[Union[int, float]]) -> Optional[Union[int, float]]: + if provided_value != provided_value: + raise ValueError(f"Provided '{self.name}={provided_value}' is not valid, NaN values are not allowed.") - @property - def _pydantic_field_with_validator( - self, - ) -> Tuple[Dict[str, Tuple[Union[StrictInt, StrictFloat], None]], Dict[str, Callable]]: - return ( - {self.name: (METADATA_PROPERTY_TYPE_TO_PYDANTIC_TYPE[self.type], None)}, - {f"{self.name}_validator": validator(self.name, allow_reuse=True, pre=True)(self._validator)}, - ) + return provided_value + + def _validator(self, value: Any) -> Any: + return self._value_in_bounds(self._check_allowed_value_type(self._check_nan(value))) def _validate_filter(self, metadata_filter: Union["IntegerMetadataFilter", "FloatMetadataFilter"]) -> None: - metadata_filter = metadata_filter.dict() + metadata_filter_ = metadata_filter.dict() for allowed_arg in ["ge", "le"]: - if metadata_filter[allowed_arg] is not None: + if metadata_filter_[allowed_arg] is not None: if ( self.max is not None and self.min is not None - and not (self.max >= metadata_filter[allowed_arg] >= self.min) + and not (self.max >= metadata_filter_[allowed_arg] >= self.min) ): raise ValidationError( - f"Provided '{allowed_arg}={metadata_filter[allowed_arg]}' is not valid, only values between {self.min} and {self.max} are allowed." + f"Provided '{allowed_arg}={metadata_filter_[allowed_arg]}' is not valid, " + f"only values between {self.min} and {self.max} are allowed.", + model=type(metadata_filter), ) - if self.max is not None and not (self.max >= metadata_filter[allowed_arg]): + if self.max is not None and not (self.max >= metadata_filter_[allowed_arg]): raise ValidationError( - f"Provided '{allowed_arg}={metadata_filter[allowed_arg]}' is not valid, only values under {self.max} are allowed." + f"Provided '{allowed_arg}={metadata_filter_[allowed_arg]}' is not valid, " + f"only values under {self.max} are allowed.", + model=type(metadata_filter), ) - if self.min is not None and not (self.min <= metadata_filter[allowed_arg]): + if self.min is not None and not (self.min <= metadata_filter_[allowed_arg]): raise ValidationError( - f"Provided '{allowed_arg}={metadata_filter[allowed_arg]}' is not valid, only values over {self.min} are allowed." + f"Provided '{allowed_arg}={metadata_filter_[allowed_arg]}' is not valid, " + f"only values over {self.min} are allowed.", + model=type(metadata_filter), ) @@ -292,6 +310,23 @@ class IntegerMetadataProperty(_NumericMetadataPropertySchema): min: Optional[int] = None max: Optional[int] = None + @property + def _pydantic_field_with_validator( + self, + ) -> Tuple[Dict[str, tuple], Dict[str, Callable]]: + field_type, default_value = Optional[StrictInt], None + return {self.name: (field_type, default_value)}, _validator_definition(self) + + def _check_allowed_value_type(self, value: Any) -> Any: + if value is not None: + if isinstance(value, int): + return value + raise TypeError( + f"Provided '{self.name}={value}' of type {type(value)} is not valid, " + "only values of type `int` are allowed." + ) + return value + class FloatMetadataProperty(_NumericMetadataPropertySchema): """Schema for the `FeedbackDataset` metadata properties of type `float`. This kind @@ -318,6 +353,22 @@ class FloatMetadataProperty(_NumericMetadataPropertySchema): min: Optional[float] = None max: Optional[float] = None + def _check_allowed_value_type(self, value: Any) -> Any: + if value is None or isinstance(value, (int, float)): + return value + + raise TypeError( + f"Provided '{self.name}={value}' of type {type(value)} is not valid, " + "only values of type `int` or `float` are allowed." + ) + + @property + def _pydantic_field_with_validator( + self, + ) -> Tuple[Dict[str, tuple], Dict[str, Any]]: + field_type, default_value = Optional[StrictFloat], None + return {self.name: (field_type, default_value)}, _validator_definition(self) + class MetadataFilterSchema(BaseModel, ABC): """Base schema for the `FeedbackDataset` metadata filters. diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 44aef78ac9..854216a4b1 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -512,7 +512,8 @@ async def _validate_metadata( continue try: - metadata_property.parsed_settings.check_metadata(value) + if value is not None: + metadata_property.parsed_settings.check_metadata(value) except ValueError as e: raise ValueError(f"'{name}' metadata property validation failed because {e}") from e diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index 4b39a88fe7..cba05e711b 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -524,6 +524,18 @@ def check_user_id_is_unique(cls, values: Optional[List[UserResponseCreate]]) -> return values + @validator("metadata", pre=True) + @classmethod + def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if metadata is None: + return metadata + + for k, v in metadata.items(): + if v != v: + raise ValueError(f"NaN is not allowed as metadata value, found NaN for key {k!r}") + + return metadata + class RecordsCreate(BaseModel): items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS) diff --git a/src/argilla/server/schemas/v1/records.py b/src/argilla/server/schemas/v1/records.py index d0fc87937b..979b4b86a9 100644 --- a/src/argilla/server/schemas/v1/records.py +++ b/src/argilla/server/schemas/v1/records.py @@ -17,7 +17,7 @@ from uuid import UUID from argilla.server.models import ResponseStatus -from argilla.server.pydantic_v1 import BaseModel, Field +from argilla.server.pydantic_v1 import BaseModel, Field, validator from argilla.server.schemas.base import UpdateSchema from argilla.server.schemas.v1.suggestions import SuggestionCreate @@ -51,3 +51,15 @@ class RecordUpdate(UpdateSchema): metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata") suggestions: Optional[List[SuggestionCreate]] = None vectors: Optional[Dict[str, List[float]]] + + @validator("metadata_", pre=True) + @classmethod + def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if metadata is None: + return metadata + + for k, v in metadata.items(): + if v != v: + raise ValueError(f"NaN is not allowed as metadata value, found NaN for key {k!r}") + + return {k: v for k, v in metadata.items() if v == v} # By definition, NaN != NaN diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index c697825ecd..43bb2cde9b 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -18,6 +18,7 @@ import argilla as rg import argilla.client.singleton +import numpy import pytest from argilla import ( FeedbackRecord, @@ -841,3 +842,50 @@ async def test_warning_local_methods(self, role: UserRole) -> None: match="A local `FeedbackDataset` returned because `prepare_for_training` is not supported for `RemoteFeedbackDataset`. ", ): ds.prepare_for_training(framework=None, task=None) + + async def test_add_records_with_metadata_including_nan_values( + self, owner: "User", feedback_dataset: FeedbackDataset + ): + argilla.client.singleton.init(api_key=owner.api_key) + workspace = Workspace.create(name="test-workspace") + + remote_dataset = feedback_dataset.push_to_argilla(name="test_dataset", workspace=workspace) + + records = [ + FeedbackRecord( + external_id=str(i), + fields={"text": "Hello world!", "text-2": "Hello world!"}, + metadata={"float-metadata": float("nan")}, + ) + for i in range(1, 20) + ] + + with pytest.raises(ValueError, match="NaN values are not allowed"): + remote_dataset.add_records(records) + + async def test_add_records_with_metadata_including_none_values( + self, owner: "User", feedback_dataset: FeedbackDataset + ): + argilla.client.singleton.init(api_key=owner.api_key) + workspace = Workspace.create(name="test-workspace") + + feedback_dataset.add_records( + [ + FeedbackRecord( + fields={"text": "Hello world!"}, + metadata={ + "terms-metadata": None, + "integer-metadata": None, + "float-metadata": None, + }, + ) + ] + ) + + remote_dataset = feedback_dataset.push_to_argilla(name="test_dataset", workspace=workspace) + assert len(remote_dataset.records) == 1 + assert remote_dataset.records[0].metadata == { + "terms-metadata": None, + "integer-metadata": None, + "float-metadata": None, + } diff --git a/tests/unit/client/feedback/dataset/local/test_dataset.py b/tests/unit/client/feedback/dataset/local/test_dataset.py index 948cb27a8a..5ad494336e 100644 --- a/tests/unit/client/feedback/dataset/local/test_dataset.py +++ b/tests/unit/client/feedback/dataset/local/test_dataset.py @@ -226,8 +226,8 @@ def test_add_record_with_numpy_values(property_class: Type["AllowedMetadataPrope metadata_property = property_class(name="numeric_property") dataset.add_metadata_property(metadata_property) - property_to_primitive_type = {IntegerMetadataProperty: int, FloatMetadataProperty: float} - expected_type = property_to_primitive_type[property_class] + property_to_expected_type_msg = {IntegerMetadataProperty: "`int`", FloatMetadataProperty: "`int` or `float`"} + expected_type_msg = property_to_expected_type_msg[property_class] value = numpy_type(10.0) record = FeedbackRecord(fields={"required-field": "text"}, metadata={"numeric_property": value}) @@ -235,7 +235,7 @@ def test_add_record_with_numpy_values(property_class: Type["AllowedMetadataPrope with pytest.raises( ValueError, match=f"Provided 'numeric_property={value}' of type {str(numpy_type)} is not valid, " - f"only values of type {expected_type} are allowed.", + f"only values of type {expected_type_msg} are allowed.", ): dataset.add_records(record) diff --git a/tests/unit/client/feedback/dataset/test_helpers.py b/tests/unit/client/feedback/dataset/test_helpers.py index 938af459c2..7a95913499 100644 --- a/tests/unit/client/feedback/dataset/test_helpers.py +++ b/tests/unit/client/feedback/dataset/test_helpers.py @@ -94,6 +94,12 @@ def test_generate_pydantic_schema_for_metadata( ValidationError, "Provided 'int-metadata=wrong' of type is not valid", ), + ( + [IntegerMetadataProperty(name="int-metadata", min=0, max=10)], + {"int-metadata": float("nan")}, + ValidationError, + "Provided 'int-metadata=nan' is not valid, NaN values are not allowed.", + ), ( [FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], {"float-metadata": 100.0}, @@ -106,6 +112,12 @@ def test_generate_pydantic_schema_for_metadata( ValidationError, "Provided 'float-metadata=wrong' of type is not valid", ), + ( + [FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], + {"float-metadata": float("nan")}, + ValidationError, + "Provided 'float-metadata=nan' is not valid, NaN values are not allowed.", + ), ], ) def test_generate_pydantic_schema_for_metadata_errors( diff --git a/tests/unit/client/sdk/models/test_workspaces.py b/tests/unit/client/sdk/models/test_workspaces.py index f94064fc67..a1cf8065d4 100644 --- a/tests/unit/client/sdk/models/test_workspaces.py +++ b/tests/unit/client/sdk/models/test_workspaces.py @@ -13,7 +13,7 @@ # limitations under the License. from argilla.client.sdk.workspaces.models import WorkspaceModel as ClientSchema -from argilla.server.apis.v1.handlers.workspaces import Workspace as ServerSchema +from argilla.server.schemas.v1.workspaces import Workspace as ServerSchema def test_users_schema(helpers): diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index c836cec73d..695b36c698 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import uuid from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union @@ -99,7 +100,6 @@ VectorSettingsFactory, WorkspaceFactory, ) -from tests.unit.server.api.v1.test_list_dataset_records import TestSuiteListDatasetRecords if TYPE_CHECKING: from httpx import AsyncClient @@ -2541,15 +2541,20 @@ async def test_create_dataset_records_with_wrong_optional_fields( [ (TermsMetadataPropertyFactory, {"values": ["a", "b", "c"]}, "c"), (TermsMetadataPropertyFactory, {"values": None}, "c"), + (TermsMetadataPropertyFactory, {"values": ["a", "b", "c"]}, None), + (TermsMetadataPropertyFactory, {"values": None}, None), (IntegerMetadataPropertyFactory, {"min": 0, "max": 10}, 5), + (IntegerMetadataPropertyFactory, {"min": 0, "max": 10}, None), (FloatMetadataPropertyFactory, {"min": 0.0, "max": 1}, 0.5), (FloatMetadataPropertyFactory, {"min": 0.3, "max": 0.5}, 0.35), (FloatMetadataPropertyFactory, {"min": 0.3, "max": 0.9}, 0.89), + (FloatMetadataPropertyFactory, {"min": 0.3, "max": 0.9}, None), ], ) async def test_create_dataset_records_with_metadata_values( self, async_client: "AsyncClient", + db: "AsyncSession", owner_auth_header: dict, MetadataPropertyFactoryType: Type[MetadataPropertyFactory], settings: Dict[str, Any], @@ -2575,6 +2580,45 @@ async def test_create_dataset_records_with_metadata_values( assert response.status_code == 204 + record = (await db.execute(select(Record))).scalar() + assert record.metadata_ == {"metadata-property": value} + + @pytest.mark.parametrize( + "MetadataPropertyFactoryType, settings", + [ + (TermsMetadataPropertyFactory, {"values": ["a", "b", "c"]}), + (IntegerMetadataPropertyFactory, {"min": 0, "max": 10}), + (FloatMetadataPropertyFactory, {"min": 0.3, "max": 0.9}), + ], + ) + async def test_create_dataset_records_with_metadata_nan_values( + self, + async_client: "AsyncClient", + db: "AsyncSession", + owner_auth_header: dict, + MetadataPropertyFactoryType: Type[MetadataPropertyFactory], + settings: Dict[str, Any], + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await TextFieldFactory.create(name="completion", dataset=dataset) + await TextQuestionFactory.create(name="corrected", dataset=dataset) + await MetadataPropertyFactoryType.create(name="metadata-property", settings=settings, dataset=dataset) + + records_json = { + "items": [ + { + "fields": {"completion": "text-input"}, + "metadata": {"metadata-property": math.nan}, + } + ] + } + + response = await async_client.post( + f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + ) + + assert response.status_code == 422 + @pytest.mark.parametrize( "MetadataPropertyFactoryType, settings, value", [ @@ -2835,7 +2879,7 @@ async def test_create_dataset_records_without_authentication(self, async_client: records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "external_id": "1", "response": { "values": {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, @@ -2937,7 +2981,7 @@ async def test_create_dataset_records_as_annotator(self, async_client: "AsyncCli records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "external_id": "1", "response": { "values": { @@ -2966,7 +3010,7 @@ async def test_create_dataset_records_as_admin_from_another_workspace(self, asyn records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "external_id": "1", "response": { "values": { @@ -3030,7 +3074,7 @@ async def test_create_dataset_records_with_submitted_response_without_values( records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "responses": [ { "user_id": str(owner.id), @@ -3138,7 +3182,7 @@ async def test_create_dataset_records_with_invalid_response_status( records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "responses": [ { "values": {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, @@ -3200,7 +3244,7 @@ async def test_create_dataset_records_with_non_published_dataset( dataset = await DatasetFactory.create(status=DatasetStatus.draft) records_json = { "items": [ - {"fields": {"input": "Say Hello", "ouput": "Hello"}, "external_id": "1"}, + {"fields": {"input": "Say Hello", "output": "Hello"}, "external_id": "1"}, ], } @@ -3220,7 +3264,7 @@ async def test_create_dataset_records_with_less_items_than_allowed( records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "external_id": str(external_id), } for external_id in range(0, RECORDS_CREATE_MIN_ITEMS - 1) @@ -3242,7 +3286,7 @@ async def test_create_dataset_records_with_more_items_than_allowed( records_json = { "items": [ { - "fields": {"input": "Say Hello", "ouput": "Hello"}, + "fields": {"input": "Say Hello", "output": "Hello"}, "external_id": str(external_id), } for external_id in range(0, RECORDS_CREATE_MAX_ITEMS + 1) @@ -3263,9 +3307,9 @@ async def test_create_dataset_records_with_invalid_records( dataset = await DatasetFactory.create(status=DatasetStatus.ready) records_json = { "items": [ - {"fields": {"input": "Say Hello", "ouput": "Hello"}, "external_id": 1}, + {"fields": {"input": "Say Hello", "output": "Hello"}, "external_id": 1}, {"fields": "invalid", "external_id": 2}, - {"fields": {"input": "Say Hello", "ouput": "Hello"}, "external_id": 3}, + {"fields": {"input": "Say Hello", "output": "Hello"}, "external_id": 3}, ] } @@ -3283,8 +3327,8 @@ async def test_create_dataset_records_with_nonexistent_dataset_id( await DatasetFactory.create() records_json = { "items": [ - {"fields": {"input": "Say Hello", "ouput": "Hello"}, "external_id": 1}, - {"fields": {"input": "Say Hello", "ouput": "Hello"}, "external_id": 2}, + {"fields": {"input": "Say Hello", "output": "Hello"}, "external_id": 1}, + {"fields": {"input": "Say Hello", "output": "Hello"}, "external_id": 2}, ] } @@ -3319,17 +3363,17 @@ async def test_update_dataset_records( { "id": str(records[0].id), "metadata": { - "terms-metadata-property": "a", + "terms-metadata-property": None, "integer-metadata-property": 0, "float-metadata-property": 0.0, - "extra-metadata": "yes", + "extra-metadata": None, }, }, { "id": str(records[1].id), "metadata": { "terms-metadata-property": "b", - "integer-metadata-property": 1, + "integer-metadata-property": None, "float-metadata-property": 1.0, "extra-metadata": "yes", }, @@ -3339,7 +3383,7 @@ async def test_update_dataset_records( "metadata": { "terms-metadata-property": "c", "integer-metadata-property": 2, - "float-metadata-property": 2.0, + "float-metadata-property": None, "extra-metadata": "yes", }, }, @@ -3354,16 +3398,16 @@ async def test_update_dataset_records( # Record 0 assert records[0].metadata_ == { - "terms-metadata-property": "a", + "terms-metadata-property": None, "integer-metadata-property": 0, "float-metadata-property": 0.0, - "extra-metadata": "yes", + "extra-metadata": None, } # Record 1 assert records[1].metadata_ == { "terms-metadata-property": "b", - "integer-metadata-property": 1, + "integer-metadata-property": None, "float-metadata-property": 1.0, "extra-metadata": "yes", } @@ -3372,7 +3416,7 @@ async def test_update_dataset_records( assert records[2].metadata_ == { "terms-metadata-property": "c", "integer-metadata-property": 2, - "float-metadata-property": 2.0, + "float-metadata-property": None, "extra-metadata": "yes", } @@ -3626,6 +3670,37 @@ async def test_update_dataset_records_with_invalid_metadata( "validation failed because 'i was not declared' is not an allowed term." } + async def test_update_dataset_records_with_metadata_nan_value( + self, async_client: "AsyncClient", owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + await TermsMetadataPropertyFactory.create(dataset=dataset, name="terms") + await FloatMetadataPropertyFactory.create(dataset=dataset, name="float") + records = await RecordFactory.create_batch(3, dataset=dataset) + + response = await async_client.patch( + f"/api/v1/datasets/{dataset.id}/records", + headers=owner_auth_header, + json={ + "items": [ + { + "id": str(records[0].id), + "metadata": {"terms": math.nan}, + }, + { + "id": str(records[1].id), + "metadata": {"float": math.nan}, + }, + { + "id": str(records[2].id), + "metadata": {"terms": "a"}, + }, + ] + }, + ) + + assert response.status_code == 422 + async def test_update_dataset_records_with_invalid_suggestions( self, async_client: "AsyncClient", owner_auth_header: dict ):