|
1 |
| -import json |
2 |
| -import os |
3 | 1 | from concurrent.futures import ThreadPoolExecutor, as_completed
|
4 | 2 |
|
5 |
| -from typing import Iterable, List |
| 3 | +from typing import List |
6 | 4 |
|
7 |
| -from labelbox.exceptions import InvalidQueryError |
8 |
| -from labelbox.exceptions import InvalidAttributeError |
9 |
| -from labelbox.exceptions import MalformedQueryException |
10 |
| -from labelbox.orm.model import Entity |
11 |
| -from labelbox.orm.model import Field |
12 |
| -from labelbox.schema.embedding import EmbeddingVector |
13 |
| -from labelbox.pydantic_compat import BaseModel |
14 |
| -from labelbox.schema.internal.datarow_upload_constants import ( |
15 |
| - MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT) |
| 5 | +from labelbox import pydantic_compat |
16 | 6 | from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem
|
| 7 | +from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator |
17 | 8 |
|
18 | 9 |
|
19 |
| -class UploadManifest(BaseModel): |
| 10 | +class UploadManifest(pydantic_compat.BaseModel): |
20 | 11 | source: str
|
21 | 12 | item_count: int
|
22 | 13 | chunk_uris: List[str]
|
23 | 14 |
|
24 | 15 |
|
25 |
| -class DataRowUploader: |
| 16 | +SOURCE_SDK = "SDK" |
26 | 17 |
|
27 |
| - @staticmethod |
28 |
| - def create_descriptor_file(client, |
29 |
| - items, |
30 |
| - max_attachments_per_data_row=None, |
31 |
| - is_upsert=False): |
32 |
| - """ |
33 |
| - This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. |
34 |
| - It is used to prepare the input file. The user defined input is validated, processed, and json stringified. |
35 |
| - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows |
36 | 18 |
|
37 |
| - Each element in `items` can be either a `str` or a `dict`. If |
38 |
| - it is a `str`, then it is interpreted as a local file path. The file |
39 |
| - is uploaded to Labelbox and a DataRow referencing it is created. |
| 19 | +def upload_in_chunks(client, specs: List[DataRowUpsertItem], |
| 20 | + file_upload_thread_count: int, |
| 21 | + max_chunk_size_bytes: int) -> UploadManifest: |
| 22 | + empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) |
40 | 23 |
|
41 |
| - If an item is a `dict`, then it could support one of the two following structures |
42 |
| - 1. For static imagery, video, and text it should map `DataRow` field names to values. |
43 |
| - At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. |
44 |
| - If the value for row_data is a local file path and the path exists, |
45 |
| - then the local file will be uploaded to labelbox. |
| 24 | + if empty_specs: |
| 25 | + ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) |
| 26 | + raise ValueError(f"The following items have an empty payload: {ids}") |
46 | 27 |
|
47 |
| - 2. For tiled imagery the dict must match the import structure specified in the link below |
48 |
| - https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import |
| 28 | + chunk_uris = DescriptorFileCreator(client).create( |
| 29 | + specs, max_chunk_size_bytes=max_chunk_size_bytes) |
49 | 30 |
|
50 |
| - >>> dataset.create_data_rows([ |
51 |
| - >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, |
52 |
| - >>> {DataRow.row_data:"/path/to/file1.jpg"}, |
53 |
| - >>> "path/to/file2.jpg", |
54 |
| - >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} |
55 |
| - >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} |
56 |
| - >>> ]) |
57 |
| -
|
58 |
| - For an example showing how to upload tiled data_rows see the following notebook: |
59 |
| - https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb |
60 |
| -
|
61 |
| - Args: |
62 |
| - items (iterable of (dict or str)): See above for details. |
63 |
| - max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine |
64 |
| - if the user has provided too many attachments. |
65 |
| -
|
66 |
| - Returns: |
67 |
| - uri (string): A reference to the uploaded json data. |
68 |
| -
|
69 |
| - Raises: |
70 |
| - InvalidQueryError: If the `items` parameter does not conform to |
71 |
| - the specification above or if the server did not accept the |
72 |
| - DataRow creation request (unknown reason). |
73 |
| - InvalidAttributeError: If there are fields in `items` not valid for |
74 |
| - a DataRow. |
75 |
| - ValueError: When the upload parameters are invalid |
76 |
| - """ |
77 |
| - file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT |
78 |
| - DataRow = Entity.DataRow |
79 |
| - AssetAttachment = Entity.AssetAttachment |
80 |
| - |
81 |
| - def upload_if_necessary(item): |
82 |
| - if is_upsert and 'row_data' not in item: |
83 |
| - # When upserting, row_data is not required |
84 |
| - return item |
85 |
| - row_data = item['row_data'] |
86 |
| - if isinstance(row_data, str) and os.path.exists(row_data): |
87 |
| - item_url = client.upload_file(row_data) |
88 |
| - item['row_data'] = item_url |
89 |
| - if 'external_id' not in item: |
90 |
| - # Default `external_id` to local file name |
91 |
| - item['external_id'] = row_data |
92 |
| - return item |
93 |
| - |
94 |
| - def validate_attachments(item): |
95 |
| - attachments = item.get('attachments') |
96 |
| - if attachments: |
97 |
| - if isinstance(attachments, list): |
98 |
| - if max_attachments_per_data_row and len( |
99 |
| - attachments) > max_attachments_per_data_row: |
100 |
| - raise ValueError( |
101 |
| - f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." |
102 |
| - f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." |
103 |
| - ) |
104 |
| - for attachment in attachments: |
105 |
| - AssetAttachment.validate_attachment_json(attachment) |
106 |
| - else: |
107 |
| - raise ValueError( |
108 |
| - f"Attachments must be a list. Found {type(attachments)}" |
109 |
| - ) |
110 |
| - return attachments |
111 |
| - |
112 |
| - def validate_embeddings(item): |
113 |
| - embeddings = item.get("embeddings") |
114 |
| - if embeddings: |
115 |
| - item["embeddings"] = [ |
116 |
| - EmbeddingVector(**e).to_gql() for e in embeddings |
117 |
| - ] |
118 |
| - |
119 |
| - def validate_conversational_data(conversational_data: list) -> None: |
120 |
| - """ |
121 |
| - Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json |
122 |
| -
|
123 |
| - Args: |
124 |
| - conversational_data (list): list of dictionaries. |
125 |
| - """ |
126 |
| - |
127 |
| - def check_message_keys(message): |
128 |
| - accepted_message_keys = set([ |
129 |
| - "messageId", "timestampUsec", "content", "user", "align", |
130 |
| - "canLabel" |
131 |
| - ]) |
132 |
| - for key in message.keys(): |
133 |
| - if not key in accepted_message_keys: |
134 |
| - raise KeyError( |
135 |
| - f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" |
136 |
| - ) |
137 |
| - |
138 |
| - if conversational_data and not isinstance(conversational_data, |
139 |
| - list): |
140 |
| - raise ValueError( |
141 |
| - f"conversationalData must be a list. Found {type(conversational_data)}" |
142 |
| - ) |
143 |
| - |
144 |
| - [check_message_keys(message) for message in conversational_data] |
145 |
| - |
146 |
| - def parse_metadata_fields(item): |
147 |
| - metadata_fields = item.get('metadata_fields') |
148 |
| - if metadata_fields: |
149 |
| - mdo = client.get_data_row_metadata_ontology() |
150 |
| - item['metadata_fields'] = mdo.parse_upsert_metadata( |
151 |
| - metadata_fields) |
152 |
| - |
153 |
| - def format_row(item): |
154 |
| - # Formats user input into a consistent dict structure |
155 |
| - if isinstance(item, dict): |
156 |
| - # Convert fields to strings |
157 |
| - item = { |
158 |
| - key.name if isinstance(key, Field) else key: value |
159 |
| - for key, value in item.items() |
160 |
| - } |
161 |
| - elif isinstance(item, str): |
162 |
| - # The main advantage of using a string over a dict is that the user is specifying |
163 |
| - # that the file should exist locally. |
164 |
| - # That info is lost after this section so we should check for it here. |
165 |
| - if not os.path.exists(item): |
166 |
| - raise ValueError(f"Filepath {item} does not exist.") |
167 |
| - item = {"row_data": item, "external_id": item} |
168 |
| - return item |
169 |
| - |
170 |
| - def validate_keys(item): |
171 |
| - if not is_upsert and 'row_data' not in item: |
172 |
| - raise InvalidQueryError( |
173 |
| - "`row_data` missing when creating DataRow.") |
174 |
| - |
175 |
| - if isinstance(item.get('row_data'), |
176 |
| - str) and item.get('row_data').startswith("s3:/"): |
177 |
| - raise InvalidQueryError( |
178 |
| - "row_data: s3 assets must start with 'https'.") |
179 |
| - allowed_extra_fields = { |
180 |
| - 'attachments', 'media_type', 'dataset_id', 'embeddings' |
181 |
| - } |
182 |
| - invalid_keys = set(item) - {f.name for f in DataRow.fields() |
183 |
| - } - allowed_extra_fields |
184 |
| - if invalid_keys: |
185 |
| - raise InvalidAttributeError(DataRow, invalid_keys) |
186 |
| - return item |
187 |
| - |
188 |
| - def format_legacy_conversational_data(item): |
189 |
| - messages = item.pop("conversationalData") |
190 |
| - version = item.pop("version", 1) |
191 |
| - type = item.pop("type", "application/vnd.labelbox.conversational") |
192 |
| - if "externalId" in item: |
193 |
| - external_id = item.pop("externalId") |
194 |
| - item["external_id"] = external_id |
195 |
| - if "globalKey" in item: |
196 |
| - global_key = item.pop("globalKey") |
197 |
| - item["globalKey"] = global_key |
198 |
| - validate_conversational_data(messages) |
199 |
| - one_conversation = \ |
200 |
| - { |
201 |
| - "type": type, |
202 |
| - "version": version, |
203 |
| - "messages": messages |
204 |
| - } |
205 |
| - item["row_data"] = one_conversation |
206 |
| - return item |
207 |
| - |
208 |
| - def convert_item(data_row_item): |
209 |
| - if isinstance(data_row_item, DataRowUpsertItem): |
210 |
| - item = data_row_item.payload |
211 |
| - else: |
212 |
| - item = data_row_item |
213 |
| - |
214 |
| - if "tileLayerUrl" in item: |
215 |
| - validate_attachments(item) |
216 |
| - return item |
217 |
| - |
218 |
| - if "conversationalData" in item: |
219 |
| - format_legacy_conversational_data(item) |
220 |
| - |
221 |
| - # Convert all payload variations into the same dict format |
222 |
| - item = format_row(item) |
223 |
| - # Make sure required keys exist (and there are no extra keys) |
224 |
| - validate_keys(item) |
225 |
| - # Make sure attachments are valid |
226 |
| - validate_attachments(item) |
227 |
| - # Make sure embeddings are valid |
228 |
| - validate_embeddings(item) |
229 |
| - # Parse metadata fields if they exist |
230 |
| - parse_metadata_fields(item) |
231 |
| - # Upload any local file paths |
232 |
| - item = upload_if_necessary(item) |
233 |
| - |
234 |
| - if isinstance(data_row_item, DataRowUpsertItem): |
235 |
| - return {'id': data_row_item.id, 'payload': item} |
236 |
| - else: |
237 |
| - return item |
238 |
| - |
239 |
| - if not isinstance(items, Iterable): |
240 |
| - raise ValueError( |
241 |
| - f"Must pass an iterable to create_data_rows. Found {type(items)}" |
242 |
| - ) |
243 |
| - |
244 |
| - if len(items) > MAX_DATAROW_PER_API_OPERATION: |
245 |
| - raise MalformedQueryException( |
246 |
| - f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." |
247 |
| - ) |
248 |
| - |
249 |
| - with ThreadPoolExecutor(file_upload_thread_count) as executor: |
250 |
| - futures = [executor.submit(convert_item, item) for item in items] |
251 |
| - items = [future.result() for future in as_completed(futures)] |
252 |
| - # Prepare and upload the desciptor file |
253 |
| - data = json.dumps(items) |
254 |
| - return client.upload_data(data, |
255 |
| - content_type="application/json", |
256 |
| - filename="json_import.json") |
257 |
| - |
258 |
| - @staticmethod |
259 |
| - def upload_in_chunks(client, specs: List[DataRowUpsertItem], |
260 |
| - file_upload_thread_count: int, |
261 |
| - upsert_chunk_size: int) -> UploadManifest: |
262 |
| - empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) |
263 |
| - |
264 |
| - if empty_specs: |
265 |
| - ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) |
266 |
| - raise ValueError( |
267 |
| - f"The following items have an empty payload: {ids}") |
268 |
| - |
269 |
| - chunks = [ |
270 |
| - specs[i:i + upsert_chunk_size] |
271 |
| - for i in range(0, len(specs), upsert_chunk_size) |
272 |
| - ] |
273 |
| - |
274 |
| - def _upload_chunk(chunk): |
275 |
| - return DataRowUploader.create_descriptor_file(client, |
276 |
| - chunk, |
277 |
| - is_upsert=True) |
278 |
| - |
279 |
| - with ThreadPoolExecutor(file_upload_thread_count) as executor: |
280 |
| - futures = [ |
281 |
| - executor.submit(_upload_chunk, chunk) for chunk in chunks |
282 |
| - ] |
283 |
| - chunk_uris = [future.result() for future in as_completed(futures)] |
284 |
| - |
285 |
| - return UploadManifest(source="SDK", |
286 |
| - item_count=len(specs), |
287 |
| - chunk_uris=chunk_uris) |
| 31 | + return UploadManifest(source=SOURCE_SDK, |
| 32 | + item_count=len(specs), |
| 33 | + chunk_uris=chunk_uris) |
0 commit comments