Skip to content

Commit ef6b006

Browse files
authored
Merge pull request #320 from Labelbox/ms/metadata-validation-updates
fix data type coercion issues
2 parents 4c54ed6 + b174036 commit ef6b006

File tree

1 file changed

+41
-18
lines changed

1 file changed

+41
-18
lines changed

labelbox/schema/data_row_metadata.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from labelbox.schema.ontology import SchemaId
1212
from labelbox.utils import camel_case
1313

14-
_MAX_METADATA_FIELDS = 5
15-
1614

1715
class DataRowMetadataKind(Enum):
1816
number = "CustomMetadataNumber"
@@ -40,16 +38,8 @@ def id(self):
4038

4139
DataRowMetadataSchema.update_forward_refs()
4240

43-
# Constraints for metadata values
4441
Embedding: Type[List[float]] = conlist(float, min_items=128, max_items=128)
45-
DateTime: Type[datetime] = datetime # must be in UTC
4642
String: Type[str] = constr(max_length=500)
47-
OptionId: Type[SchemaId] = SchemaId # enum option
48-
Number: Type[float] = float
49-
50-
DataRowMetadataValue = Union[Embedding, Number, DateTime, String, OptionId]
51-
# primitives used in uploads
52-
_DataRowMetadataValuePrimitives = Union[str, List, dict, float]
5343

5444

5545
class _CamelCaseMixin(BaseModel):
@@ -62,7 +52,9 @@ class Config:
6252
# Metadata base class
6353
class DataRowMetadataField(_CamelCaseMixin):
6454
schema_id: SchemaId
65-
value: Union[DataRowMetadataValue, _DataRowMetadataValuePrimitives]
55+
# value is of type `Any` so that we do not improperly coerce the value to the wrong tpye
56+
# Additional validation is performed before upload using the schema information
57+
value: Any
6658

6759

6860
class DataRowMetadata(_CamelCaseMixin):
@@ -241,10 +233,14 @@ def parse_metadata(
241233
elif schema.kind == DataRowMetadataKind.option:
242234
field = DataRowMetadataField(schema_id=schema.parent,
243235
value=schema.uid)
236+
elif schema.kind == DataRowMetadataKind.datetime:
237+
field = DataRowMetadataField(
238+
schema_id=schema.uid,
239+
value=datetime.fromisoformat(f["value"][:-1] +
240+
"+00:00"))
244241
else:
245242
field = DataRowMetadataField(schema_id=schema.uid,
246243
value=f["value"])
247-
248244
fields.append(field)
249245
parsed.append(
250246
DataRowMetadata(data_row_id=dr["dataRowId"], fields=fields))
@@ -300,10 +296,6 @@ def _batch_upsert(
300296

301297
items = []
302298
for m in metadata:
303-
if len(m.fields) > _MAX_METADATA_FIELDS:
304-
raise ValueError(
305-
f"Cannot upload {len(m.fields)}, the max number is {_MAX_METADATA_FIELDS}"
306-
)
307299
items.append(
308300
_UpsertBatchDataRowMetadata(
309301
data_row_id=m.data_row_id,
@@ -478,17 +470,39 @@ def _batch_operations(
478470
def _validate_parse_embedding(
479471
field: DataRowMetadataField
480472
) -> List[Dict[str, Union[SchemaId, Embedding]]]:
473+
474+
if isinstance(field.value, list):
475+
if not (Embedding.min_items <= len(field.value) <= Embedding.max_items):
476+
raise ValueError(
477+
"Embedding length invalid. "
478+
"Must have length within the interval "
479+
f"[{Embedding.min_items},{Embedding.max_items}]. Found {len(field.value)}"
480+
)
481+
field.value = [float(x) for x in field.value]
482+
else:
483+
raise ValueError(
484+
f"Expected a list for embedding. Found {type(field.value)}")
481485
return [field.dict(by_alias=True)]
482486

483487

484488
def _validate_parse_number(
485-
field: DataRowMetadataField
486-
) -> List[Dict[str, Union[SchemaId, Number]]]:
489+
field: DataRowMetadataField
490+
) -> List[Dict[str, Union[SchemaId, str, float, int]]]:
491+
field.value = float(field.value)
487492
return [field.dict(by_alias=True)]
488493

489494

490495
def _validate_parse_datetime(
491496
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
497+
if isinstance(field.value, str):
498+
if field.value.endswith("Z"):
499+
field.value = field.value[:-1]
500+
field.value = datetime.fromisoformat(field.value)
501+
elif not isinstance(field.value, datetime):
502+
raise TypeError(
503+
f"value for datetime fields must be either a string or datetime object. Found {type(field.value)}"
504+
)
505+
492506
return [{
493507
"schemaId": field.schema_id,
494508
"value": field.value.isoformat() + "Z", # needs to be UTC
@@ -497,6 +511,15 @@ def _validate_parse_datetime(
497511

498512
def _validate_parse_text(
499513
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
514+
if not isinstance(field.value, str):
515+
raise ValueError(
516+
f"Expected a string type for the text field. Found {type(field.value)}"
517+
)
518+
519+
if len(field.value) > String.max_length:
520+
raise ValueError(
521+
f"string fields cannot exceed {String.max_length} characters.")
522+
500523
return [field.dict(by_alias=True)]
501524

502525

0 commit comments

Comments
 (0)