Skip to content

Commit

Permalink
Merge branch 'main' into feature/non-source-tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
lakshmi2506 authored Jan 3, 2024
2 parents 7b94c49 + be0877f commit 8d74ec0
Show file tree
Hide file tree
Showing 23 changed files with 1,138 additions and 121 deletions.
8 changes: 7 additions & 1 deletion cumulusci/salesforce_api/retrieve_profile_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(
class RetrieveProfileApi(BaseSalesforceApiTask):
def _init_task(self):
super(RetrieveProfileApi, self)._init_task()
self.api_version = self.org_config.latest_api_version
self.api_version = self.project_config.config["project"]["package"][
"api_version"
]

def _retrieve_existing_profiles(self, profiles: List[str]):
query = self._build_query(["Name"], "Profile", {"Name": profiles})
Expand All @@ -97,6 +99,10 @@ def _retrieve_existing_profiles(self, profiles: List[str]):
for data in result["records"]:
existing_profiles.append(data["Name"])

# Since System Administrator is named Admin in Metadata API
if "Admin" in profiles:
existing_profiles.extend(["Admin", "System Administrator"])

return existing_profiles

def _run_query(self, query):
Expand Down
6 changes: 4 additions & 2 deletions cumulusci/salesforce_api/tests/test_retrieve_profile_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def retrieve_profile_api_instance():
project_config = MagicMock()
task_config = MagicMock()
org_config = MagicMock()
org_config.latest_api_version = "58.0"
project_config.config = {"project": {"package": {"api_version": "58.0"}}}
sf_mock.query.return_value = {"records": []}
api = RetrieveProfileApi(
project_config=project_config, org_config=org_config, task_config=task_config
Expand All @@ -36,7 +36,7 @@ def test_init_task(retrieve_profile_api_instance):


def test_retrieve_existing_profiles(retrieve_profile_api_instance):
profiles = ["Profile1", "Profile2"]
profiles = ["Profile1", "Profile2", "Admin"]
result = {"records": [{"Name": "Profile1"}]}
with patch.object(
RetrieveProfileApi, "_build_query", return_value="some_query"
Expand All @@ -47,6 +47,8 @@ def test_retrieve_existing_profiles(retrieve_profile_api_instance):

assert "Profile1" in existing_profiles
assert "Profile2" not in existing_profiles
assert "Admin" in existing_profiles
assert "System Administrator" in existing_profiles


def test_run_query_sf(retrieve_profile_api_instance):
Expand Down
5 changes: 4 additions & 1 deletion cumulusci/tasks/bulkdata/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def strip_name_field(record):

if "RecordTypeId" in mapping.fields:
self._extract_record_types(
mapping.sf_object, mapping.get_source_record_type_table(), conn
mapping.sf_object,
mapping.get_source_record_type_table(),
conn,
self.org_config.is_person_accounts_enabled,
)

self.session.commit()
Expand Down
203 changes: 195 additions & 8 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session

from cumulusci.core.enums import StrEnum
from cumulusci.core.exceptions import BulkDataException, TaskOptionsError
from cumulusci.core.utils import process_bool_arg
from cumulusci.salesforce_api.org_schema import get_org_schema
Expand All @@ -28,9 +29,11 @@
)
from cumulusci.tasks.bulkdata.step import (
DEFAULT_BULK_BATCH_SIZE,
DataApi,
DataOperationJobResult,
DataOperationStatus,
DataOperationType,
RestApiDmlOperation,
get_dml_operation,
)
from cumulusci.tasks.bulkdata.upsert_utils import (
Expand Down Expand Up @@ -88,6 +91,9 @@ class LoadData(SqlAlchemyMixin, BaseSalesforceApiTask):
"org_shape_match_only": {
"description": "When True, all path options are ignored and only a dataset matching the org shape name will be loaded. Defaults to False."
},
"enable_rollback": {
"description": "When True, performs rollback operation incase of error. Defaults to False"
},
}
row_warning_limit = 10

Expand Down Expand Up @@ -115,6 +121,9 @@ def _init_options(self, kwargs):
self.options["set_recently_viewed"] = process_bool_arg(
self.options.get("set_recently_viewed", True)
)
self.options["enable_rollback"] = process_bool_arg(
self.options.get("enable_rollback", False)
)

def _init_dataset(self):
"""Find the dataset paths to use with the following sequence:
Expand Down Expand Up @@ -261,13 +270,33 @@ def _execute_step(
step, query = self.configure_step(mapping)

with tempfile.TemporaryFile(mode="w+t") as local_ids:
# Store the previous values of the records before upsert
# This is so that we can perform rollback
if (
mapping.action
in [
DataOperationType.ETL_UPSERT,
DataOperationType.UPSERT,
DataOperationType.UPDATE,
]
and self.options["enable_rollback"]
):
UpdateRollback.prepare_for_rollback(
self, step, self._stream_queried_data(mapping, local_ids, query)
)
step.start()
step.load_records(self._stream_queried_data(mapping, local_ids, query))
step.end()

# Process Job Results
if step.job_result.status is not DataOperationStatus.JOB_FAILURE:
local_ids.seek(0)
self._process_job_results(mapping, step, local_ids)
elif (
step.job_result.status is DataOperationStatus.JOB_FAILURE
and self.options["enable_rollback"]
):
Rollback._perform_rollback(self)

return step.job_result

Expand Down Expand Up @@ -367,7 +396,9 @@ def _load_record_types(self, sobjects, conn):
"""Persist record types for the given sObjects into the database."""
for sobject in sobjects:
table_name = sobject + "_rt_target_mapping"
self._extract_record_types(sobject, table_name, conn)
self._extract_record_types(
sobject, table_name, conn, self.org_config.is_person_accounts_enabled
)

def _get_statics(self, mapping):
"""Return the static values (not column names) to be appended to
Expand Down Expand Up @@ -452,7 +483,7 @@ def _process_job_results(self, mapping, step, local_ids):
id_table_name = self._initialize_id_table(mapping, self.reset_oids)
conn = self.session.connection()

results_generator = self._generate_results_id_map(step, local_ids)
sf_id_results = self._generate_results_id_map(step, local_ids)

# If we know we have no successful inserts, don't attempt to persist Ids.
# Do, however, drain the generator to get error-checking behavior.
Expand All @@ -463,11 +494,8 @@ def _process_job_results(self, mapping, step, local_ids):
connection=conn,
table=self.metadata.tables[id_table_name],
columns=("id", "sf_id"),
record_iterable=results_generator,
record_iterable=sf_id_results,
)
else:
for r in results_generator:
pass # Drain generator to validate results

# Contact records for Person Accounts are inserted during an Account
# sf_object step. Insert records into the Contact ID table for
Expand All @@ -494,16 +522,37 @@ def _process_job_results(self, mapping, step, local_ids):

def _generate_results_id_map(self, step, local_ids):
"""Consume results from load and prepare rows for id table.
Raise BulkDataException on row errors if configured to do so."""
Raise BulkDataException on row errors if configured to do so.
Adds created records into insert_rollback Table
Performs rollback in case of any errors if enable_rollback is True"""
error_checker = RowErrorChecker(
self.logger, self.options["ignore_row_errors"], self.row_warning_limit
)
local_ids = (lid.strip("\n") for lid in local_ids)
sf_id_results = []
created_results = []
failed_results = []
for result, local_id in zip(step.get_results(), local_ids):
if result.success:
yield (local_id, result.id)
sf_id_results.append([local_id, result.id])
if result.created:
created_results.append([result.id])
else:
failed_results.append([result, local_id])

# We record failed_results separately since if a unsuccesful record
# was in between, it would not store all the successful ids
for result, local_id in failed_results:
try:
error_checker.check_for_row_error(result, local_id)
except Exception as e:
if self.options["enable_rollback"]:
CreateRollback.prepare_for_rollback(self, step, created_results)
Rollback._perform_rollback(self)
raise e
if self.options["enable_rollback"]:
CreateRollback.prepare_for_rollback(self, step, created_results)
return sf_id_results

def _initialize_id_table(self, mapping, should_reset_table):
"""initalize or find table to hold the inserted SF Ids
Expand Down Expand Up @@ -566,6 +615,9 @@ def _init_db(self):
self.metadata.bind = connection
self.inspector = inspect(parent_engine)

# empty the record of initalized tables
Rollback._initialized_rollback_tables_api = {}

# initialize the automap mapping
self.base = automap_base(bind=connection, metadata=self.metadata)
self.base.prepare(connection, reflect=True)
Expand Down Expand Up @@ -808,6 +860,141 @@ def _set_viewed(self) -> T.List["SetRecentlyViewedInfo"]:
return results


class RollbackType(StrEnum):
"""Enum to specify type of rollback"""

UPSERT = "upsert_rollback"
INSERT = "insert_rollback"


class Rollback:
# Store the table name and it's corresponding API (rest or bulk)
_initialized_rollback_tables_api = {}

@staticmethod
def _create_tables_for_rollback(context, step, rollback_type: RollbackType) -> str:
"""Create the tables required for upsert and insert rollback"""
table_name = f"{step.sobject}_{rollback_type}"

if table_name not in Rollback._initialized_rollback_tables_api:
common_columns = [Column("Id", Unicode(255), primary_key=True)]

additional_columns = (
[Column(field, Unicode(255)) for field in step.fields if field != "Id"]
if rollback_type is RollbackType.UPSERT
else []
)

columns = common_columns + additional_columns

# Create the table
rollback_table = Table(table_name, context.metadata, *columns)
rollback_table.create()

# Store the API in the initialized tables dictionary
if isinstance(step, RestApiDmlOperation):
Rollback._initialized_rollback_tables_api[table_name] = DataApi.REST
else:
Rollback._initialized_rollback_tables_api[table_name] = DataApi.BULK

return table_name

@staticmethod
def _perform_rollback(context):
"""Perform total rollback"""
context.logger.info("--Initiated Rollback Procedure--")
for table in reversed(context.metadata.sorted_tables):
if table.name.endswith(RollbackType.INSERT):
CreateRollback._perform_rollback(context, table)
elif table.name.endswith(RollbackType.UPSERT):
UpdateRollback._perform_rollback(context, table)
context.logger.info("--Finished Rollback Procedure--")


class UpdateRollback:
@staticmethod
def prepare_for_rollback(context, step, records):
"""Retrieve previous values for records being updated"""
results, columns = step.get_prev_record_values(records)
if results:
table_name = Rollback._create_tables_for_rollback(
context, step, RollbackType.UPSERT
)
conn = context.session.connection()
sql_bulk_insert_from_records(
connection=conn,
table=context.metadata.tables[table_name],
columns=columns,
record_iterable=results,
)

@staticmethod
def _perform_rollback(context, table: Table) -> None:
"""Perform rollback for updated records"""
sf_object = table.name.split(f"_{RollbackType.UPSERT.value}")[0]
records = context.session.query(table).all()

if records:
context.logger.info(f"Reverting upserts for {sf_object}")
api_options = {"update_key": "Id"}

# Use get_dml_operation to create an UPSERT step
step = get_dml_operation(
sobject=sf_object,
operation=DataOperationType.UPSERT,
api_options=api_options,
context=context,
fields=[column.name for column in table.columns],
api=Rollback._initialized_rollback_tables_api[table.name],
volume=len(records),
)
step.start()
step.load_records(records)
step.end()
context.logger.info("Done")


class CreateRollback:
@staticmethod
def prepare_for_rollback(context, step, records):
"""Store the sf_ids of all records that were created
to prepare for rollback"""
if records:
table_name = Rollback._create_tables_for_rollback(
context, step, RollbackType.INSERT
)
conn = context.session.connection()
sql_bulk_insert_from_records(
connection=conn,
table=context.metadata.tables[table_name],
columns=["Id"],
record_iterable=records,
)

@staticmethod
def _perform_rollback(context, table: Table) -> None:
"""Perform rollback for insert operation"""
sf_object = table.name.split(f"_{RollbackType.INSERT.value}")[0]
records = context.session.query(table).all()

if records:
context.logger.info(f"Deleting {sf_object} records")
# Perform DELETE operation using get_dml_operation
step = get_dml_operation(
sobject=sf_object,
operation=DataOperationType.DELETE,
fields=["Id"],
api_options={},
context=context,
api=Rollback._initialized_rollback_tables_api[table.name],
volume=len(records),
)
step.start()
step.load_records(records)
step.end()
context.logger.info("Done")


class StepResultInfo(T.NamedTuple):
"""Represent a Step Result in a form easily convertible to JSON"""

Expand Down
11 changes: 8 additions & 3 deletions cumulusci/tasks/bulkdata/query_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing as T
from functools import cached_property

from sqlalchemy import func, text
from sqlalchemy import and_, func, text
from sqlalchemy.orm import Query, aliased

from cumulusci.core.exceptions import BulkDataException
Expand Down Expand Up @@ -134,10 +134,15 @@ def outerjoins_to_add(self):
rt_source_table.columns.record_type_id
== getattr(self.model, self.mapping.fields["RecordTypeId"]),
),
# Combination of IsPersonType and DeveloperName is unique
(
rt_dest_table,
rt_dest_table.columns.developer_name
== rt_source_table.columns.developer_name,
and_(
rt_dest_table.columns.developer_name
== rt_source_table.columns.developer_name,
rt_dest_table.columns.is_person_type
== rt_source_table.columns.is_person_type,
),
),
]

Expand Down
Loading

0 comments on commit 8d74ec0

Please sign in to comment.