11
11
from labelbox .schema .ontology import SchemaId
12
12
from labelbox .utils import camel_case
13
13
14
- _MAX_METADATA_FIELDS = 5
15
-
16
14
17
15
class DataRowMetadataKind (Enum ):
18
16
number = "CustomMetadataNumber"
@@ -40,16 +38,8 @@ def id(self):
40
38
41
39
DataRowMetadataSchema .update_forward_refs ()
42
40
43
- # Constraints for metadata values
44
41
Embedding : Type [List [float ]] = conlist (float , min_items = 128 , max_items = 128 )
45
- DateTime : Type [datetime ] = datetime # must be in UTC
46
42
String : Type [str ] = constr (max_length = 500 )
47
- OptionId : Type [SchemaId ] = SchemaId # enum option
48
- Number : Type [float ] = float
49
-
50
- DataRowMetadataValue = Union [Embedding , Number , DateTime , String , OptionId ]
51
- # primitives used in uploads
52
- _DataRowMetadataValuePrimitives = Union [str , List , dict , float ]
53
43
54
44
55
45
class _CamelCaseMixin (BaseModel ):
@@ -62,7 +52,9 @@ class Config:
62
52
# Metadata base class
63
53
class DataRowMetadataField (_CamelCaseMixin ):
64
54
schema_id : SchemaId
65
- value : Union [DataRowMetadataValue , _DataRowMetadataValuePrimitives ]
55
+ # value is of type `Any` so that we do not improperly coerce the value to the wrong tpye
56
+ # Additional validation is performed before upload using the schema information
57
+ value : Any
66
58
67
59
68
60
class DataRowMetadata (_CamelCaseMixin ):
@@ -241,10 +233,14 @@ def parse_metadata(
241
233
elif schema .kind == DataRowMetadataKind .option :
242
234
field = DataRowMetadataField (schema_id = schema .parent ,
243
235
value = schema .uid )
236
+ elif schema .kind == DataRowMetadataKind .datetime :
237
+ field = DataRowMetadataField (
238
+ schema_id = schema .uid ,
239
+ value = datetime .fromisoformat (f ["value" ][:- 1 ] +
240
+ "+00:00" ))
244
241
else :
245
242
field = DataRowMetadataField (schema_id = schema .uid ,
246
243
value = f ["value" ])
247
-
248
244
fields .append (field )
249
245
parsed .append (
250
246
DataRowMetadata (data_row_id = dr ["dataRowId" ], fields = fields ))
@@ -300,10 +296,6 @@ def _batch_upsert(
300
296
301
297
items = []
302
298
for m in metadata :
303
- if len (m .fields ) > _MAX_METADATA_FIELDS :
304
- raise ValueError (
305
- f"Cannot upload { len (m .fields )} , the max number is { _MAX_METADATA_FIELDS } "
306
- )
307
299
items .append (
308
300
_UpsertBatchDataRowMetadata (
309
301
data_row_id = m .data_row_id ,
@@ -478,17 +470,39 @@ def _batch_operations(
478
470
def _validate_parse_embedding (
479
471
field : DataRowMetadataField
480
472
) -> List [Dict [str , Union [SchemaId , Embedding ]]]:
473
+
474
+ if isinstance (field .value , list ):
475
+ if not (Embedding .min_items <= len (field .value ) <= Embedding .max_items ):
476
+ raise ValueError (
477
+ "Embedding length invalid. "
478
+ "Must have length within the interval "
479
+ f"[{ Embedding .min_items } ,{ Embedding .max_items } ]. Found { len (field .value )} "
480
+ )
481
+ field .value = [float (x ) for x in field .value ]
482
+ else :
483
+ raise ValueError (
484
+ f"Expected a list for embedding. Found { type (field .value )} " )
481
485
return [field .dict (by_alias = True )]
482
486
483
487
484
488
def _validate_parse_number (
485
- field : DataRowMetadataField
486
- ) -> List [Dict [str , Union [SchemaId , Number ]]]:
489
+ field : DataRowMetadataField
490
+ ) -> List [Dict [str , Union [SchemaId , str , float , int ]]]:
491
+ field .value = float (field .value )
487
492
return [field .dict (by_alias = True )]
488
493
489
494
490
495
def _validate_parse_datetime (
491
496
field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , str ]]]:
497
+ if isinstance (field .value , str ):
498
+ if field .value .endswith ("Z" ):
499
+ field .value = field .value [:- 1 ]
500
+ field .value = datetime .fromisoformat (field .value )
501
+ elif not isinstance (field .value , datetime ):
502
+ raise TypeError (
503
+ f"value for datetime fields must be either a string or datetime object. Found { type (field .value )} "
504
+ )
505
+
492
506
return [{
493
507
"schemaId" : field .schema_id ,
494
508
"value" : field .value .isoformat () + "Z" , # needs to be UTC
@@ -497,6 +511,15 @@ def _validate_parse_datetime(
497
511
498
512
def _validate_parse_text (
499
513
field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , str ]]]:
514
+ if not isinstance (field .value , str ):
515
+ raise ValueError (
516
+ f"Expected a string type for the text field. Found { type (field .value )} "
517
+ )
518
+
519
+ if len (field .value ) > String .max_length :
520
+ raise ValueError (
521
+ f"string fields cannot exceed { String .max_length } characters." )
522
+
500
523
return [field .dict (by_alias = True )]
501
524
502
525
0 commit comments