1
+ import asyncio
2
+ import concurrent .futures
1
3
import dataclasses
2
4
import logging
3
5
import subprocess
4
6
import time
5
- import asyncio
6
- import concurrent .futures
7
- from workflow import snow
8
7
from datetime import datetime
9
8
from enum import Enum
10
9
from itertools import groupby
14
13
from more_itertools import peekable
15
14
16
15
from workflow import query as q , paths , rai , constants
17
- from workflow . exception import StepTimeOutException , CommandExecutionException
16
+ from workflow import snow
18
17
from workflow .common import EnvConfig , RaiConfig , Source , BatchConfig , Export , FileType , ContainerType , Container , \
19
18
FileMetadata
19
+ from workflow .exception import StepTimeOutException , CommandExecutionException
20
20
from workflow .manager import ResourceManager
21
21
from workflow .utils import save_csv_output , format_duration , build_models , extract_date_range , build_relation_path , \
22
22
get_common_model_relative_path , get_or_create_eventloop
@@ -327,10 +327,12 @@ def _parse_sources(step: dict, env_config: EnvConfig) -> List[Source]:
327
327
328
328
class LoadDataWorkflowStep (WorkflowStep ):
329
329
collapse_partitions_on_load : bool
330
+ load_jointly : bool
330
331
331
- def __init__ (self , idt , name , type_value , state , timing , engine_size , collapse_partitions_on_load ):
332
+ def __init__ (self , idt , name , type_value , state , timing , engine_size , collapse_partitions_on_load , load_jointly ):
332
333
super ().__init__ (idt , name , type_value , state , timing , engine_size )
333
334
self .collapse_partitions_on_load = collapse_partitions_on_load
335
+ self .load_jointly = load_jointly
334
336
335
337
def _execute (self , logger : logging .Logger , env_config : EnvConfig , rai_config : RaiConfig ):
336
338
rai .execute_query (logger , rai_config , env_config , q .DELETE_REFRESHED_SOURCES_DATA , readonly = False )
@@ -349,20 +351,22 @@ def _execute(self, logger: logging.Logger, env_config: EnvConfig, rai_config: Ra
349
351
else :
350
352
simple_resources .append (src )
351
353
352
- for src in simple_resources :
353
- self ._load_source (logger , env_config , rai_config , src )
354
+ self ._load_simple_resources (logger , env_config , rai_config , simple_resources )
354
355
355
356
if async_resources :
356
- for src in async_resources :
357
- self ._load_source (logger , env_config , rai_config , src )
358
- self .await_pending (env_config , logger , missed_resources )
357
+ self ._load_async_resources (logger , env_config , rai_config , async_resources )
359
358
360
- def await_pending (self , env_config , logger , missed_resources ):
359
+ def _load_async_resources (self , logger : logging .Logger , env_config : EnvConfig , rai_config : RaiConfig ,
360
+ async_resources ) -> None :
361
+ for src in async_resources :
362
+ self ._load_async_resource (logger , env_config , rai_config , src )
363
+ self ._await_pending (env_config , logger , async_resources )
364
+
365
+ def _await_pending (self , env_config , logger , pending_resources ):
361
366
loop = get_or_create_eventloop ()
362
367
if loop .is_running ():
363
368
raise Exception ('Waiting for resource would interrupt unexpected event loop - aborting to avoid confusion' )
364
- pending = [src for src in missed_resources if self ._resource_is_async (src )]
365
- pending_cos = [self ._await_async_resource (logger , env_config , resource ) for resource in pending ]
369
+ pending_cos = [self ._await_async_resource (logger , env_config , resource ) for resource in pending_resources ]
366
370
loop .run_until_complete (asyncio .gather (* pending_cos ))
367
371
368
372
async def _await_async_resource (self , logger : logging .Logger , env_config : EnvConfig , src ):
@@ -371,69 +375,103 @@ async def _await_async_resource(self, logger: logging.Logger, env_config: EnvCon
371
375
if ContainerType .SNOWFLAKE == container .type :
372
376
await snow .await_data_sync (logger , config , src ["resources" ])
373
377
374
- def _load_source (self , logger : logging .Logger , env_config : EnvConfig , rai_config : RaiConfig , src ):
375
- source_name = src ["source" ]
376
- if 'is_date_partitioned' in src and src ['is_date_partitioned' ] == 'Y' :
377
- logger .info (f"Loading source '{ source_name } ' partitioned by date" )
378
- if self .collapse_partitions_on_load :
379
- srcs = src ["dates" ]
380
- first_date = srcs [0 ]["date" ]
381
- last_date = srcs [- 1 ]["date" ]
382
-
383
- logger .info (
384
- f"Loading '{ source_name } ' all date partitions simultaneously, range { first_date } to { last_date } " )
385
-
386
- resources = []
387
- for d in srcs :
388
- resources += d ["resources" ]
389
- self ._load_resource (logger , env_config , rai_config , resources , src )
390
- else :
391
- logger .info (f"Loading '{ source_name } ' one date partition at a time" )
392
- for d in src ["dates" ]:
393
- logger .info (f"Loading partition for date { d ['date' ]} " )
394
-
395
- for res in d ["resources" ]:
396
- self ._load_resource (logger , env_config , rai_config , [res ], src )
378
+ def _load_simple_resources (self , logger : logging .Logger , env_config : EnvConfig , rai_config : RaiConfig ,
379
+ simple_resources ) -> None :
380
+ # prepare queries for simple resources
381
+ query_batches = []
382
+ for src in simple_resources :
383
+ # now add all the items returned by _get_data_load_query` to the `query_batches` list
384
+ query_batches .extend (self ._get_data_load_query (logger , env_config , src ))
385
+
386
+ # execute queries for simple resources, if `load_jointly` is set to True then execute all queries in one txn
387
+ if self .load_jointly :
388
+ logger .info ("Loading all CSV/JSON(L) sources jointly" )
389
+ query = ""
390
+ inputs = {}
391
+ for query_with_input in query_batches :
392
+ query += query_with_input .query
393
+ inputs .update (query_with_input .inputs )
394
+ rai .execute_query (logger , rai_config , env_config , query , inputs , readonly = False )
397
395
else :
398
- logger .info (f"Loading source '{ source_name } ' not partitioned by date" )
399
- if self .collapse_partitions_on_load :
400
- logger .info (f"Loading '{ source_name } ' all chunk partitions simultaneously" )
401
- self ._load_resource (logger , env_config , rai_config , src ["resources" ], src )
402
- else :
403
- logger .info (f"Loading '{ source_name } ' one chunk partition at a time" )
404
- for res in src ["resources" ]:
405
- self ._load_resource (logger , env_config , rai_config , [res ], src )
406
-
407
- @staticmethod
408
- def _resource_is_async (src ):
409
- return True if ContainerType .SNOWFLAKE == ContainerType .from_source (src ) else False
396
+ for query_with_input in query_batches :
397
+ rai .execute_query (logger , rai_config , env_config , query_with_input .query , query_with_input .inputs ,
398
+ readonly = False )
410
399
411
- @staticmethod
412
- def _load_resource (logger : logging .Logger , env_config : EnvConfig , rai_config : RaiConfig , resources , src ) -> None :
400
+ def _get_data_load_query (self , logger : logging .Logger , env_config : EnvConfig , src ) -> list :
413
401
try :
414
402
container = env_config .get_container (src ["container" ])
415
403
config = EnvConfig .get_config (container )
416
- if ContainerType .LOCAL == container .type or ContainerType .AZURE == container .type :
417
- query_with_input = q .load_resources (logger , config , resources , src )
418
- rai .execute_query (logger , rai_config , env_config , query_with_input .query , query_with_input .inputs ,
419
- readonly = False )
420
- elif ContainerType .SNOWFLAKE == container .type :
421
- snow .begin_data_sync (logger , config , rai_config , resources , src )
404
+ if 'is_date_partitioned' in src and src ['is_date_partitioned' ] == 'Y' :
405
+ return self ._get_date_part_load_query (logger , config , src )
406
+ else :
407
+ return self ._get_simple_src_load_query (logger , config , src )
422
408
except KeyError as e :
423
409
logger .error (f"Unsupported file type: { src ['file_type' ]} . Skip the source: { src } " , e )
424
410
except ValueError as e :
425
411
logger .error (f"Unsupported source type. Skip the source: { src } " , e )
412
+ return [] # return empty list if source is not supported
413
+
414
+ def _get_date_part_load_query (self , logger : logging .Logger , config , src ):
415
+ source_name = src ["source" ]
416
+ logger .info (f"Loading source '{ source_name } ' partitioned by date" )
417
+ if self .collapse_partitions_on_load :
418
+ srcs = src ["dates" ]
419
+ first_date = srcs [0 ]["date" ]
420
+ last_date = srcs [- 1 ]["date" ]
421
+
422
+ logger .info (
423
+ f"Loading '{ source_name } ' all date partitions simultaneously, range { first_date } to { last_date } " )
424
+
425
+ resources = []
426
+ for d in srcs :
427
+ resources += d ["resources" ]
428
+ return [q .load_resources (logger , config , resources , src )]
429
+ else :
430
+ logger .info (f"Loading '{ source_name } ' one date partition at a time" )
431
+ batch = []
432
+ for d in src ["dates" ]:
433
+ logger .info (f"Loading partition for date { d ['date' ]} " )
434
+
435
+ for res in d ["resources" ]:
436
+ batch .append (q .load_resources (logger , config , [res ], src ))
437
+ return batch
438
+
439
+ def _get_simple_src_load_query (self , logger : logging .Logger , config , src ):
440
+ source_name = src ["source" ]
441
+ logger .info (f"Loading source '{ source_name } ' not partitioned by date" )
442
+ if self .collapse_partitions_on_load :
443
+ logger .info (f"Loading '{ source_name } ' all chunk partitions simultaneously" )
444
+ return [q .load_resources (logger , config , src ["resources" ], src )]
445
+ else :
446
+ logger .info (f"Loading '{ source_name } ' one chunk partition at a time" )
447
+ batch = []
448
+ for res in src ["resources" ]:
449
+ batch .append (q .load_resources (logger , config , [res ], src ))
450
+ return batch
451
+
452
+ @staticmethod
453
+ def _resource_is_async (src ):
454
+ return True if ContainerType .SNOWFLAKE == ContainerType .from_source (src ) else False
455
+
456
+ @staticmethod
457
+ def _load_async_resource (logger : logging .Logger , env_config : EnvConfig , rai_config : RaiConfig , resources , src ) -> \
458
+ None :
459
+ container = env_config .get_container (src ["container" ])
460
+ config = EnvConfig .get_config (container )
461
+ snow .begin_data_sync (logger , config , rai_config , resources , src )
426
462
427
463
428
464
class LoadDataWorkflowStepFactory (WorkflowStepFactory ):
429
465
430
466
def _required_params (self , config : WorkflowConfig ) -> List [str ]:
431
- return [constants .COLLAPSE_PARTITIONS_ON_LOAD ]
467
+ return [constants .COLLAPSE_PARTITIONS_ON_LOAD , constants . LOAD_DATA_JOINTLY ]
432
468
433
469
def _get_step (self , logger : logging .Logger , config : WorkflowConfig , idt , name , type_value , state , timing ,
434
470
engine_size , step : dict ) -> WorkflowStep :
435
471
collapse_partitions_on_load = config .step_params [constants .COLLAPSE_PARTITIONS_ON_LOAD ]
436
- return LoadDataWorkflowStep (idt , name , type_value , state , timing , engine_size , collapse_partitions_on_load )
472
+ load_jointly = config .step_params [constants .LOAD_DATA_JOINTLY ]
473
+ return LoadDataWorkflowStep (idt , name , type_value , state , timing , engine_size , collapse_partitions_on_load ,
474
+ load_jointly )
437
475
438
476
439
477
class MaterializeWorkflowStep (WorkflowStep ):
0 commit comments