Skip to content

Commit 7c985db

Browse files
authored
[PLT-999] Vb/chunk by size plt 999 (#1648)
2 parents 01151e9 + 31ba730 commit 7c985db

File tree

9 files changed

+429
-318
lines changed

9 files changed

+429
-318
lines changed

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
from labelbox.schema.user import User
3434
from labelbox.schema.iam_integration import IAMIntegration
3535
from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem)
36-
from labelbox.schema.internal.data_row_uploader import DataRowUploader
36+
import labelbox.schema.internal.data_row_uploader as data_row_uploader
37+
from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator
3738
from labelbox.schema.internal.datarow_upload_constants import (
38-
MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE)
39+
FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES)
3940

4041
logger = logging.getLogger(__name__)
4142

@@ -53,7 +54,6 @@ class Dataset(DbObject, Updateable, Deletable):
5354
created_by (Relationship): `ToOne` relationship to User
5455
organization (Relationship): `ToOne` relationship to Organization
5556
"""
56-
__upsert_chunk_size: Final = UPSERT_CHUNK_SIZE
5757

5858
name = Field.String("name")
5959
description = Field.String("description")
@@ -240,10 +240,8 @@ def create_data_rows_sync(self, items) -> None:
240240
f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows."
241241
" For larger imports use the async function Dataset.create_data_rows()"
242242
)
243-
descriptor_url = DataRowUploader.create_descriptor_file(
244-
self.client,
245-
items,
246-
max_attachments_per_data_row=max_attachments_per_data_row)
243+
descriptor_url = DescriptorFileCreator(self.client).create_one(
244+
items, max_attachments_per_data_row=max_attachments_per_data_row)
247245
dataset_param = "datasetId"
248246
url_param = "jsonUrl"
249247
query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){
@@ -264,7 +262,7 @@ def create_data_rows(self,
264262
Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows.
265263
266264
Args:
267-
items (iterable of (dict or str)): See the docstring for `DataRowUploader.create_descriptor_file` for more information
265+
items (iterable of (dict or str))
268266
269267
Returns:
270268
Task representing the data import on the server side. The Task
@@ -619,11 +617,11 @@ def _exec_upsert_data_rows(
619617
file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT
620618
) -> "DataUpsertTask":
621619

622-
manifest = DataRowUploader.upload_in_chunks(
620+
manifest = data_row_uploader.upload_in_chunks(
623621
client=self.client,
624622
specs=specs,
625623
file_upload_thread_count=file_upload_thread_count,
626-
upsert_chunk_size=UPSERT_CHUNK_SIZE)
624+
max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES)
627625

628626
data = json.dumps(manifest.dict()).encode("utf-8")
629627
manifest_uri = self.client.upload_data(data,
Lines changed: 17 additions & 271 deletions
Original file line numberDiff line numberDiff line change
@@ -1,287 +1,33 @@
1-
import json
2-
import os
31
from concurrent.futures import ThreadPoolExecutor, as_completed
42

5-
from typing import Iterable, List
3+
from typing import List
64

7-
from labelbox.exceptions import InvalidQueryError
8-
from labelbox.exceptions import InvalidAttributeError
9-
from labelbox.exceptions import MalformedQueryException
10-
from labelbox.orm.model import Entity
11-
from labelbox.orm.model import Field
12-
from labelbox.schema.embedding import EmbeddingVector
13-
from labelbox.pydantic_compat import BaseModel
14-
from labelbox.schema.internal.datarow_upload_constants import (
15-
MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT)
5+
from labelbox import pydantic_compat
166
from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem
7+
from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator
178

189

19-
class UploadManifest(BaseModel):
10+
class UploadManifest(pydantic_compat.BaseModel):
2011
source: str
2112
item_count: int
2213
chunk_uris: List[str]
2314

2415

25-
class DataRowUploader:
16+
SOURCE_SDK = "SDK"
2617

27-
@staticmethod
28-
def create_descriptor_file(client,
29-
items,
30-
max_attachments_per_data_row=None,
31-
is_upsert=False):
32-
"""
33-
This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`.
34-
It is used to prepare the input file. The user defined input is validated, processed, and json stringified.
35-
Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows
3618

37-
Each element in `items` can be either a `str` or a `dict`. If
38-
it is a `str`, then it is interpreted as a local file path. The file
39-
is uploaded to Labelbox and a DataRow referencing it is created.
19+
def upload_in_chunks(client, specs: List[DataRowUpsertItem],
20+
file_upload_thread_count: int,
21+
max_chunk_size_bytes: int) -> UploadManifest:
22+
empty_specs = list(filter(lambda spec: spec.is_empty(), specs))
4023

41-
If an item is a `dict`, then it could support one of the two following structures
42-
1. For static imagery, video, and text it should map `DataRow` field names to values.
43-
At the minimum an `items` passed as a `dict` must contain a `row_data` key and value.
44-
If the value for row_data is a local file path and the path exists,
45-
then the local file will be uploaded to labelbox.
24+
if empty_specs:
25+
ids = list(map(lambda spec: spec.id.get("value"), empty_specs))
26+
raise ValueError(f"The following items have an empty payload: {ids}")
4627

47-
2. For tiled imagery the dict must match the import structure specified in the link below
48-
https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import
28+
chunk_uris = DescriptorFileCreator(client).create(
29+
specs, max_chunk_size_bytes=max_chunk_size_bytes)
4930

50-
>>> dataset.create_data_rows([
51-
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
52-
>>> {DataRow.row_data:"/path/to/file1.jpg"},
53-
>>> "path/to/file2.jpg",
54-
>>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}}
55-
>>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}}
56-
>>> ])
57-
58-
For an example showing how to upload tiled data_rows see the following notebook:
59-
https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb
60-
61-
Args:
62-
items (iterable of (dict or str)): See above for details.
63-
max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine
64-
if the user has provided too many attachments.
65-
66-
Returns:
67-
uri (string): A reference to the uploaded json data.
68-
69-
Raises:
70-
InvalidQueryError: If the `items` parameter does not conform to
71-
the specification above or if the server did not accept the
72-
DataRow creation request (unknown reason).
73-
InvalidAttributeError: If there are fields in `items` not valid for
74-
a DataRow.
75-
ValueError: When the upload parameters are invalid
76-
"""
77-
file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT
78-
DataRow = Entity.DataRow
79-
AssetAttachment = Entity.AssetAttachment
80-
81-
def upload_if_necessary(item):
82-
if is_upsert and 'row_data' not in item:
83-
# When upserting, row_data is not required
84-
return item
85-
row_data = item['row_data']
86-
if isinstance(row_data, str) and os.path.exists(row_data):
87-
item_url = client.upload_file(row_data)
88-
item['row_data'] = item_url
89-
if 'external_id' not in item:
90-
# Default `external_id` to local file name
91-
item['external_id'] = row_data
92-
return item
93-
94-
def validate_attachments(item):
95-
attachments = item.get('attachments')
96-
if attachments:
97-
if isinstance(attachments, list):
98-
if max_attachments_per_data_row and len(
99-
attachments) > max_attachments_per_data_row:
100-
raise ValueError(
101-
f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}."
102-
f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary."
103-
)
104-
for attachment in attachments:
105-
AssetAttachment.validate_attachment_json(attachment)
106-
else:
107-
raise ValueError(
108-
f"Attachments must be a list. Found {type(attachments)}"
109-
)
110-
return attachments
111-
112-
def validate_embeddings(item):
113-
embeddings = item.get("embeddings")
114-
if embeddings:
115-
item["embeddings"] = [
116-
EmbeddingVector(**e).to_gql() for e in embeddings
117-
]
118-
119-
def validate_conversational_data(conversational_data: list) -> None:
120-
"""
121-
Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
122-
123-
Args:
124-
conversational_data (list): list of dictionaries.
125-
"""
126-
127-
def check_message_keys(message):
128-
accepted_message_keys = set([
129-
"messageId", "timestampUsec", "content", "user", "align",
130-
"canLabel"
131-
])
132-
for key in message.keys():
133-
if not key in accepted_message_keys:
134-
raise KeyError(
135-
f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}"
136-
)
137-
138-
if conversational_data and not isinstance(conversational_data,
139-
list):
140-
raise ValueError(
141-
f"conversationalData must be a list. Found {type(conversational_data)}"
142-
)
143-
144-
[check_message_keys(message) for message in conversational_data]
145-
146-
def parse_metadata_fields(item):
147-
metadata_fields = item.get('metadata_fields')
148-
if metadata_fields:
149-
mdo = client.get_data_row_metadata_ontology()
150-
item['metadata_fields'] = mdo.parse_upsert_metadata(
151-
metadata_fields)
152-
153-
def format_row(item):
154-
# Formats user input into a consistent dict structure
155-
if isinstance(item, dict):
156-
# Convert fields to strings
157-
item = {
158-
key.name if isinstance(key, Field) else key: value
159-
for key, value in item.items()
160-
}
161-
elif isinstance(item, str):
162-
# The main advantage of using a string over a dict is that the user is specifying
163-
# that the file should exist locally.
164-
# That info is lost after this section so we should check for it here.
165-
if not os.path.exists(item):
166-
raise ValueError(f"Filepath {item} does not exist.")
167-
item = {"row_data": item, "external_id": item}
168-
return item
169-
170-
def validate_keys(item):
171-
if not is_upsert and 'row_data' not in item:
172-
raise InvalidQueryError(
173-
"`row_data` missing when creating DataRow.")
174-
175-
if isinstance(item.get('row_data'),
176-
str) and item.get('row_data').startswith("s3:/"):
177-
raise InvalidQueryError(
178-
"row_data: s3 assets must start with 'https'.")
179-
allowed_extra_fields = {
180-
'attachments', 'media_type', 'dataset_id', 'embeddings'
181-
}
182-
invalid_keys = set(item) - {f.name for f in DataRow.fields()
183-
} - allowed_extra_fields
184-
if invalid_keys:
185-
raise InvalidAttributeError(DataRow, invalid_keys)
186-
return item
187-
188-
def format_legacy_conversational_data(item):
189-
messages = item.pop("conversationalData")
190-
version = item.pop("version", 1)
191-
type = item.pop("type", "application/vnd.labelbox.conversational")
192-
if "externalId" in item:
193-
external_id = item.pop("externalId")
194-
item["external_id"] = external_id
195-
if "globalKey" in item:
196-
global_key = item.pop("globalKey")
197-
item["globalKey"] = global_key
198-
validate_conversational_data(messages)
199-
one_conversation = \
200-
{
201-
"type": type,
202-
"version": version,
203-
"messages": messages
204-
}
205-
item["row_data"] = one_conversation
206-
return item
207-
208-
def convert_item(data_row_item):
209-
if isinstance(data_row_item, DataRowUpsertItem):
210-
item = data_row_item.payload
211-
else:
212-
item = data_row_item
213-
214-
if "tileLayerUrl" in item:
215-
validate_attachments(item)
216-
return item
217-
218-
if "conversationalData" in item:
219-
format_legacy_conversational_data(item)
220-
221-
# Convert all payload variations into the same dict format
222-
item = format_row(item)
223-
# Make sure required keys exist (and there are no extra keys)
224-
validate_keys(item)
225-
# Make sure attachments are valid
226-
validate_attachments(item)
227-
# Make sure embeddings are valid
228-
validate_embeddings(item)
229-
# Parse metadata fields if they exist
230-
parse_metadata_fields(item)
231-
# Upload any local file paths
232-
item = upload_if_necessary(item)
233-
234-
if isinstance(data_row_item, DataRowUpsertItem):
235-
return {'id': data_row_item.id, 'payload': item}
236-
else:
237-
return item
238-
239-
if not isinstance(items, Iterable):
240-
raise ValueError(
241-
f"Must pass an iterable to create_data_rows. Found {type(items)}"
242-
)
243-
244-
if len(items) > MAX_DATAROW_PER_API_OPERATION:
245-
raise MalformedQueryException(
246-
f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call."
247-
)
248-
249-
with ThreadPoolExecutor(file_upload_thread_count) as executor:
250-
futures = [executor.submit(convert_item, item) for item in items]
251-
items = [future.result() for future in as_completed(futures)]
252-
# Prepare and upload the desciptor file
253-
data = json.dumps(items)
254-
return client.upload_data(data,
255-
content_type="application/json",
256-
filename="json_import.json")
257-
258-
@staticmethod
259-
def upload_in_chunks(client, specs: List[DataRowUpsertItem],
260-
file_upload_thread_count: int,
261-
upsert_chunk_size: int) -> UploadManifest:
262-
empty_specs = list(filter(lambda spec: spec.is_empty(), specs))
263-
264-
if empty_specs:
265-
ids = list(map(lambda spec: spec.id.get("value"), empty_specs))
266-
raise ValueError(
267-
f"The following items have an empty payload: {ids}")
268-
269-
chunks = [
270-
specs[i:i + upsert_chunk_size]
271-
for i in range(0, len(specs), upsert_chunk_size)
272-
]
273-
274-
def _upload_chunk(chunk):
275-
return DataRowUploader.create_descriptor_file(client,
276-
chunk,
277-
is_upsert=True)
278-
279-
with ThreadPoolExecutor(file_upload_thread_count) as executor:
280-
futures = [
281-
executor.submit(_upload_chunk, chunk) for chunk in chunks
282-
]
283-
chunk_uris = [future.result() for future in as_completed(futures)]
284-
285-
return UploadManifest(source="SDK",
286-
item_count=len(specs),
287-
chunk_uris=chunk_uris)
31+
return UploadManifest(source=SOURCE_SDK,
32+
item_count=len(specs),
33+
chunk_uris=chunk_uris)
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
MAX_DATAROW_PER_API_OPERATION = 150_000
21
FILE_UPLOAD_THREAD_COUNT = 20
3-
UPSERT_CHUNK_SIZE = 10_000
2+
UPSERT_CHUNK_SIZE_BYTES = 10_000_000
43
DOWNLOAD_RESULT_PAGE_SIZE = 5_000

0 commit comments

Comments
 (0)