Skip to content

Commit

Permalink
Merge pull request #3731 from SFDO-Tooling/feature/load_data
Browse files Browse the repository at this point in the history
Feature/load data
  • Loading branch information
aditya-balachander authored Jan 16, 2024
2 parents 38a6540 + 773501f commit 2fc9a52
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 105 deletions.
80 changes: 42 additions & 38 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def _init_options(self, kwargs):
self.options["enable_rollback"] = process_bool_arg(
self.options.get("enable_rollback", False)
)
self._id_generators = {}
self._old_format = False

def _init_dataset(self):
"""Find the dataset paths to use with the following sequence:
Expand Down Expand Up @@ -205,10 +207,11 @@ def _run_task(self):
"No data will be loaded because this is a persistent org and no dataset was specified."
)
return
self.ID_TABLE_NAME = "cumulusci_id_table"
self._init_mapping()
with self._init_db():
self._expand_mapping()

self._initialize_id_table(self.reset_oids)
start_step = self.options.get("start_step")
started = False
results = {}
Expand Down Expand Up @@ -359,7 +362,8 @@ def check_simple_upsert(self, mapping):
def _stream_queried_data(self, mapping, local_ids, query):
"""Get data from the local db"""

statics = self._get_statics(mapping)
# statics = self._get_statics(mapping)
staticizer = self._add_statics_to_row(mapping)
total_rows = 0

if mapping.anchor_date:
Expand All @@ -372,13 +376,13 @@ def _stream_queried_data(self, mapping, local_ids, query):
batch_size = mapping.batch_size or DEFAULT_BULK_BATCH_SIZE
for row in query.yield_per(batch_size):
total_rows += 1
# Add static values to row
pkey = row[0]
row = list(row[1:]) + statics

if mapping.anchor_date and (date_context[0] or date_context[1]):
row = adjust_relative_dates(
mapping, date_context, row, DataOperationType.INSERT
)
pkey = row[0] # FIXME: This is a local-DB ordering assumption.
row = staticizer(list(row[1:]))
if mapping.action is DataOperationType.UPDATE:
if len(row) > 1 and all([f is None for f in row[1:]]):
# Skip update rows that contain no values
Expand All @@ -389,7 +393,7 @@ def _stream_queried_data(self, mapping, local_ids, query):
yield row

self.logger.info(
f"Prepared {total_rows} rows for {mapping['action']} to {mapping['sf_object']}."
f"Prepared {total_rows} rows for {mapping.action.value} to {mapping.sf_object}."
)

def _load_record_types(self, sobjects, conn):
Expand All @@ -400,10 +404,9 @@ def _load_record_types(self, sobjects, conn):
sobject, table_name, conn, self.org_config.is_person_accounts_enabled
)

def _get_statics(self, mapping):
"""Return the static values (not column names) to be appended to
records for this mapping."""
def _add_statics_to_row(self, mapping):
statics = list(mapping.static.values())

if mapping.record_type:
query = (
f"SELECT Id FROM RecordType WHERE SObjectType='{mapping.sf_object}'"
Expand All @@ -416,7 +419,10 @@ def _get_statics(self, mapping):
raise BulkDataException(f"Cannot find RecordType with query `{query}`")
statics.append(record_type_id)

return statics
def add_statics(row):
return row + statics

return add_statics

def _query_db(self, mapping):
"""Build a query to retrieve data from the local db.
Expand All @@ -439,12 +445,14 @@ def _query_db(self, mapping):
query = self.session.query(*columns)

classes = [
AddLookupsToQuery,
AddRecordTypesToQuery,
AddMappingFiltersToQuery,
AddUpsertsToQuery,
]
transformers = [cls(mapping, self.metadata, model) for cls in classes]
transformers.append(
AddLookupsToQuery(mapping, self.metadata, model, self._old_format)
)

if mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping):
transformers.append(AddPersonAccountsToQuery(mapping, self.metadata, model))
Expand Down Expand Up @@ -479,20 +487,25 @@ def _process_job_results(self, mapping, step, local_ids):
DataOperationType.UPSERT,
DataOperationType.ETL_UPSERT,
)
if is_insert_or_upsert:
id_table_name = self._initialize_id_table(mapping, self.reset_oids)
conn = self.session.connection()

conn = self.session.connection()
sf_id_results = self._generate_results_id_map(step, local_ids)

for i in range(len(sf_id_results)):
if str(sf_id_results[i][0]).isnumeric():
self._old_format = True
sf_id_results[i][0] = mapping.table + "-" + str(sf_id_results[i][0])
else:
break
# If we know we have no successful inserts, don't attempt to persist Ids.
# Do, however, drain the generator to get error-checking behavior.
if is_insert_or_upsert and (
step.job_result.records_processed - step.job_result.total_row_errors
):
table = self.metadata.tables[self.ID_TABLE_NAME]
sql_bulk_insert_from_records(
connection=conn,
table=self.metadata.tables[id_table_name],
table=table,
columns=("id", "sf_id"),
record_iterable=sf_id_results,
)
Expand All @@ -510,7 +523,7 @@ def _process_job_results(self, mapping, step, local_ids):
if account_id_lookup:
sql_bulk_insert_from_records(
connection=conn,
table=self.metadata.tables[id_table_name],
table=self.metadata.tables[self.ID_TABLE_NAME],
columns=("id", "sf_id"),
record_iterable=self._generate_contact_id_map_for_person_accounts(
mapping, account_id_lookup, conn
Expand Down Expand Up @@ -554,37 +567,28 @@ def _generate_results_id_map(self, step, local_ids):
CreateRollback.prepare_for_rollback(self, step, created_results)
return sf_id_results

def _initialize_id_table(self, mapping, should_reset_table):
def _initialize_id_table(self, should_reset_table):
"""initalize or find table to hold the inserted SF Ids
The table has a name like xxx_sf_ids and has just two columns, id and sf_id.
If the table already exists, should_reset_table determines whether to
drop and recreate it or not.
"""
id_table_name = f"{mapping['table']}_sf_ids"

already_exists = id_table_name in self.metadata.tables
already_exists = self.ID_TABLE_NAME in self.metadata.tables

if already_exists and not should_reset_table:
return id_table_name

if not hasattr(self, "_initialized_id_tables"):
self._initialized_id_tables = set()
if id_table_name not in self._initialized_id_tables:
if already_exists:
self.metadata.remove(self.metadata.tables[id_table_name])
id_table = Table(
id_table_name,
self.metadata,
Column("id", Unicode(255), primary_key=True),
Column("sf_id", Unicode(18)),
)
if self.inspector.has_table(id_table_name):
id_table.drop()
id_table.create()
self._initialized_id_tables.add(id_table_name)
return id_table_name
return
id_table = Table(
self.ID_TABLE_NAME,
self.metadata,
Column("id", Unicode(255), primary_key=True),
Column("sf_id", Unicode(18)),
)
if id_table.exists():
id_table.drop()
id_table.create()

def _sqlite_load(self):
"""Read a SQLite script and initialize the in-memory database."""
Expand Down Expand Up @@ -655,7 +659,7 @@ def _init_mapping(self):
mapping=self.mapping,
sf=self.sf,
namespace=self.project_config.project__package__namespace,
data_operation=DataOperationType.INSERT,
data_operation=DataOperationType.QUERY,
inject_namespaces=self.options["inject_namespaces"],
drop_missing=self.options["drop_missing_schema"],
)
Expand Down
23 changes: 15 additions & 8 deletions cumulusci/tasks/bulkdata/query_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cumulusci.core.exceptions import BulkDataException

Criterion = T.Any
ID_TABLE_NAME = "cumulusci_id_table"


class LoadQueryExtender:
Expand Down Expand Up @@ -50,18 +51,17 @@ def add_outerjoins(self, query: Query):
class AddLookupsToQuery(LoadQueryExtender):
"""Adds columns and joins relatinng to lookups"""

def __init__(self, mapping, metadata, model) -> None:
def __init__(self, mapping, metadata, model, _old_format) -> None:
super().__init__(mapping, metadata, model)
self._old_format = _old_format
self.lookups = [
lookup for lookup in self.mapping.lookups.values() if not lookup.after
]

@cached_property
def columns_to_add(self):
for lookup in self.lookups:
lookup.aliased_table = aliased(
self.metadata.tables[f"{lookup.table}_sf_ids"]
)
lookup.aliased_table = aliased(self.metadata.tables[ID_TABLE_NAME])
return [lookup.aliased_table.columns.sf_id for lookup in self.lookups]

@cached_property
Expand All @@ -71,10 +71,17 @@ def outerjoins_to_add(self):
def join_for_lookup(lookup):
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)
return (
lookup.aliased_table,
lookup.aliased_table.columns.id == value_column,
)
if self._old_format:
return (
lookup.aliased_table,
lookup.aliased_table.columns.id
== str(lookup.table) + "-" + value_column,
)
else:
return (
lookup.aliased_table,
lookup.aliased_table.columns.id == value_column,
)

return [join_for_lookup(lookup) for lookup in self.lookups]

Expand Down
71 changes: 61 additions & 10 deletions cumulusci/tasks/bulkdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from cumulusci.core.exceptions import BulkDataException
from cumulusci.utils.iterators import iterate_in_chunks

ID_TABLE_NAME = "cumulusci_id_table"


class SqlAlchemyMixin:
logger: logging.Logger
Expand Down Expand Up @@ -73,16 +75,60 @@ def _database_url(self):
else:
return self._temp_database_url()

def _id_generator_for_object(self, sobject: str):
if sobject not in self._id_generators:

def _generate_ids():
counter = 0
while True:
yield f"{sobject}-{counter}"
counter += 1

self._id_generators[sobject] = _generate_ids()

return self._id_generators[sobject]

def _update_column(
self, *, source_model, target_model, key_field, join_field, target_field
):
key_attr = getattr(source_model, key_field)
join_attr = getattr(target_model, join_field)
target_attr = getattr(target_model, target_field)

id_column = inspect(source_model).primary_key[0].name

try:
self.session.query(source_model).filter(
key_attr.isnot(None), key_attr == join_attr
).update({key_attr: target_attr}, synchronize_session=False)
except NotImplementedError:
# Some databases, such as SQLite, don't support multitable update
# TODO: review memory consumption of this routine.
mappings = []
for row, lookup_id in self.session.query(source_model, target_attr).join(
target_model, key_attr == join_attr
):
mappings.append(
{id_column: getattr(row, id_column), key_field: lookup_id}
)
self.session.bulk_update_mappings(source_model, mappings)

def _update_sf_id_column(self, model, key_field):
self._update_column(
source_model=model,
target_model=self.models[self.ID_TABLE_NAME],
key_field=key_field,
join_field="sf_id",
target_field="id",
)

def _handle_primary_key(mapping, fields):
"""Provide support for legacy mappings which used the OID as the pk but
default to using an autoincrementing int pk and a separate sf_id column"""
def _is_autopk_database(self):
# If the type of the Id column on a mapping is INTEGER,
# this is an autopk database.

if mapping.get_oid_as_pk():
id_column = mapping.fields["Id"]
fields.append(Column(id_column, Unicode(255), primary_key=True))
else:
fields.append(Column("id", Integer(), primary_key=True, autoincrement=True))
mapping = self.mapping.values()[0]
id_field = mapping.fields["Id"]
return isinstance(getattr(self.models[mapping.table], id_field).type, Integer)


def create_table(mapping, metadata) -> Table:
Expand All @@ -92,10 +138,15 @@ def create_table(mapping, metadata) -> Table:
Mapping should be a MappingStep instance"""

fields = []
_handle_primary_key(mapping, fields)
# _handle_primary_key(mapping, fields)
id_column = mapping.fields["Id"] # Guaranteed to be present by mapping parser.
fields.append(Column(id_column, Unicode(255), primary_key=True))

# make a field list to create
for field, db in mapping.get_complete_field_map().items():
# for field, db in mapping.get_complete_field_map().items():
for field, db in zip(
mapping.get_extract_field_list(), mapping.get_database_column_list()
).items():
if field == "Id":
continue

Expand Down
23 changes: 17 additions & 6 deletions datasets/mapping.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ Contact:
lookups:
AccountId:
table: Account
Opportunity:
sf_object: Opportunity
Lead:
sf_object: Lead
api: rest
batch_size: 2
fields:
- LastName
- Company
Event:
sf_object: Event
api: bulk
batch_size: 2
fields:
- Name
- CloseDate
- Amount
- StageName
- Subject
- DurationInMinutes
- ActivityDateTime
lookups:
WhoId:
table:
- Contact
- Lead
Loading

0 comments on commit 2fc9a52

Please sign in to comment.