Skip to content

Commit 46f8588

Browse files
authoredMar 21, 2024··
adding joint data loading flag (#79)
1 parent e3838c8 commit 46f8588

9 files changed

+138
-77
lines changed
 

‎cli-e2e-test/rel/json_schema_mapping.rel

+4
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99

1010
def JsonSourceConfigKey(k) { json_source_config(k, x...) from x... }
1111

12+
@function
1213
def json_source:value[src in JsonSourceConfigKey] = source_catalog[src, _, :value]
14+
@function
1315
def json_source:child[src in JsonSourceConfigKey] = source_catalog[src, _, :child]
1416
def json_source:root[src in JsonSourceConfigKey] = source_catalog[src, _, :root]
1517
def json_source:array[src in JsonSourceConfigKey] = source_catalog[src, _, :array]
1618

19+
@function
1720
def json_source:value[src in JsonSourceConfigKey] = simple_source_catalog[src, :value]
21+
@function
1822
def json_source:child[src in JsonSourceConfigKey] = simple_source_catalog[src, :child]
1923
def json_source:root[src in JsonSourceConfigKey] = simple_source_catalog[src, :root]
2024
def json_source:array[src in JsonSourceConfigKey] = simple_source_catalog[src, :array]

‎cli-e2e-test/test_e2e.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ def test_scenario1_model(self):
4040
self.assertNotEqual(rsp, 1)
4141
self.assert_output_dir_files(self.test_scenario1_model.__name__)
4242

43+
def test_scenario1_load_data_jointly(self):
44+
# when
45+
test_args = ["--batch-config", "./config/model/scenario1.json",
46+
"--end-date", "20220105",
47+
"--drop-db", "--load-data-jointly"]
48+
rsp = call(self.cmd_with_common_arguments + test_args)
49+
# then
50+
self.assertNotEqual(rsp, 1)
51+
self.assert_output_dir_files(self.test_scenario1_model.__name__)
52+
4353
def test_scenario1_model_yaml(self):
4454
# when
4555
test_args = ["--batch-config", "./config/model/scenario1.yaml",
@@ -122,7 +132,7 @@ def test_scenario3_model_single_partition_change_for_date_partitioned(self):
122132
rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, self.env_config, RESOURCES_TO_DELETE_REL)
123133
self.assertEqual(rsp_json, [{'partition': 2023090800001, 'relation': 'city_data'}])
124134

125-
def test_scenario3_model_two_partitions_overriden_by_one_for_date_partitioned(self):
135+
def test_scenario3_model_two_partitions_overridden_by_one_for_date_partitioned(self):
126136
# when
127137
test_args = ["--batch-config", "./config/model/scenario3.json",
128138
"--start-date", "20230908",

‎cli/args.py

+6
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def parse() -> Namespace:
8181
action=BooleanOptionalAction,
8282
default=True
8383
)
84+
parser.add_argument(
85+
"--load-data-jointly",
86+
help="When loading data, load all sources and partitions in one transaction",
87+
action=BooleanOptionalAction,
88+
default=False
89+
)
8490
parser.add_argument(
8591
"--log-level",
8692
help="Set log level",

‎cli/runner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def start(factories: dict[str, workflow.executor.WorkflowStepFactory] = MappingP
5151
workflow.constants.END_DATE: args.end_date,
5252
workflow.constants.FORCE_REIMPORT: args.force_reimport,
5353
workflow.constants.FORCE_REIMPORT_NOT_CHUNK_PARTITIONED: args.force_reimport_not_chunk_partitioned,
54-
workflow.constants.COLLAPSE_PARTITIONS_ON_LOAD: args.collapse_partitions_on_load
54+
workflow.constants.COLLAPSE_PARTITIONS_ON_LOAD: args.collapse_partitions_on_load,
55+
workflow.constants.LOAD_DATA_JOINTLY: args.load_data_jointly
5556
}
5657
config = workflow.executor.WorkflowConfig(env_config, workflow.common.BatchConfig(args.batch_config_name,
5758
batch_config_json),

‎test/test_cfg_src_step_factory.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def _create_wf_cfg(env_config: EnvConfig, batch_config: BatchConfig) -> Workflow
7676
constants.END_DATE: "2021-01-01",
7777
constants.FORCE_REIMPORT: False,
7878
constants.FORCE_REIMPORT_NOT_CHUNK_PARTITIONED: False,
79-
constants.COLLAPSE_PARTITIONS_ON_LOAD: False
79+
constants.COLLAPSE_PARTITIONS_ON_LOAD: False,
80+
constants.LOAD_DATA_JOINTLY: False
8081
}
8182
return WorkflowConfig(
8283
env=env_config,

‎workflow/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version_info__ = (0, 0, 44)
15+
__version_info__ = (0, 0, 45)
1616
__version__ = ".".join(map(str, __version_info__))

‎workflow/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
FORCE_REIMPORT = "force_reimport"
8181
FORCE_REIMPORT_NOT_CHUNK_PARTITIONED = "force_reimport_not_chunk_partitioned"
8282
COLLAPSE_PARTITIONS_ON_LOAD = "collapse_partitions_on_load"
83+
LOAD_DATA_JOINTLY = "load_data_jointly"
8384

8485
# Snowflake constants
8586

‎workflow/executor.py

+96-58
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
import asyncio
2+
import concurrent.futures
13
import dataclasses
24
import logging
35
import subprocess
46
import time
5-
import asyncio
6-
import concurrent.futures
7-
from workflow import snow
87
from datetime import datetime
98
from enum import Enum
109
from itertools import groupby
@@ -14,9 +13,10 @@
1413
from more_itertools import peekable
1514

1615
from workflow import query as q, paths, rai, constants
17-
from workflow.exception import StepTimeOutException, CommandExecutionException
16+
from workflow import snow
1817
from workflow.common import EnvConfig, RaiConfig, Source, BatchConfig, Export, FileType, ContainerType, Container, \
1918
FileMetadata
19+
from workflow.exception import StepTimeOutException, CommandExecutionException
2020
from workflow.manager import ResourceManager
2121
from workflow.utils import save_csv_output, format_duration, build_models, extract_date_range, build_relation_path, \
2222
get_common_model_relative_path, get_or_create_eventloop
@@ -327,10 +327,12 @@ def _parse_sources(step: dict, env_config: EnvConfig) -> List[Source]:
327327

328328
class LoadDataWorkflowStep(WorkflowStep):
329329
collapse_partitions_on_load: bool
330+
load_jointly: bool
330331

331-
def __init__(self, idt, name, type_value, state, timing, engine_size, collapse_partitions_on_load):
332+
def __init__(self, idt, name, type_value, state, timing, engine_size, collapse_partitions_on_load, load_jointly):
332333
super().__init__(idt, name, type_value, state, timing, engine_size)
333334
self.collapse_partitions_on_load = collapse_partitions_on_load
335+
self.load_jointly = load_jointly
334336

335337
def _execute(self, logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig):
336338
rai.execute_query(logger, rai_config, env_config, q.DELETE_REFRESHED_SOURCES_DATA, readonly=False)
@@ -349,20 +351,22 @@ def _execute(self, logger: logging.Logger, env_config: EnvConfig, rai_config: Ra
349351
else:
350352
simple_resources.append(src)
351353

352-
for src in simple_resources:
353-
self._load_source(logger, env_config, rai_config, src)
354+
self._load_simple_resources(logger, env_config, rai_config, simple_resources)
354355

355356
if async_resources:
356-
for src in async_resources:
357-
self._load_source(logger, env_config, rai_config, src)
358-
self.await_pending(env_config, logger, missed_resources)
357+
self._load_async_resources(logger, env_config, rai_config, async_resources)
359358

360-
def await_pending(self, env_config, logger, missed_resources):
359+
def _load_async_resources(self, logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig,
360+
async_resources) -> None:
361+
for src in async_resources:
362+
self._load_async_resource(logger, env_config, rai_config, src)
363+
self._await_pending(env_config, logger, async_resources)
364+
365+
def _await_pending(self, env_config, logger, pending_resources):
361366
loop = get_or_create_eventloop()
362367
if loop.is_running():
363368
raise Exception('Waiting for resource would interrupt unexpected event loop - aborting to avoid confusion')
364-
pending = [src for src in missed_resources if self._resource_is_async(src)]
365-
pending_cos = [self._await_async_resource(logger, env_config, resource) for resource in pending]
369+
pending_cos = [self._await_async_resource(logger, env_config, resource) for resource in pending_resources]
366370
loop.run_until_complete(asyncio.gather(*pending_cos))
367371

368372
async def _await_async_resource(self, logger: logging.Logger, env_config: EnvConfig, src):
@@ -371,69 +375,103 @@ async def _await_async_resource(self, logger: logging.Logger, env_config: EnvCon
371375
if ContainerType.SNOWFLAKE == container.type:
372376
await snow.await_data_sync(logger, config, src["resources"])
373377

374-
def _load_source(self, logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig, src):
375-
source_name = src["source"]
376-
if 'is_date_partitioned' in src and src['is_date_partitioned'] == 'Y':
377-
logger.info(f"Loading source '{source_name}' partitioned by date")
378-
if self.collapse_partitions_on_load:
379-
srcs = src["dates"]
380-
first_date = srcs[0]["date"]
381-
last_date = srcs[-1]["date"]
382-
383-
logger.info(
384-
f"Loading '{source_name}' all date partitions simultaneously, range {first_date} to {last_date}")
385-
386-
resources = []
387-
for d in srcs:
388-
resources += d["resources"]
389-
self._load_resource(logger, env_config, rai_config, resources, src)
390-
else:
391-
logger.info(f"Loading '{source_name}' one date partition at a time")
392-
for d in src["dates"]:
393-
logger.info(f"Loading partition for date {d['date']}")
394-
395-
for res in d["resources"]:
396-
self._load_resource(logger, env_config, rai_config, [res], src)
378+
def _load_simple_resources(self, logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig,
379+
simple_resources) -> None:
380+
# prepare queries for simple resources
381+
query_batches = []
382+
for src in simple_resources:
383+
# now add all the items returned by _get_data_load_query` to the `query_batches` list
384+
query_batches.extend(self._get_data_load_query(logger, env_config, src))
385+
386+
# execute queries for simple resources, if `load_jointly` is set to True then execute all queries in one txn
387+
if self.load_jointly:
388+
logger.info("Loading all CSV/JSON(L) sources jointly")
389+
query = ""
390+
inputs = {}
391+
for query_with_input in query_batches:
392+
query += query_with_input.query
393+
inputs.update(query_with_input.inputs)
394+
rai.execute_query(logger, rai_config, env_config, query, inputs, readonly=False)
397395
else:
398-
logger.info(f"Loading source '{source_name}' not partitioned by date")
399-
if self.collapse_partitions_on_load:
400-
logger.info(f"Loading '{source_name}' all chunk partitions simultaneously")
401-
self._load_resource(logger, env_config, rai_config, src["resources"], src)
402-
else:
403-
logger.info(f"Loading '{source_name}' one chunk partition at a time")
404-
for res in src["resources"]:
405-
self._load_resource(logger, env_config, rai_config, [res], src)
406-
407-
@staticmethod
408-
def _resource_is_async(src):
409-
return True if ContainerType.SNOWFLAKE == ContainerType.from_source(src) else False
396+
for query_with_input in query_batches:
397+
rai.execute_query(logger, rai_config, env_config, query_with_input.query, query_with_input.inputs,
398+
readonly=False)
410399

411-
@staticmethod
412-
def _load_resource(logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig, resources, src) -> None:
400+
def _get_data_load_query(self, logger: logging.Logger, env_config: EnvConfig, src) -> list:
413401
try:
414402
container = env_config.get_container(src["container"])
415403
config = EnvConfig.get_config(container)
416-
if ContainerType.LOCAL == container.type or ContainerType.AZURE == container.type:
417-
query_with_input = q.load_resources(logger, config, resources, src)
418-
rai.execute_query(logger, rai_config, env_config, query_with_input.query, query_with_input.inputs,
419-
readonly=False)
420-
elif ContainerType.SNOWFLAKE == container.type:
421-
snow.begin_data_sync(logger, config, rai_config, resources, src)
404+
if 'is_date_partitioned' in src and src['is_date_partitioned'] == 'Y':
405+
return self._get_date_part_load_query(logger, config, src)
406+
else:
407+
return self._get_simple_src_load_query(logger, config, src)
422408
except KeyError as e:
423409
logger.error(f"Unsupported file type: {src['file_type']}. Skip the source: {src}", e)
424410
except ValueError as e:
425411
logger.error(f"Unsupported source type. Skip the source: {src}", e)
412+
return [] # return empty list if source is not supported
413+
414+
def _get_date_part_load_query(self, logger: logging.Logger, config, src):
415+
source_name = src["source"]
416+
logger.info(f"Loading source '{source_name}' partitioned by date")
417+
if self.collapse_partitions_on_load:
418+
srcs = src["dates"]
419+
first_date = srcs[0]["date"]
420+
last_date = srcs[-1]["date"]
421+
422+
logger.info(
423+
f"Loading '{source_name}' all date partitions simultaneously, range {first_date} to {last_date}")
424+
425+
resources = []
426+
for d in srcs:
427+
resources += d["resources"]
428+
return [q.load_resources(logger, config, resources, src)]
429+
else:
430+
logger.info(f"Loading '{source_name}' one date partition at a time")
431+
batch = []
432+
for d in src["dates"]:
433+
logger.info(f"Loading partition for date {d['date']}")
434+
435+
for res in d["resources"]:
436+
batch.append(q.load_resources(logger, config, [res], src))
437+
return batch
438+
439+
def _get_simple_src_load_query(self, logger: logging.Logger, config, src):
440+
source_name = src["source"]
441+
logger.info(f"Loading source '{source_name}' not partitioned by date")
442+
if self.collapse_partitions_on_load:
443+
logger.info(f"Loading '{source_name}' all chunk partitions simultaneously")
444+
return [q.load_resources(logger, config, src["resources"], src)]
445+
else:
446+
logger.info(f"Loading '{source_name}' one chunk partition at a time")
447+
batch = []
448+
for res in src["resources"]:
449+
batch.append(q.load_resources(logger, config, [res], src))
450+
return batch
451+
452+
@staticmethod
453+
def _resource_is_async(src):
454+
return True if ContainerType.SNOWFLAKE == ContainerType.from_source(src) else False
455+
456+
@staticmethod
457+
def _load_async_resource(logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig, resources, src) ->\
458+
None:
459+
container = env_config.get_container(src["container"])
460+
config = EnvConfig.get_config(container)
461+
snow.begin_data_sync(logger, config, rai_config, resources, src)
426462

427463

428464
class LoadDataWorkflowStepFactory(WorkflowStepFactory):
429465

430466
def _required_params(self, config: WorkflowConfig) -> List[str]:
431-
return [constants.COLLAPSE_PARTITIONS_ON_LOAD]
467+
return [constants.COLLAPSE_PARTITIONS_ON_LOAD, constants.LOAD_DATA_JOINTLY]
432468

433469
def _get_step(self, logger: logging.Logger, config: WorkflowConfig, idt, name, type_value, state, timing,
434470
engine_size, step: dict) -> WorkflowStep:
435471
collapse_partitions_on_load = config.step_params[constants.COLLAPSE_PARTITIONS_ON_LOAD]
436-
return LoadDataWorkflowStep(idt, name, type_value, state, timing, engine_size, collapse_partitions_on_load)
472+
load_jointly = config.step_params[constants.LOAD_DATA_JOINTLY]
473+
return LoadDataWorkflowStep(idt, name, type_value, state, timing, engine_size, collapse_partitions_on_load,
474+
load_jointly)
437475

438476

439477
class MaterializeWorkflowStep(WorkflowStep):

‎workflow/query.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,10 @@ def _local_load_multipart_query(rel_name: str, file_type: FileType, parts) -> Qu
307307
load_config = _multi_part_load_config_query(rel_name, file_type,
308308
_local_multipart_config_integration(raw_data_rel_name))
309309

310-
query = f"{_part_index_relation(part_indexes)}\n" \
310+
query = f"{_part_index_relation(rel_name, part_indexes)}\n" \
311311
f"{raw_text}\n" \
312312
f"{load_config}\n" \
313-
f"{insert_text}"
313+
f"{insert_text}\n"
314314

315315
return QueryWithInputs(query, inputs)
316316

@@ -330,10 +330,10 @@ def _azure_load_multipart_query(rel_name: str, file_type: FileType, parts, confi
330330
load_config = _multi_part_load_config_query(rel_name, file_type,
331331
_azure_multipart_config_integration(path_rel_name, config))
332332

333-
return f"{_part_index_relation(part_indexes)}\n" \
333+
return f"{_part_index_relation(rel_name, part_indexes)}\n" \
334334
f"{_path_rel_name_relation(path_rel_name, part_uri_map)}\n" \
335335
f"{load_config}\n" \
336-
f"{insert_text}"
336+
f"{insert_text}\n"
337337

338338

339339
def _multi_part_load_config_query(rel_name: str, file_type: FileType, config_integration: str) -> str:
@@ -345,7 +345,7 @@ def _multi_part_load_config_query(rel_name: str, file_type: FileType, config_int
345345
return f"""
346346
bound {IMPORT_CONFIG_REL}:{rel_name}:schema
347347
bound {IMPORT_CONFIG_REL}:{rel_name}:syntax:header
348-
module {_config_rel_name(rel_name)}[i in part_indexes]
348+
module {_config_rel_name(rel_name)}[i in part_indexes:{rel_name}]
349349
{schema}
350350
{config_integration}
351351
end
@@ -383,25 +383,25 @@ def _indexed_literal(raw_data_rel_name: str, index: int) -> str:
383383
def _config_rel_name(rel: str) -> str: return f"load_{rel}_config"
384384

385385

386-
def _part_index_relation(part_index: str) -> str:
387-
return f"def part_index_config:schema:INDEX = \"int\"\n" \
388-
f"def part_index_config:data = \"\"\"\n" \
386+
def _part_index_relation(rel_name: str, part_index: str) -> str:
387+
return f"def part_index_config:{rel_name}:schema:INDEX = \"int\"\n" \
388+
f"def part_index_config:{rel_name}:data = \"\"\"\n" \
389389
f"INDEX\n" \
390390
f"{part_index}\n" \
391391
f"\"\"\"" \
392-
f"def part_indexes_csv = load_csv[part_index_config]\n" \
393-
f"def part_indexes = part_indexes_csv:INDEX[_]"
392+
f"def part_indexes_csv:{rel_name} = load_csv[part_index_config:{rel_name}]\n" \
393+
f"def part_indexes:{rel_name} = part_indexes_csv:{rel_name}:INDEX[_]"
394394

395395

396396
def _path_rel_name_relation(path_rel_name: str, part_uri_map: str) -> str:
397-
return f"def part_uri_map_config:schema:INDEX = \"int\"\n" \
398-
f"def part_uri_map_config:schema:URI = \"string\"\n" \
399-
f"def part_uri_map_config:data = \"\"\"\n" \
397+
return f"def part_uri_map_config:{path_rel_name}:schema:INDEX = \"int\"\n" \
398+
f"def part_uri_map_config:{path_rel_name}:schema:URI = \"string\"\n" \
399+
f"def part_uri_map_config:{path_rel_name}:data = \"\"\"\n" \
400400
f"INDEX,URI\n" \
401401
f"{part_uri_map}\n" \
402402
f"\"\"\"" \
403-
f"def part_uri_map_csv = load_csv[part_uri_map_config]\n" \
404-
f"def {path_rel_name}(i, u) {{ part_uri_map_csv:INDEX(row, i) and part_uri_map_csv:URI(row, u) from row }}"
403+
f"def part_uri_map_csv:{path_rel_name} = load_csv[part_uri_map_config:{path_rel_name}]\n" \
404+
f"def {path_rel_name}(i, u) {{ part_uri_map_csv:{path_rel_name}:INDEX(row, i) and part_uri_map_csv:{path_rel_name}:URI(row, u) from row }}"
405405

406406

407407
def _export_relation_as_csv_local(rel_name) -> str:

0 commit comments

Comments
 (0)
Please sign in to comment.