diff --git a/cumulusci/core/tests/test_datasets_e2e.py b/cumulusci/core/tests/test_datasets_e2e.py index c5140d3609..387ad696ad 100644 --- a/cumulusci/core/tests/test_datasets_e2e.py +++ b/cumulusci/core/tests/test_datasets_e2e.py @@ -304,6 +304,7 @@ def write_yaml(filename: str, json: Any): "after": "Insert Account", } }, + "select_options": {}, }, "Insert Event": { "sf_object": "Event", @@ -316,16 +317,19 @@ def write_yaml(filename: str, json: Any): "after": "Insert Lead", } }, + "select_options": {}, }, "Insert Account": { "sf_object": "Account", "table": "Account", "fields": ["Name"], + "select_options": {}, }, "Insert Lead": { "sf_object": "Lead", "table": "Lead", "fields": ["Company", "LastName"], + "select_options": {}, }, } assert tuple(actual.items()) == tuple(expected.items()), actual.items() diff --git a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py index 95d6b9ff97..cec42d0bd9 100644 --- a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py +++ b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py @@ -5,7 +5,7 @@ from pydantic import Field, validator from cumulusci.core.enums import StrEnum -from cumulusci.tasks.bulkdata.step import DataApi +from cumulusci.tasks.bulkdata.utils import DataApi from cumulusci.utils.yaml.model_parser import CCIDictModel, HashableBaseModel object_decl = re.compile(r"objects\((\w+)\)", re.IGNORECASE) diff --git a/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py b/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py index 7dbaefc740..69dd0e361d 100644 --- a/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py +++ b/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py @@ -41,6 +41,7 @@ def test_simple_generate_mapping_from_declarations(self, org_config): "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, } } @@ -74,11 +75,13 @@ def test_generate_mapping_from_both_kinds_of_declarations(self, org_config): "sf_object": "Contact", "table": "Contact", "fields": ["FirstName", "LastName"], + "select_options": {}, }, "Insert Account": { "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, }, }.items() ) @@ -111,6 +114,7 @@ def test_generate_load_mapping_from_declarations__lookups(self, org_config): "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", @@ -119,6 +123,7 @@ def test_generate_load_mapping_from_declarations__lookups(self, org_config): "lookups": { "AccountId": {"table": ["Account"], "key_field": "AccountId"} }, + "select_options": {}, }, } @@ -157,6 +162,7 @@ def test_generate_load_mapping_from_declarations__polymorphic_lookups( "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", @@ -165,11 +171,13 @@ def test_generate_load_mapping_from_declarations__polymorphic_lookups( "lookups": { "AccountId": {"table": ["Account"], "key_field": "AccountId"} }, + "select_options": {}, }, "Insert Lead": { "sf_object": "Lead", "table": "Lead", "fields": ["LastName", "Company"], + "select_options": {}, }, "Insert Event": { "sf_object": "Event", @@ -178,6 +186,7 @@ def test_generate_load_mapping_from_declarations__polymorphic_lookups( "lookups": { "WhoId": {"table": ["Contact", "Lead"], "key_field": "WhoId"} }, + "select_options": {}, }, } @@ -221,6 +230,7 @@ def test_generate_load_mapping_from_declarations__circular_lookups( }, "sf_object": "Account", "table": "Account", + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", @@ -229,6 +239,7 @@ def test_generate_load_mapping_from_declarations__circular_lookups( "lookups": { "AccountId": {"table": ["Account"], "key_field": "AccountId"} }, + "select_options": {}, }, }, mf @@ -252,11 +263,13 @@ def test_generate_load_mapping__with_load_declarations(self, org_config): "sf_object": "Account", "api": DataApi.REST, "table": "Account", + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", "api": DataApi.BULK, "table": "Contact", + "select_options": {}, }, }, mf @@ -288,6 +301,7 @@ def test_generate_load_mapping__with_upserts(self, org_config): "Insert Account": { "sf_object": "Account", "table": "Account", + "select_options": {}, }, "Upsert Account Name": { "sf_object": "Account", @@ -295,6 +309,7 @@ def test_generate_load_mapping__with_upserts(self, org_config): "action": DataOperationType.UPSERT, "update_key": ("Name",), "fields": ["Name"], + "select_options": {}, }, "Etl_Upsert Account AccountNumber_Name": { "sf_object": "Account", @@ -302,10 +317,12 @@ def test_generate_load_mapping__with_upserts(self, org_config): "action": DataOperationType.ETL_UPSERT, "update_key": ("AccountNumber", "Name"), "fields": ["AccountNumber", "Name"], + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", "table": "Contact", + "select_options": {}, }, }, mf diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 4ae0dcf31a..0732d57777 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -27,6 +27,7 @@ AddMappingFiltersToQuery, AddPersonAccountsToQuery, AddRecordTypesToQuery, + DynamicLookupQueryExtender, ) from cumulusci.tasks.bulkdata.step import ( DEFAULT_BULK_BATCH_SIZE, @@ -289,7 +290,12 @@ def _execute_step( self, step, self._stream_queried_data(mapping, local_ids, query) ) step.start() - step.load_records(self._stream_queried_data(mapping, local_ids, query)) + if mapping.action == DataOperationType.SELECT: + step.select_records( + self._stream_queried_data(mapping, local_ids, query) + ) + else: + step.load_records(self._stream_queried_data(mapping, local_ids, query)) step.end() # Process Job Results @@ -304,10 +310,108 @@ def _execute_step( return step.job_result + def process_lookup_fields(self, mapping, fields, polymorphic_fields): + """Modify fields and priority fields based on lookup and polymorphic checks.""" + # Store the lookups and their original order for re-insertion at the end + original_lookups = [name for name in fields if name in mapping.lookups] + max_insert_index = -1 + for name, lookup in mapping.lookups.items(): + if name in fields: + # Get the index of the lookup field before removing it + insert_index = fields.index(name) + max_insert_index = max(max_insert_index, insert_index) + # Remove the lookup field from fields + fields.remove(name) + + # Do the same for priority fields + lookup_in_priority_fields = False + if name in mapping.select_options.priority_fields: + # Set flag to True + lookup_in_priority_fields = True + # Remove the lookup field from priority fields + del mapping.select_options.priority_fields[name] + + # Check if this lookup field is polymorphic + if ( + name in polymorphic_fields + and len(polymorphic_fields[name]["referenceTo"]) > 1 + ): + # Convert to list if string + if not isinstance(lookup.table, list): + lookup.table = [lookup.table] + # Polymorphic field handling + polymorphic_references = lookup.table + relationship_name = polymorphic_fields[name]["relationshipName"] + + # Loop through each polymorphic type (e.g., Contact, Lead) + for ref_type in polymorphic_references: + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.table == ref_type + ), + None, + ) + if lookup_mapping_step: + lookup_fields = lookup_mapping_step.fields.keys() + # Insert fields in the format {relationship_name}.{ref_type}.{lookup_field} + for field in lookup_fields: + fields.insert( + insert_index, + f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}", + ) + insert_index += 1 + max_insert_index = max(max_insert_index, insert_index) + if lookup_in_priority_fields: + mapping.select_options.priority_fields[ + f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}" + ] = f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}" + + else: + # Non-polymorphic field handling + lookup_table = lookup.table + + if isinstance(lookup_table, list): + lookup_table = lookup_table[0] + + # Get the mapping step for the non-polymorphic reference + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.table == lookup_table + ), + None, + ) + + if lookup_mapping_step: + relationship_name = polymorphic_fields[name]["relationshipName"] + lookup_fields = lookup_mapping_step.fields.keys() + + # Insert the new fields at the same position as the removed lookup field + for field in lookup_fields: + fields.insert(insert_index, f"{relationship_name}.{field}") + insert_index += 1 + max_insert_index = max(max_insert_index, insert_index) + if lookup_in_priority_fields: + mapping.select_options.priority_fields[ + f"{relationship_name}.{field}" + ] = f"{relationship_name}.{field}" + + # Append the original lookups at the end in the same order + for name in original_lookups: + if name not in fields: + fields.insert(max_insert_index, name) + max_insert_index += 1 + def configure_step(self, mapping): """Create a step appropriate to the action""" bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel" api_options = {"batch_size": mapping.batch_size, "bulk_mode": bulk_mode} + num_records_in_target = None + content_type = None fields = mapping.get_load_field_list() @@ -336,11 +440,45 @@ def configure_step(self, mapping): self.check_simple_upsert(mapping) api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT + elif mapping.action == DataOperationType.SELECT: + # Set content type to json + content_type = "JSON" + # Bulk process expects DataOpertionType to be QUERY + action = DataOperationType.QUERY + # Determine number of records in the target org + record_count_response = self.sf.restful( + f"limits/recordCount?sObjects={mapping.sf_object}" + ) + sobject_map = { + entry["name"]: entry["count"] + for entry in record_count_response["sObjects"] + } + num_records_in_target = sobject_map.get(mapping.sf_object, None) + + # Check for similarity selection strategy and modify fields accordingly + if mapping.select_options.strategy == "similarity": + # Describe the object to determine polymorphic lookups + describe_result = self.sf.restful( + f"sobjects/{mapping.sf_object}/describe" + ) + polymorphic_fields = { + field["name"]: field + for field in describe_result["fields"] + if field["type"] == "reference" + } + self.process_lookup_fields(mapping, fields, polymorphic_fields) else: action = mapping.action query = self._query_db(mapping) + # Set volume + volume = ( + num_records_in_target + if num_records_in_target is not None + else query.count() + ) + step = get_dml_operation( sobject=mapping.sf_object, operation=action, @@ -348,7 +486,12 @@ def configure_step(self, mapping): context=self, fields=fields, api=mapping.api, - volume=query.count(), + volume=volume, + selection_strategy=mapping.select_options.strategy, + selection_filter=mapping.select_options.filter, + selection_priority_fields=mapping.select_options.priority_fields, + content_type=content_type, + threshold=mapping.select_options.threshold, ) return step, query @@ -448,9 +591,20 @@ def _query_db(self, mapping): AddMappingFiltersToQuery, AddUpsertsToQuery, ] - transformers = [ + transformers = [] + if ( + mapping.action == DataOperationType.SELECT + and mapping.select_options.strategy == "similarity" + ): + transformers.append( + DynamicLookupQueryExtender( + mapping, self.mapping, self.metadata, model, self._old_format + ) + ) + transformers.append( AddLookupsToQuery(mapping, self.metadata, model, self._old_format) - ] + ) + transformers.extend([cls(mapping, self.metadata, model) for cls in classes]) if mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping): @@ -481,10 +635,11 @@ def _process_job_results(self, mapping, step, local_ids): """Get the job results and process the results. If we're raising for row-level errors, do so; if we're inserting, store the new Ids.""" - is_insert_or_upsert = mapping.action in ( + is_insert_upsert_or_select = mapping.action in ( DataOperationType.INSERT, DataOperationType.UPSERT, DataOperationType.ETL_UPSERT, + DataOperationType.SELECT, ) conn = self.session.connection() @@ -500,7 +655,7 @@ def _process_job_results(self, mapping, step, local_ids): break # If we know we have no successful inserts, don't attempt to persist Ids. # Do, however, drain the generator to get error-checking behavior. - if is_insert_or_upsert and ( + if is_insert_upsert_or_select and ( step.job_result.records_processed - step.job_result.total_row_errors ): table = self.metadata.tables[self.ID_TABLE_NAME] @@ -516,7 +671,7 @@ def _process_job_results(self, mapping, step, local_ids): # person account Contact records so lookups to # person account Contact records get populated downstream as expected. if ( - is_insert_or_upsert + is_insert_upsert_or_select and mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping) ): @@ -531,7 +686,7 @@ def _process_job_results(self, mapping, step, local_ids): ), ) - if is_insert_or_upsert: + if is_insert_upsert_or_select: self.session.commit() def _generate_results_id_map(self, step, local_ids): diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index bb59fc6647..59c7d630a2 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -8,33 +8,21 @@ from typing import IO, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from pydantic import Field, ValidationError, root_validator, validator -from requests.structures import CaseInsensitiveDict as RequestsCaseInsensitiveDict from simple_salesforce import Salesforce from typing_extensions import Literal from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.dates import iso_to_date +from cumulusci.tasks.bulkdata.select_utils import SelectOptions, SelectStrategy from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType +from cumulusci.tasks.bulkdata.utils import CaseInsensitiveDict from cumulusci.utils import convert_to_snake_case from cumulusci.utils.yaml.model_parser import CCIDictModel logger = getLogger(__name__) -class CaseInsensitiveDict(RequestsCaseInsensitiveDict): - def __init__(self, *args, **kwargs): - self._canonical_keys = {} - super().__init__(*args, **kwargs) - - def canonical_key(self, name): - return self._canonical_keys[name.lower()] - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self._canonical_keys[key.lower()] = key - - class MappingLookup(CCIDictModel): "Lookup relationship between two tables." table: Union[str, List[str]] # Support for polymorphic lookups @@ -43,6 +31,7 @@ class MappingLookup(CCIDictModel): join_field: Optional[str] = None after: Optional[str] = None aliased_table: Optional[Any] = None + parent_tables: Optional[Any] = None name: Optional[str] = None # populated by parent def get_lookup_key_field(self, model=None): @@ -107,6 +96,9 @@ class MappingStep(CCIDictModel): ] = None # default should come from task options anchor_date: Optional[Union[str, date]] = None soql_filter: Optional[str] = None # soql_filter property + select_options: Optional[SelectOptions] = Field( + default_factory=lambda: SelectOptions(strategy=SelectStrategy.STANDARD) + ) update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts @validator("bulk_mode", "api", "action", pre=True) @@ -129,6 +121,27 @@ def split_update_key(cls, val): ), "`update_key` should be a field name or list of field names." assert False, "Should be unreachable" # pragma: no cover + @root_validator + def validate_priority_fields(cls, values): + select_options = values.get("select_options") + fields_ = values.get("fields_", {}) + lookups = values.get("lookups", {}) + + if select_options and select_options.priority_fields: + priority_field_names = set(select_options.priority_fields.keys()) + field_names = set(fields_.keys()) + lookup_names = set(lookups.keys()) + + # Check if all priority fields are present in the fields + missing_fields = priority_field_names - field_names + missing_fields = missing_fields - lookup_names + if missing_fields: + raise ValueError( + f"Priority fields {missing_fields} are not present in 'fields' or 'lookups'" + ) + + return values + def get_oid_as_pk(self): """Returns True if using Salesforce Ids as primary keys.""" return "Id" in self.fields @@ -673,7 +686,9 @@ def _infer_and_validate_lookups(mapping: Dict, sf: Salesforce): if len(target_objects) == 1: # This is a non-polymorphic lookup. target_index = list(sf_objects.values()).index(target_objects[0]) - if target_index > idx or target_index == idx: + if ( + target_index > idx or target_index == idx + ) and m.action != DataOperationType.SELECT: # This is a non-polymorphic after step. lookup.after = list(mapping.keys())[idx] else: @@ -725,7 +740,7 @@ def validate_and_inject_mapping( if drop_missing: # Drop any steps with sObjects that are not present. - for (include, step_name) in zip(should_continue, list(mapping.keys())): + for include, step_name in zip(should_continue, list(mapping.keys())): if not include: del mapping[step_name] diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index aef23f5dc3..181736a4bc 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -3,6 +3,7 @@ from sqlalchemy import String, and_, func, text from sqlalchemy.orm import Query, aliased +from sqlalchemy.sql import literal_column from cumulusci.core.exceptions import BulkDataException @@ -86,6 +87,81 @@ def join_for_lookup(lookup): return [join_for_lookup(lookup) for lookup in self.lookups] +class DynamicLookupQueryExtender(LoadQueryExtender): + """Dynamically adds columns and joins for all fields in lookup tables, handling polymorphic lookups""" + + def __init__( + self, mapping, all_mappings, metadata, model, _old_format: bool + ) -> None: + super().__init__(mapping, metadata, model) + self._old_format = _old_format + self.all_mappings = all_mappings + self.lookups = [ + lookup for lookup in self.mapping.lookups.values() if not lookup.after + ] + + @cached_property + def columns_to_add(self): + """Add all relevant fields from lookup tables directly without CASE, with support for polymorphic lookups.""" + columns = [] + for lookup in self.lookups: + tables = lookup.table if isinstance(lookup.table, list) else [lookup.table] + lookup.parent_tables = [ + aliased( + self.metadata.tables[table], name=f"{lookup.name}_{table}_alias" + ) + for table in tables + ] + + for parent_table, table_name in zip(lookup.parent_tables, tables): + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.all_mappings.values() + if step.table == table_name + ), + None, + ) + if lookup_mapping_step: + load_fields = lookup_mapping_step.fields.keys() + for field in load_fields: + if field in lookup_mapping_step.fields: + matching_column = next( + ( + col + for col in parent_table.columns + if col.name == lookup_mapping_step.fields[field] + ) + ) + columns.append( + matching_column.label(f"{parent_table.name}_{field}") + ) + else: + # Append an empty string if the field is not present + columns.append( + literal_column("''").label( + f"{parent_table.name}_{field}" + ) + ) + return columns + + @cached_property + def outerjoins_to_add(self): + """Add outer joins for each lookup table directly, including handling for polymorphic lookups.""" + + def join_for_lookup(lookup, parent_table): + key_field = lookup.get_lookup_key_field(self.model) + value_column = getattr(self.model, key_field) + return (parent_table, parent_table.columns.id == value_column) + + joins = [] + for lookup in self.lookups: + for parent_table in lookup.parent_tables: + joins.append(join_for_lookup(lookup, parent_table)) + return joins + + class AddRecordTypesToQuery(LoadQueryExtender): """Adds columns, joins and filters relatinng to recordtypes""" diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py new file mode 100644 index 0000000000..7412a38ae4 --- /dev/null +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -0,0 +1,769 @@ +import random +import re +import typing as T +from enum import Enum + +import numpy as np +import pandas as pd +from annoy import AnnoyIndex +from pydantic import Field, root_validator, validator +from sklearn.feature_extraction.text import HashingVectorizer +from sklearn.preprocessing import StandardScaler + +from cumulusci.core.enums import StrEnum +from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( + DEFAULT_DECLARATIONS, +) +from cumulusci.tasks.bulkdata.utils import CaseInsensitiveDict +from cumulusci.utils.yaml.model_parser import CCIDictModel + + +class SelectStrategy(StrEnum): + """Enum defining the different selection strategies requested.""" + + STANDARD = "standard" + SIMILARITY = "similarity" + RANDOM = "random" + + +class SelectRecordRetrievalMode(StrEnum): + """Enum defining whether you need all records or match the + number of records of the local sql file""" + + ALL = "all" + MATCH = "match" + + +ENUM_VALUES = { + v.value.lower(): v.value + for enum in [SelectStrategy] + for v in enum.__members__.values() +} + + +class SelectOptions(CCIDictModel): + filter: T.Optional[str] = None # Optional filter for selection + strategy: SelectStrategy = SelectStrategy.STANDARD # Strategy for selection + priority_fields: T.Dict[str, str] = Field({}) + threshold: T.Optional[float] = None + + @validator("strategy", pre=True) + def validate_strategy(cls, value): + if isinstance(value, Enum): + return value + + if value: + matched_strategy = ENUM_VALUES.get(value.lower()) + if matched_strategy: + return matched_strategy + + raise ValueError(f"Invalid strategy value: {value}") + + @validator("priority_fields", pre=True) + def standardize_fields_to_dict(cls, values): + if values is None: + values = {} + if type(values) is list: + values = {elem: elem for elem in values} + return CaseInsensitiveDict(values) + + @root_validator + def validate_threshold_and_strategy(cls, values): + threshold = values.get("threshold") + strategy = values.get("strategy") + + if threshold is not None: + values["threshold"] = float(threshold) # Convert to float + + if not (0 <= values["threshold"] <= 1): + raise ValueError( + f"Threshold must be between 0 and 1, got {values['threshold']}." + ) + + if strategy != SelectStrategy.SIMILARITY: + raise ValueError( + "If a threshold is specified, the strategy must be set to 'similarity'." + ) + + return values + + +class SelectOperationExecutor: + def __init__(self, strategy: SelectStrategy): + self.strategy = strategy + self.retrieval_mode = ( + SelectRecordRetrievalMode.ALL + if strategy == SelectStrategy.SIMILARITY + else SelectRecordRetrievalMode.MATCH + ) + + def select_generate_query( + self, + sobject: str, + fields: T.List[str], + user_filter: str, + limit: T.Union[int, None], + offset: T.Union[int, None], + ): + _, select_fields = split_and_filter_fields(fields=fields) + # For STANDARD strategy + if self.strategy == SelectStrategy.STANDARD: + return standard_generate_query( + sobject=sobject, user_filter=user_filter, limit=limit, offset=offset + ) + # For SIMILARITY strategy + elif self.strategy == SelectStrategy.SIMILARITY: + return similarity_generate_query( + sobject=sobject, + fields=select_fields, + user_filter=user_filter, + limit=limit, + offset=offset, + ) + # For RANDOM strategy + elif self.strategy == SelectStrategy.RANDOM: + return standard_generate_query( + sobject=sobject, user_filter=user_filter, limit=limit, offset=offset + ) + + def select_post_process( + self, + load_records, + query_records: list, + fields: list, + num_records: int, + sobject: str, + weights: list, + threshold: T.Union[float, None], + ): + # For STANDARD strategy + if self.strategy == SelectStrategy.STANDARD: + return standard_post_process( + query_records=query_records, num_records=num_records, sobject=sobject + ) + # For SIMILARITY strategy + elif self.strategy == SelectStrategy.SIMILARITY: + return similarity_post_process( + load_records=load_records, + query_records=query_records, + fields=fields, + sobject=sobject, + weights=weights, + threshold=threshold, + ) + # For RANDOM strategy + elif self.strategy == SelectStrategy.RANDOM: + return random_post_process( + query_records=query_records, num_records=num_records, sobject=sobject + ) + + +def standard_generate_query( + sobject: str, + user_filter: str, + limit: T.Union[int, None], + offset: T.Union[int, None], +) -> T.Tuple[str, T.List[str]]: + """Generates the SOQL query for the standard (as well as random) selection strategy""" + + query = f"SELECT Id FROM {sobject}" + # If user specifies user_filter + if user_filter: + query += add_limit_offset_to_user_filter( + filter_clause=user_filter, limit_clause=limit, offset_clause=offset + ) + else: + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + query += f" WHERE {declaration.where}" + query += f" LIMIT {limit}" if limit else "" + query += f" OFFSET {offset}" if offset else "" + return query, ["Id"] + + +def standard_post_process( + query_records: list, num_records: int, sobject: str +) -> T.Tuple[T.List[dict], None, T.Union[str, None]]: + """Processes the query results for the standard selection strategy""" + # Handle case where query returns 0 records + if not query_records: + error_message = f"No records found for {sobject} in the target org." + return [], None, error_message + + # Add 'success: True' to each record to emulate records have been inserted + selected_records = [ + {"id": record[0], "success": True, "created": False} for record in query_records + ] + + # If fewer records than requested, repeat existing records to match num_records + if len(selected_records) < num_records: + original_records = selected_records.copy() + while len(selected_records) < num_records: + selected_records.extend(original_records) + selected_records = selected_records[:num_records] + + return selected_records, None, None # Return selected records and None for error + + +def similarity_generate_query( + sobject: str, + fields: T.List[str], + user_filter: str, + limit: T.Union[int, None], + offset: T.Union[int, None], +) -> T.Tuple[str, T.List[str]]: + """Generates the SOQL query for the similarity selection strategy, with support for TYPEOF on polymorphic fields.""" + + # Pre-process the new fields format to create a nested dict structure for TYPEOF clauses + nested_fields = {} + regular_fields = [] + + for field in fields: + components = field.split(".") + if len(components) >= 3: + # Handle polymorphic fields (format: {relationship_name}.{ref_obj}.{ref_field}) + relationship, ref_obj, ref_field = ( + components[0], + components[1], + components[2], + ) + if relationship not in nested_fields: + nested_fields[relationship] = {} + if ref_obj not in nested_fields[relationship]: + nested_fields[relationship][ref_obj] = [] + nested_fields[relationship][ref_obj].append(ref_field) + else: + # Handle regular fields (format: {field}) + regular_fields.append(field) + + # Construct the query fields + query_fields = [] + + # Build TYPEOF clauses for polymorphic fields + for relationship, references in nested_fields.items(): + type_clauses = [] + for ref_obj, ref_fields in references.items(): + fields_clause = ", ".join(ref_fields) + type_clauses.append(f"WHEN {ref_obj} THEN {fields_clause}") + type_clause = f"TYPEOF {relationship} {' '.join(type_clauses)} ELSE Id END" + query_fields.append(type_clause) + + # Add regular fields to the query + query_fields.extend(regular_fields) + + # Ensure "Id" is included in the fields list for identification + if "Id" not in query_fields: + query_fields.insert(0, "Id") + + # Build the main SOQL query + fields_to_query = ", ".join(query_fields) + query = f"SELECT {fields_to_query} FROM {sobject}" + + # Add the user-defined filter clause or default clause + if user_filter: + query += add_limit_offset_to_user_filter( + filter_clause=user_filter, limit_clause=limit, offset_clause=offset + ) + else: + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + query += f" WHERE {declaration.where}" + query += f" LIMIT {limit}" if limit else "" + query += f" OFFSET {offset}" if offset else "" + + # Return the original input fields with "Id" added if needed + if "Id" not in fields: + fields.insert(0, "Id") + + return query, fields + + +def similarity_post_process( + load_records, + query_records: list, + fields: list, + sobject: str, + weights: list, + threshold: T.Union[float, None], +) -> T.Tuple[ + T.List[T.Union[dict, None]], T.List[T.Union[list, None]], T.Union[str, None] +]: + """Processes the query results for the similarity selection strategy""" + # Handle case where query returns 0 records + if not query_records and not threshold: + error_message = f"No records found for {sobject} in the target org." + return [], [], error_message + + load_records = list(load_records) + # Replace None values in each row with empty strings + for idx, row in enumerate(load_records): + row = [value if value is not None else "" for value in row] + load_records[idx] = row + load_record_count, query_record_count = len(load_records), len(query_records) + + complexity_constant = load_record_count * query_record_count + + select_records = [] + insert_records = [] + + if complexity_constant < 1000: + select_records, insert_records = levenshtein_post_process( + load_records, query_records, fields, weights, threshold + ) + else: + select_records, insert_records = annoy_post_process( + load_records, query_records, fields, weights, threshold + ) + + return select_records, insert_records, None + + +def annoy_post_process( + load_records: list, + query_records: list, + all_fields: list, + similarity_weights: list, + threshold: T.Union[float, None], +) -> T.Tuple[T.List[dict], list]: + """Processes the query results for the similarity selection strategy using Annoy algorithm for large number of records""" + selected_records = [] + insertion_candidates = [] + + # Split fields into load and select categories + load_field_list, select_field_list = split_and_filter_fields(fields=all_fields) + # Only select those weights for select field list + similarity_weights = [ + similarity_weights[idx] + for idx, field in enumerate(all_fields) + if field in select_field_list + ] + load_shaped_records = reorder_records( + records=load_records, original_fields=all_fields, new_fields=load_field_list + ) + select_shaped_records = reorder_records( + records=load_records, original_fields=all_fields, new_fields=select_field_list + ) + + if not query_records: + # Directly append to load record for insertion if target_records is empty + selected_records = [None for _ in load_records] + insertion_candidates = load_shaped_records + return selected_records, insertion_candidates + + query_records = replace_empty_strings_with_missing(query_records) + select_shaped_records = replace_empty_strings_with_missing(select_shaped_records) + + hash_features = 100 + num_trees = 10 + + query_record_ids = [record[0] for record in query_records] + query_record_data = [record[1:] for record in query_records] + + record_to_id_map = { + tuple(query_record_data[i]): query_record_ids[i] + for i in range(len(query_records)) + } + + final_load_vectors, final_query_vectors = vectorize_records( + select_shaped_records, + query_record_data, + hash_features=hash_features, + weights=similarity_weights, + ) + + # Create Annoy index for nearest neighbor search + vector_dimension = final_query_vectors.shape[1] + annoy_index = AnnoyIndex(vector_dimension, "euclidean") + + for i in range(len(final_query_vectors)): + annoy_index.add_item(i, final_query_vectors[i]) + + # Build the index + annoy_index.build(num_trees) + + # Find nearest neighbors for each query vector + n_neighbors = 1 + + for i, load_vector in enumerate(final_load_vectors): + # Get nearest neighbors' indices and distances + nearest_neighbors = annoy_index.get_nns_by_vector( + load_vector, n_neighbors, include_distances=True + ) + neighbor_indices = nearest_neighbors[0] # Indices of nearest neighbors + neighbor_distances = [ + distance / 2 for distance in nearest_neighbors[1] + ] # Distances sqrt(2(1-cos(u,v)))/2 lies between [0,1] + + for idx, neighbor_index in enumerate(neighbor_indices): + # Retrieve the corresponding record from the database + record = query_record_data[neighbor_index] + closest_record_id = record_to_id_map[tuple(record)] + if threshold and (neighbor_distances[idx] >= threshold): + selected_records.append(None) + insertion_candidates.append(load_shaped_records[i]) + else: + selected_records.append( + {"id": closest_record_id, "success": True, "created": False} + ) + + return selected_records, insertion_candidates + + +def levenshtein_post_process( + source_records: list, + target_records: list, + all_fields: list, + similarity_weights: list, + distance_threshold: T.Union[float, None], +) -> T.Tuple[T.List[T.Optional[dict]], T.List[T.Optional[list]]]: + """Processes query results using Levenshtein algorithm for similarity selection with a small number of records.""" + selected_records = [] + insertion_candidates = [] + + # Split fields into load and select categories + load_field_list, select_field_list = split_and_filter_fields(fields=all_fields) + # Only select those weights for select field list + similarity_weights = [ + similarity_weights[idx] + for idx, field in enumerate(all_fields) + if field in select_field_list + ] + load_shaped_records = reorder_records( + records=source_records, original_fields=all_fields, new_fields=load_field_list + ) + select_shaped_records = reorder_records( + records=source_records, original_fields=all_fields, new_fields=select_field_list + ) + + if not target_records: + # Directly append to load record for insertion if target_records is empty + selected_records = [None for _ in source_records] + insertion_candidates = load_shaped_records + return selected_records, insertion_candidates + + for select_record, load_record in zip(select_shaped_records, load_shaped_records): + closest_match, match_distance = find_closest_record( + select_record, target_records, similarity_weights + ) + + if distance_threshold and match_distance > distance_threshold: + # Append load record for insertion if distance exceeds threshold + insertion_candidates.append(load_record) + selected_records.append(None) + elif closest_match: + # Append match details if distance is within threshold + selected_records.append( + {"id": closest_match[0], "success": True, "created": False} + ) + + return selected_records, insertion_candidates + + +def random_post_process( + query_records: list, num_records: int, sobject: str +) -> T.Tuple[T.List[dict], None, T.Union[str, None]]: + """Processes the query results for the random selection strategy""" + + if not query_records: + error_message = f"No records found for {sobject} in the target org." + return [], None, error_message + + selected_records = [] + for _ in range(num_records): # Loop 'num_records' times + # Randomly select one record from query_records + random_record = random.choice(query_records) + selected_records.append( + {"id": random_record[0], "success": True, "created": False} + ) + + return selected_records, None, None + + +def find_closest_record(load_record: list, query_records: list, weights: list): + closest_distance = float("inf") + closest_record = query_records[0] + + for record in query_records: + distance = calculate_levenshtein_distance(load_record, record[1:], weights) + if distance < closest_distance: + closest_distance = distance + closest_record = record + + return closest_record, closest_distance + + +def levenshtein_distance(str1: str, str2: str): + """Calculate the Levenshtein distance between two strings""" + len_str1 = len(str1) + 1 + len_str2 = len(str2) + 1 + + dp = [[0 for _ in range(len_str2)] for _ in range(len_str1)] + + for i in range(len_str1): + dp[i][0] = i + for j in range(len_str2): + dp[0][j] = j + + for i in range(1, len_str1): + for j in range(1, len_str2): + cost = 0 if str1[i - 1] == str2[j - 1] else 1 + dp[i][j] = min( + dp[i - 1][j] + 1, # Deletion + dp[i][j - 1] + 1, # Insertion + dp[i - 1][j - 1] + cost, + ) # Substitution + + return dp[-1][-1] + + +def calculate_levenshtein_distance(record1: list, record2: list, weights: list): + if len(record1) != len(record2): + raise ValueError("Records must have the same number of fields.") + elif len(record1) != len(weights): + raise ValueError("Records must be same size as fields (weights).") + + total_distance = 0 + + for field1, field2, weight in zip(record1, record2, weights): + field1 = field1.lower() + field2 = field2.lower() + + if len(field1) == 0 and len(field2) == 0: + # If both fields are blank, distance is 0 + distance = 0 + else: + # Average distance per character + distance = levenshtein_distance(field1, field2) / max( + len(field1), len(field2) + ) + if len(field1) == 0 or len(field2) == 0: + # If one field is blank, reduce the impact of the distance + distance = distance * 0.05 # Fixed value for blank vs non-blank + + # Multiply the distance by the corresponding weight + total_distance += distance * weight + + # Average distance per character with weights + return total_distance / sum(weights) if len(weights) else 0 + + +def add_limit_offset_to_user_filter( + filter_clause: str, + limit_clause: T.Union[float, None] = None, + offset_clause: T.Union[float, None] = None, +) -> str: + + # Extract existing LIMIT and OFFSET from filter_clause if present + existing_limit_match = re.search(r"LIMIT\s+(\d+)", filter_clause, re.IGNORECASE) + existing_offset_match = re.search(r"OFFSET\s+(\d+)", filter_clause, re.IGNORECASE) + + if existing_limit_match: + existing_limit = int(existing_limit_match.group(1)) + if limit_clause is not None: # Only apply limit_clause if it's provided + limit_clause = min(existing_limit, limit_clause) + else: + limit_clause = existing_limit + + if existing_offset_match: + existing_offset = int(existing_offset_match.group(1)) + if offset_clause is not None: + offset_clause = existing_offset + offset_clause + else: + offset_clause = existing_offset + + # Remove existing LIMIT and OFFSET from filter_clause, handling potential extra spaces + filter_clause = re.sub( + r"\s+OFFSET\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE + ).strip() + filter_clause = re.sub( + r"\s+LIMIT\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE + ).strip() + + if limit_clause is not None: + filter_clause += f" LIMIT {limit_clause}" + if offset_clause is not None: + filter_clause += f" OFFSET {offset_clause}" + + return f" {filter_clause}" + + +def determine_field_types(df, weights): + numerical_features = [] + boolean_features = [] + categorical_features = [] + + numerical_weights = [] + boolean_weights = [] + categorical_weights = [] + + for col, weight in zip(df.columns, weights): + # Check if the column can be converted to numeric + try: + # Attempt to convert to numeric + df[col] = pd.to_numeric(df[col], errors="raise") + numerical_features.append(col) + numerical_weights.append(weight) + except ValueError: + # Check for boolean values + if df[col].str.lower().isin(["true", "false"]).all(): + # Map to actual boolean values + df[col] = df[col].str.lower().map({"true": True, "false": False}) + boolean_features.append(col) + boolean_weights.append(weight) + else: + categorical_features.append(col) + categorical_weights.append(weight) + + return ( + numerical_features, + boolean_features, + categorical_features, + numerical_weights, + boolean_weights, + categorical_weights, + ) + + +def vectorize_records(db_records, query_records, hash_features, weights): + # Convert database records and query records to DataFrames + df_db = pd.DataFrame(db_records) + df_query = pd.DataFrame(query_records) + + # Determine field types and corresponding weights + # Modifies boolean columns to True or False + ( + numerical_features, + boolean_features, + categorical_features, + numerical_weights, + boolean_weights, + categorical_weights, + ) = determine_field_types(df_db, weights) + + # Modify query dataframe boolean columns to True or False + for col in df_query.columns: + if df_query[col].str.lower().isin(["true", "false"]).all(): + df_query[col] = ( + df_query[col].str.lower().map({"true": True, "false": False}) + ) + + # Fit StandardScaler on the numerical features of the database records + scaler = StandardScaler() + if numerical_features: + df_db[numerical_features] = scaler.fit_transform(df_db[numerical_features]) + df_query[numerical_features] = scaler.transform(df_query[numerical_features]) + + # Use HashingVectorizer to transform the categorical features + hashing_vectorizer = HashingVectorizer( + n_features=hash_features, alternate_sign=False + ) + + # For db_records + hashed_categorical_data_db = [] + for idx, col in enumerate(categorical_features): + hashed_db = hashing_vectorizer.fit_transform(df_db[col]).toarray() + # Apply weight to the hashed vector for this categorical feature + hashed_db_weighted = hashed_db * categorical_weights[idx] + hashed_categorical_data_db.append(hashed_db_weighted) + + # For query_records + hashed_categorical_data_query = [] + for idx, col in enumerate(categorical_features): + hashed_query = hashing_vectorizer.transform(df_query[col]).toarray() + # Apply weight to the hashed vector for this categorical feature + hashed_query_weighted = hashed_query * categorical_weights[idx] + hashed_categorical_data_query.append(hashed_query_weighted) + + # Combine all feature types into a single vector for the database records + db_vectors = [] + if numerical_features: + db_vectors.append(df_db[numerical_features].values * numerical_weights) + if boolean_features: + db_vectors.append(df_db[boolean_features].astype(int).values * boolean_weights) + if hashed_categorical_data_db: + db_vectors.append(np.hstack(hashed_categorical_data_db)) + + # Concatenate database vectors + final_db_vectors = np.hstack(db_vectors) + + # Combine all feature types into a single vector for the query records + query_vectors = [] + if numerical_features: + query_vectors.append(df_query[numerical_features].values * numerical_weights) + if boolean_features: + query_vectors.append( + df_query[boolean_features].astype(int).values * boolean_weights + ) + if hashed_categorical_data_query: + query_vectors.append(np.hstack(hashed_categorical_data_query)) + + # Concatenate query vectors + final_query_vectors = np.hstack(query_vectors) + + return final_db_vectors, final_query_vectors + + +def replace_empty_strings_with_missing(records): + return [ + [(field if field != "" else "missing") for field in record] + for record in records + ] + + +def split_and_filter_fields(fields: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]: + # List to store non-lookup fields (load fields) + load_fields = [] + + # Set to store unique first components of select fields + unique_components = set() + # Keep track of last flattened lookup index + last_flat_lookup_index = -1 + + # Iterate through the fields + for idx, field in enumerate(fields): + if "." in field: + # Split the field by '.' and add the first component to the set + first_component = field.split(".")[0] + unique_components.add(first_component) + last_flat_lookup_index = max(last_flat_lookup_index, idx) + else: + # Add the field to the load_fields list + load_fields.append(field) + + # Number of unique components + num_unique_components = len(unique_components) + + # Adjust select_fields by removing only the field at last_flat_lookup_index + 1 + if last_flat_lookup_index + 1 < len( + fields + ) and last_flat_lookup_index + num_unique_components < len(fields): + select_fields = ( + fields[: last_flat_lookup_index + 1] + + fields[last_flat_lookup_index + num_unique_components + 1 :] + ) + else: + select_fields = fields + + return load_fields, select_fields + + +# Function to reorder records based on the new field list +def reorder_records(records, original_fields, new_fields): + if not original_fields: + raise KeyError("original_fields should not be empty") + # Map the original field indices + field_index_map = {field: i for i, field in enumerate(original_fields)} + reordered_records = [] + + for record in records: + reordered_records.append( + [ + record[field_index_map[field]] + for field in new_fields + if field in field_index_map + ] + ) + + return reordered_records diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index edcb62afbb..b2a13bf966 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -7,7 +7,8 @@ import time from abc import ABCMeta, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, List, NamedTuple, Optional +from itertools import tee +from typing import Any, Dict, List, NamedTuple, Optional, Union import requests import salesforce_bulk @@ -15,13 +16,21 @@ from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.core.utils import process_bool_arg -from cumulusci.tasks.bulkdata.utils import iterate_in_chunks +from cumulusci.tasks.bulkdata.select_utils import ( + SelectOperationExecutor, + SelectRecordRetrievalMode, + SelectStrategy, + split_and_filter_fields, +) +from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict from cumulusci.utils.xml import lxml_parse_string DEFAULT_BULK_BATCH_SIZE = 10_000 DEFAULT_REST_BATCH_SIZE = 200 MAX_REST_BATCH_SIZE = 200 +HIGH_PRIORITY_VALUE = 3 +LOW_PRIORITY_VALUE = 0.5 csv.field_size_limit(2**27) # 128 MB @@ -36,14 +45,7 @@ class DataOperationType(StrEnum): UPSERT = "upsert" ETL_UPSERT = "etl_upsert" SMART_UPSERT = "smart_upsert" # currently undocumented - - -class DataApi(StrEnum): - """Enum defining requested Salesforce data API for an operation.""" - - BULK = "bulk" - REST = "rest" - SMART = "smart" + SELECT = "select" class DataOperationStatus(StrEnum): @@ -320,6 +322,11 @@ def get_prev_record_values(self, records): """Get the previous records values in case of UPSERT and UPDATE to prepare for rollback""" pass + @abstractmethod + def select_records(self, records): + """Perform the requested DML operation on the supplied row iterator.""" + pass + @abstractmethod def load_records(self, records): """Perform the requested DML operation on the supplied row iterator.""" @@ -338,7 +345,20 @@ def get_results(self): class BulkApiDmlOperation(BaseDmlOperation, BulkJobMixin): """Operation class for all DML operations run using the Bulk API.""" - def __init__(self, *, sobject, operation, api_options, context, fields): + def __init__( + self, + *, + sobject, + operation, + api_options, + context, + fields, + selection_strategy=SelectStrategy.STANDARD, + selection_filter=None, + selection_priority_fields=None, + content_type=None, + threshold=None, + ): super().__init__( sobject=sobject, operation=operation, @@ -353,18 +373,27 @@ def __init__(self, *, sobject, operation, api_options, context, fields): self.csv_buff = io.StringIO(newline="") self.csv_writer = csv.writer(self.csv_buff, quoting=csv.QUOTE_ALL) + self.select_operation_executor = SelectOperationExecutor(selection_strategy) + self.selection_filter = selection_filter + self.weights = assign_weights( + priority_fields=selection_priority_fields, fields=fields + ) + self.content_type = content_type if content_type else "CSV" + self.threshold = threshold + def start(self): self.job_id = self.bulk.create_job( self.sobject, self.operation.value, - contentType="CSV", + contentType=self.content_type, concurrency=self.api_options.get("bulk_mode", "Parallel"), external_id_name=self.api_options.get("update_key"), ) def end(self): self.bulk.close_job(self.job_id) - self.job_result = self._wait_for_job(self.job_id) + if not self.job_result: + self.job_result = self._wait_for_job(self.job_id) def get_prev_record_values(self, records): """Get the previous values of the records based on the update key @@ -424,6 +453,161 @@ def load_records(self, records): self.context.logger.info(f"Uploading batch {count + 1}") self.batch_ids.append(self.bulk.post_batch(self.job_id, iter(csv_batch))) + def select_records(self, records): + """Executes a SOQL query to select records and adds them to results""" + + self.select_results = [] # Store selected records + query_records = [] + # Create a copy of the generator using tee + records, records_copy = tee(records) + # Count total number of records to fetch using the copy + total_num_records = sum(1 for _ in records_copy) + limit_clause = self._determine_limit_clause(total_num_records=total_num_records) + + # Generate and execute SOQL query + # (not passing offset as it is not supported in Bulk) + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( + sobject=self.sobject, + fields=self.fields, + user_filter=self.selection_filter if self.selection_filter else None, + limit=limit_clause, + offset=None, + ) + + # Execute the main select query using Bulk API + select_query_records = self._execute_select_query( + select_query=select_query, query_fields=query_fields + ) + + query_records.extend(select_query_records) + # Post-process the query results + ( + selected_records, + insert_records, + error_message, + ) = self.select_operation_executor.select_post_process( + load_records=records, + query_records=query_records, + fields=self.fields, + num_records=total_num_records, + sobject=self.sobject, + weights=self.weights, + threshold=self.threshold, + ) + + # Log the number of selected and prepared for insertion records + num_selected = sum(1 for record in selected_records if record) + num_prepared = len(insert_records) if insert_records else 0 + + self.logger.info( + f"{num_selected} records selected." + + ( + f" {num_prepared} records prepared for insertion." + if num_prepared > 0 + else "" + ) + ) + + if insert_records: + self._process_insert_records(insert_records, selected_records) + + if not error_message: + self.select_results.extend(selected_records) + + # Update job result based on selection outcome + self.job_result = DataOperationJobResult( + status=( + DataOperationStatus.SUCCESS + if len(self.select_results) + else DataOperationStatus.JOB_FAILURE + ), + job_errors=[error_message] if error_message else [], + records_processed=len(self.select_results), + total_row_errors=0, + ) + + def _process_insert_records(self, insert_records, selected_records): + """Processes and inserts records if necessary.""" + insert_fields, _ = split_and_filter_fields(fields=self.fields) + insert_step = BulkApiDmlOperation( + sobject=self.sobject, + operation=DataOperationType.INSERT, + api_options=self.api_options, + context=self.context, + fields=insert_fields, + ) + insert_step.start() + insert_step.load_records(insert_records) + insert_step.end() + # Retrieve insert results + insert_results = [] + for batch_id in insert_step.batch_ids: + try: + results_url = f"{insert_step.bulk.endpoint}/job/{insert_step.job_id}/batch/{batch_id}/result" + # Download entire result file to a temporary file first + # to avoid the server dropping connections + with download_file(results_url, insert_step.bulk) as f: + self.logger.info(f"Downloaded results for batch {batch_id}") + reader = csv.reader(f) + next(reader) # Skip header row + for row in reader: + success = process_bool_arg(row[1]) + created = process_bool_arg(row[2]) + insert_results.append( + {"id": row[0], "success": success, "created": created} + ) + except Exception as e: + raise BulkDataException( + f"Failed to download results for batch {batch_id} ({str(e)})" + ) + + insert_index = 0 + for idx, record in enumerate(selected_records): + if record is None: + selected_records[idx] = insert_results[insert_index] + insert_index += 1 + + def _determine_limit_clause(self, total_num_records): + """Determines the LIMIT clause based on the retrieval mode.""" + if ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.ALL + ): + return None + elif ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.MATCH + ): + return total_num_records + + def _execute_select_query(self, select_query: str, query_fields: List[str]): + """Executes the select Bulk API query, retrieves results in JSON, and converts to CSV format if needed.""" + self.batch_id = self.bulk.query(self.job_id, select_query) + self.bulk.wait_for_batch(self.job_id, self.batch_id) + result_ids = self.bulk.get_query_batch_result_ids( + self.batch_id, job_id=self.job_id + ) + select_query_records = [] + + for result_id in result_ids: + # Modify URI to request JSON format + uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}?format=json" + # Download JSON data + with download_file(uri, self.bulk) as f: + data = json.load(f) + # Get headers from fields, expanding nested structures for TYPEOF results + self.headers = query_fields + + # Convert each record to a flat row + for record in data: + flat_record = flatten_record(record, self.headers) + select_query_records.append(flat_record) + + return select_query_records + def _batch(self, records, n, char_limit=10000000): """Given an iterator of records, yields batches of records serialized in .csv format. @@ -472,6 +656,29 @@ def _serialize_csv_record(self, record): return serialized def get_results(self): + """ + Retrieves and processes the results of a Bulk API operation. + """ + + if self.operation is DataOperationType.QUERY: + yield from self._get_query_results() + else: + yield from self._get_batch_results() + + def _get_query_results(self): + """Handles results for QUERY (select) operations""" + for row in self.select_results: + success = process_bool_arg(row["success"]) + created = process_bool_arg(row["created"]) + yield DataOperationResult( + row["id"] if success else "", + success, + "", + created, + ) + + def _get_batch_results(self): + """Handles results for other DataOperationTypes (insert, update, etc.)""" for batch_id in self.batch_ids: try: results_url = ( @@ -481,29 +688,46 @@ def get_results(self): # to avoid the server dropping connections with download_file(results_url, self.bulk) as f: self.logger.info(f"Downloaded results for batch {batch_id}") + yield from self._parse_batch_results(f) - reader = csv.reader(f) - next(reader) # skip header - - for row in reader: - success = process_bool_arg(row[1]) - created = process_bool_arg(row[2]) - yield DataOperationResult( - row[0] if success else None, - success, - row[3] if not success else None, - created, - ) except Exception as e: raise BulkDataException( f"Failed to download results for batch {batch_id} ({str(e)})" ) + def _parse_batch_results(self, f): + """Parses batch results from the downloaded file""" + reader = csv.reader(f) + next(reader) # Skip header row + + for row in reader: + success = process_bool_arg(row[1]) + created = process_bool_arg(row[2]) + yield DataOperationResult( + row[0] if success else None, + success, + row[3] if not success else None, + created, + ) + class RestApiDmlOperation(BaseDmlOperation): """Operation class for all DML operations run using the REST API.""" - def __init__(self, *, sobject, operation, api_options, context, fields): + def __init__( + self, + *, + sobject, + operation, + api_options, + context, + fields, + selection_strategy=SelectStrategy.STANDARD, + selection_filter=None, + selection_priority_fields=None, + content_type=None, + threshold=None, + ): super().__init__( sobject=sobject, operation=operation, @@ -517,7 +741,9 @@ def __init__(self, *, sobject, operation, api_options, context, fields): field["name"]: field for field in getattr(context.sf, sobject).describe()["fields"] } - self.boolean_fields = [f for f in fields if describe[f]["type"] == "boolean"] + self.boolean_fields = [ + f for f in fields if "." not in f and describe[f]["type"] == "boolean" + ] self.api_options = api_options.copy() self.api_options["batch_size"] = ( self.api_options.get("batch_size") or DEFAULT_REST_BATCH_SIZE @@ -526,6 +752,14 @@ def __init__(self, *, sobject, operation, api_options, context, fields): self.api_options["batch_size"], MAX_REST_BATCH_SIZE ) + self.select_operation_executor = SelectOperationExecutor(selection_strategy) + self.selection_filter = selection_filter + self.weights = assign_weights( + priority_fields=selection_priority_fields, fields=fields + ) + self.content_type = content_type + self.threshold = threshold + def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) for boolean_field in self.boolean_fields: @@ -623,14 +857,151 @@ def load_records(self, records): row_errors = len([res for res in self.results if not res["success"]]) self.job_result = DataOperationJobResult( - DataOperationStatus.SUCCESS - if not row_errors - else DataOperationStatus.ROW_FAILURE, + ( + DataOperationStatus.SUCCESS + if not row_errors + else DataOperationStatus.ROW_FAILURE + ), [], len(self.results), row_errors, ) + def select_records(self, records): + """Executes a SOQL query to select records and adds them to results""" + + self.results = [] + query_records = [] + + # Create a copy of the generator using tee + records, records_copy = tee(records) + + # Count total number of records to fetch using the copy + total_num_records = sum(1 for _ in records_copy) + + # Set LIMIT condition + limit_clause = self._determine_limit_clause(total_num_records) + + # Generate the SOQL query based on the selection strategy + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( + sobject=self.sobject, + fields=self.fields, + user_filter=self.selection_filter or None, + limit=limit_clause, + offset=None, + ) + + # Execute the query and gather the records + query_records = self._execute_soql_query(select_query, query_fields) + + # Post-process the query results for this batch + ( + selected_records, + insert_records, + error_message, + ) = self.select_operation_executor.select_post_process( + load_records=records, + query_records=query_records, + fields=self.fields, + num_records=total_num_records, + sobject=self.sobject, + weights=self.weights, + threshold=self.threshold, + ) + + # Log the number of selected and prepared for insertion records + num_selected = sum(1 for record in selected_records if record) + num_prepared = len(insert_records) if insert_records else 0 + + self.logger.info( + f"{num_selected} records selected." + + ( + f" {num_prepared} records prepared for insertion." + if num_prepared > 0 + else "" + ) + ) + + if insert_records: + self._process_insert_records(insert_records, selected_records) + + if not error_message: + # Add selected records from this batch to the overall results + self.results.extend(selected_records) + + # Update the job result based on the overall selection outcome + self._update_job_result(error_message) + + def _determine_limit_clause(self, total_num_records): + """Determines the LIMIT clause based on the retrieval mode.""" + if ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.ALL + ): + return None + elif ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.MATCH + ): + return total_num_records + + def _execute_soql_query(self, select_query, query_fields): + """Executes the SOQL query and returns the flattened records.""" + query_records = [] + response = self.sf.restful( + requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" + ) + query_records.extend(self._flatten_response_records(response, query_fields)) + + while not response["done"]: + response = self.sf.query_more( + response["nextRecordsUrl"], identifier_is_url=True + ) + query_records.extend(self._flatten_response_records(response, query_fields)) + + return query_records + + def _flatten_response_records(self, response, query_fields): + """Flattens the response records and returns them as a list.""" + return [flatten_record(record, query_fields) for record in response["records"]] + + def _process_insert_records(self, insert_records, selected_records): + """Processes and inserts records if necessary.""" + insert_fields, _ = split_and_filter_fields(fields=self.fields) + insert_step = RestApiDmlOperation( + sobject=self.sobject, + operation=DataOperationType.INSERT, + api_options=self.api_options, + context=self.context, + fields=insert_fields, + ) + insert_step.start() + insert_step.load_records(insert_records) + insert_step.end() + insert_results = insert_step.results + + insert_index = 0 + for idx, record in enumerate(selected_records): + if record is None: + selected_records[idx] = insert_results[insert_index] + insert_index += 1 + + def _update_job_result(self, error_message): + """Updates the job result based on the selection outcome.""" + self.job_result = DataOperationJobResult( + status=( + DataOperationStatus.SUCCESS + if len(self.results) + else DataOperationStatus.JOB_FAILURE + ), + job_errors=[error_message] if error_message else [], + records_processed=len(self.results), + total_row_errors=0, + ) + def get_results(self): """Return a generator of DataOperationResult objects.""" @@ -712,6 +1083,11 @@ def get_dml_operation( context: Any, volume: int, api: Optional[DataApi] = DataApi.SMART, + selection_strategy: SelectStrategy = SelectStrategy.STANDARD, + selection_filter: Union[str, None] = None, + selection_priority_fields: Union[dict, None] = None, + content_type: Union[str, None] = None, + threshold: Union[float, None] = None, ) -> BaseDmlOperation: """Create an appropriate DmlOperation instance for the given parameters, selecting between REST and Bulk APIs based upon volume (Bulk used at volumes over 2000 records, @@ -745,4 +1121,96 @@ def get_dml_operation( api_options=api_options, context=context, fields=fields, + selection_strategy=selection_strategy, + selection_filter=selection_filter, + selection_priority_fields=selection_priority_fields, + content_type=content_type, + threshold=threshold, ) + + +def extract_flattened_headers(query_fields): + """Extract headers from query fields, including handling of TYPEOF fields.""" + headers = [] + + for field in query_fields: + if isinstance(field, dict): + # Handle TYPEOF / polymorphic fields + for lookup, references in field.items(): + # Assuming each reference is a list of dictionaries + for ref_type in references: + for ref_obj, ref_fields in ref_type.items(): + for nested_field in ref_fields: + headers.append( + f"{lookup}.{ref_obj}.{nested_field}" + ) # Flatten the structure + else: + # Regular fields + headers.append(field) + + return headers + + +def flatten_record(record, headers): + """Flatten each record to match headers, handling nested fields.""" + flat_record = [] + + for field in headers: + components = field.split(".") + value = "" + + # Handle lookup fields with two or three components + if len(components) >= 2: + lookup_field = components[0] + lookup = record.get(lookup_field, None) + + # Check if lookup field exists in the record + if lookup is None: + value = "" + else: + if len(components) == 2: + # Handle fields with two components: {lookup}.{ref_field} + ref_field = components[1] + value = lookup.get(ref_field, "") + elif len(components) == 3: + # Handle fields with three components: {lookup}.{ref_obj}.{ref_field} + ref_obj, ref_field = components[1], components[2] + # Check if the type matches the specified ref_obj + if lookup.get("attributes", {}).get("type") == ref_obj: + value = lookup.get(ref_field, "") + else: + value = "" + + else: + # Regular fields or non-polymorphic fields + value = record.get(field, "") + + # Set None values to empty string + if value is None: + value = "" + elif not isinstance(value, str): + value = str(value) + + # Append the resolved value to the flattened record + flat_record.append(value) + + return flat_record + + +def assign_weights( + priority_fields: Union[Dict[str, str], None], fields: List[str] +) -> list: + # If priority_fields is None or an empty dictionary, set all weights to 1 + if not priority_fields: + return [1] * len(fields) + + # Initialize the weight list with LOW_PRIORITY_VALUE + weights = [LOW_PRIORITY_VALUE] * len(fields) + + # Iterate over the fields and assign weights based on priority_fields + for i, field in enumerate(fields): + if field in priority_fields: + # Set weight to HIGH_PRIORITY_VALUE if field is in priority_fields + weights[i] = HIGH_PRIORITY_VALUE + + return weights diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select.yml b/cumulusci/tasks/bulkdata/tests/mapping_select.yml new file mode 100644 index 0000000000..e549d7a474 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select.yml @@ -0,0 +1,20 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml new file mode 100644 index 0000000000..6ab196fda6 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml @@ -0,0 +1,20 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: invalid_strategy + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml new file mode 100644 index 0000000000..1bad614b1d --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml @@ -0,0 +1,21 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + threshold: 1.5 + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml new file mode 100644 index 0000000000..71958848c5 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml @@ -0,0 +1,21 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: standard + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + threshold: 0.5 + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml new file mode 100644 index 0000000000..2ff1482f3d --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml @@ -0,0 +1,21 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + threshold: invalid threshold + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml new file mode 100644 index 0000000000..34011945ad --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml @@ -0,0 +1,22 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + - Name + - AccountNumber + - ParentId + - Email + fields: + - Name + - AccountNumber + - Description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml new file mode 100644 index 0000000000..1559848b48 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml @@ -0,0 +1,18 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + fields: + - Name + - AccountNumber + - Description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/test_load.py b/cumulusci/tasks/bulkdata/tests/test_load.py index 6649ff202e..8fb8ee0756 100644 --- a/cumulusci/tasks/bulkdata/tests/test_load.py +++ b/cumulusci/tasks/bulkdata/tests/test_load.py @@ -806,6 +806,111 @@ def test_stream_queried_data__skips_empty_rows(self): ["001000000006", "001000000008"], ] == records + def test_process_lookup_fields_polymorphic(self): + task = _make_task( + LoadData, + { + "options": { + "sql_path": Path(__file__).parent + / "test_query_db_joins_lookups.sql", + "mapping": Path(__file__).parent + / "test_query_db_joins_lookups_select.yml", + } + }, + ) + polymorphic_fields = { + "WhoId": { + "name": "WhoId", + "referenceTo": ["Contact", "Lead"], + "relationshipName": "Who", + }, + "WhatId": { + "name": "WhatId", + "referenceTo": ["Account"], + "relationshipName": "What", + }, + } + + expected_fields = [ + "Subject", + "Who.Contact.FirstName", + "Who.Contact.LastName", + "Who.Lead.LastName", + "WhoId", + ] + expected_priority_fields_keys = { + "Who.Contact.FirstName", + "Who.Contact.LastName", + "Who.Lead.LastName", + } + with mock.patch( + "cumulusci.tasks.bulkdata.load.validate_and_inject_mapping" + ), mock.patch.object(task, "sf", create=True): + task._init_mapping() + with task._init_db(): + task._old_format = mock.Mock(return_value=False) + mapping = task.mapping["Select Event"] + fields = mapping.get_load_field_list() + task.process_lookup_fields( + mapping=mapping, fields=fields, polymorphic_fields=polymorphic_fields + ) + assert fields == expected_fields + assert ( + set(mapping.select_options.priority_fields.keys()) + == expected_priority_fields_keys + ) + + def test_process_lookup_fields_non_polymorphic(self): + task = _make_task( + LoadData, + { + "options": { + "sql_path": Path(__file__).parent + / "test_query_db_joins_lookups.sql", + "mapping": Path(__file__).parent + / "test_query_db_joins_lookups_select.yml", + } + }, + ) + non_polymorphic_fields = { + "AccountId": { + "name": "AccountId", + "referenceTo": ["Account"], + "relationshipName": "Account", + } + } + + expected_fields = [ + "FirstName", + "LastName", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ] + expected_priority_fields_keys = { + "FirstName", + "Account.Name", + "Account.AccountNumber", + } + with mock.patch( + "cumulusci.tasks.bulkdata.load.validate_and_inject_mapping" + ), mock.patch.object(task, "sf", create=True): + task._init_mapping() + with task._init_db(): + task._old_format = mock.Mock(return_value=False) + mapping = task.mapping["Select Contact"] + fields = mapping.get_load_field_list() + task.process_lookup_fields( + mapping=mapping, + fields=fields, + polymorphic_fields=non_polymorphic_fields, + ) + assert fields == expected_fields + assert ( + set(mapping.select_options.priority_fields.keys()) + == expected_priority_fields_keys + ) + @responses.activate def test_stream_queried_data__adjusts_relative_dates(self): mock_describe_calls() @@ -878,6 +983,15 @@ def test_query_db__joins_self_lookups(self): old_format=True, ) + def test_query_db__joins_select_lookups(self): + """SQL File in New Format (Select)""" + _validate_query_for_mapping_step( + sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql", + mapping=Path(__file__).parent / "test_query_db_joins_lookups_select.yml", + mapping_step_name="Select Event", + expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''', + ) + def test_query_db__joins_polymorphic_lookups(self): """SQL File in New Format (Polymorphic)""" _validate_query_for_mapping_step( diff --git a/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py b/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py index c1419f300b..8ce38ff5a8 100644 --- a/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py +++ b/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py @@ -17,6 +17,7 @@ parse_from_yaml, validate_and_inject_mapping, ) +from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType from cumulusci.tests.util import DummyOrgConfig, mock_describe_calls @@ -213,6 +214,70 @@ def test_get_relative_date_e2e(self): date.today(), ) + def test_select_options__success(self): + base_path = Path(__file__).parent / "mapping_select.yml" + result = parse_from_yaml(base_path) + + step = result["Select Accounts"] + select_options = step.select_options + assert select_options + assert select_options.strategy == SelectStrategy.SIMILARITY + assert select_options.filter == "WHEN Name in ('Sample Account')" + assert select_options.priority_fields + + def test_select_options__invalid_strategy(self): + base_path = Path(__file__).parent / "mapping_select_invalid_strategy.yml" + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert "Invalid strategy value: invalid_strategy" in str(e.value) + + def test_select_options__invalid_threshold__non_float(self): + base_path = ( + Path(__file__).parent / "mapping_select_invalid_threshold__non_float.yml" + ) + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert "value is not a valid float" in str(e.value) + + def test_select_options__invalid_threshold__invalid_strategy(self): + base_path = ( + Path(__file__).parent + / "mapping_select_invalid_threshold__invalid_strategy.yml" + ) + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert ( + "If a threshold is specified, the strategy must be set to 'similarity'." + in str(e.value) + ) + + def test_select_options__invalid_threshold__invalid_number(self): + base_path = ( + Path(__file__).parent + / "mapping_select_invalid_threshold__invalid_number.yml" + ) + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert "Threshold must be between 0 and 1, got 1.5" in str(e.value) + + def test_select_options__missing_priority_fields(self): + base_path = Path(__file__).parent / "mapping_select_missing_priority_fields.yml" + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + print(str(e.value)) + assert ( + "Priority fields {'Email'} are not present in 'fields' or 'lookups'" + in str(e.value) + ) + + def test_select_options__no_priority_fields(self): + base_path = Path(__file__).parent / "mapping_select_no_priority_fields.yml" + result = parse_from_yaml(base_path) + + step = result["Select Accounts"] + select_options = step.select_options + assert select_options.priority_fields == {} + # Start of FLS/Namespace Injection Unit Tests def test_is_injectable(self): diff --git a/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql index 113e5cebe5..ed7f0e694a 100644 --- a/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql +++ b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql @@ -1,13 +1,23 @@ BEGIN TRANSACTION; +CREATE TABLE "accounts" ( + id VARCHAR(255) NOT NULL, + "Name" VARCHAR(255), + "AccountNumber" VARCHAR(255), + PRIMARY KEY (id) +); +INSERT INTO "accounts" VALUES("Account-1",'Bluth Company','123456'); +INSERT INTO "accounts" VALUES("Account-2",'Sampson PLC','567890'); + CREATE TABLE "contacts" ( id VARCHAR(255) NOT NULL, "FirstName" VARCHAR(255), - "LastName" VARCHAR(255), + "LastName" VARCHAR(255), + "AccountId" VARCHAR(255), PRIMARY KEY (id) ); -INSERT INTO "contacts" VALUES("Contact-1",'Alpha','gamma'); -INSERT INTO "contacts" VALUES("Contact-2",'Temp','Bluth'); +INSERT INTO "contacts" VALUES("Contact-1",'Alpha','gamma', 'Account-2'); +INSERT INTO "contacts" VALUES("Contact-2",'Temp','Bluth', 'Account-1'); CREATE TABLE "events" ( id VARCHAR(255) NOT NULL, diff --git a/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml new file mode 100644 index 0000000000..4b37f491eb --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml @@ -0,0 +1,48 @@ +Insert Account: + sf_object: Account + table: accounts + api: rest + fields: + - Name + - AccountNumber + +Insert Lead: + sf_object: Lead + table: leads + api: bulk + fields: + - LastName + +Select Contact: + sf_object: Contact + table: contacts + api: bulk + action: select + select_options: + strategy: similarity + priority_fields: + - FirstName + - AccountId + fields: + - FirstName + - LastName + lookups: + AccountId: + table: accounts + +Select Event: + sf_object: Event + table: events + api: rest + action: select + select_options: + strategy: similarity + priority_fields: + - WhoId + fields: + - Subject + lookups: + WhoId: + table: + - contacts + - leads diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py new file mode 100644 index 0000000000..a0b5a3fcad --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -0,0 +1,1006 @@ +import pandas as pd +import pytest + +from cumulusci.tasks.bulkdata.select_utils import ( + SelectOperationExecutor, + SelectStrategy, + add_limit_offset_to_user_filter, + annoy_post_process, + calculate_levenshtein_distance, + determine_field_types, + find_closest_record, + levenshtein_distance, + reorder_records, + replace_empty_strings_with_missing, + split_and_filter_fields, + vectorize_records, +) + + +# Test Cases for standard_generate_query +def test_standard_generate_query_with_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + limit = 5 + offset = 2 + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset + ) + + assert "WHERE" in query # Ensure WHERE clause is included + assert f"LIMIT {limit}" in query + assert f"OFFSET {offset}" in query + assert fields == ["Id"] + + +def test_standard_generate_query_without_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset + ) + + assert "WHERE" not in query # No WHERE clause should be present + assert f"LIMIT {limit}" in query + assert "OFFSET" not in query + assert fields == ["Id"] + + +def test_standard_generate_query_with_user_filter(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + user_filter = "WHERE Name IN ('Sample Contact')" + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], user_filter=user_filter, limit=limit, offset=offset + ) + + assert "WHERE" in query + assert "Sample Contact" in query + assert "LIMIT" in query + assert "OFFSET" not in query + assert fields == ["Id"] + + +# Test Cases for random generate query +def test_random_generate_query_with_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + limit = 5 + offset = 2 + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset + ) + + assert "WHERE" in query # Ensure WHERE clause is included + assert f"LIMIT {limit}" in query + assert f"OFFSET {offset}" in query + assert fields == ["Id"] + + +def test_random_generate_query_without_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset + ) + + assert "WHERE" not in query # No WHERE clause should be present + assert f"LIMIT {limit}" in query + assert "OFFSET" not in query + assert fields == ["Id"] + + +# Test Cases for standard_post_process +def test_standard_post_process_with_records(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + records = [["001"], ["002"], ["003"]] + num_records = 3 + sobject = "Contact" + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, + ) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + assert all(record["id"] in ["001", "002", "003"] for record in selected_records) + + +def test_standard_post_process_with_fewer_records(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + records = [["001"]] + num_records = 3 + sobject = "Opportunity" + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, + ) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + # Check if records are repeated to match num_records + assert selected_records.count({"id": "001", "success": True, "created": False}) == 3 + + +def test_standard_post_process_with_no_records(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + records = [] + num_records = 2 + sobject = "Lead" + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, + ) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." + + +# Test cases for Random Post Process +def test_random_post_process_with_records(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) + records = [["001"], ["002"], ["003"]] + num_records = 3 + sobject = "Contact" + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, + ) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + + +def test_random_post_process_with_no_records(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) + records = [] + num_records = 2 + sobject = "Lead" + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, + ) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." + + +# Test Cases for Similarity Generate Query +def test_similarity_generate_query_with_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + limit = 5 + offset = 2 + query, fields = select_operator.select_generate_query( + sobject, ["Name"], [], limit, offset + ) + + assert "WHERE" in query # Ensure WHERE clause is included + assert fields == ["Id", "Name"] + assert f"LIMIT {limit}" in query + assert f"OFFSET {offset}" in query + + +def test_similarity_generate_query_without_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + query, fields = select_operator.select_generate_query( + sobject, ["Name"], [], limit, offset + ) + + assert "WHERE" not in query # No WHERE clause should be present + assert fields == ["Id", "Name"] + assert f"LIMIT {limit}" in query + assert "OFFSET" not in query + + +def test_similarity_generate_query_with_nested_fields(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + sobject = "Event" # Assuming no declaration for this object + limit = 3 + offset = None + fields = [ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ] + query, query_fields = select_operator.select_generate_query( + sobject, fields, [], limit, offset + ) + + assert "WHERE" not in query # No WHERE clause should be present + assert query_fields == [ + "Id", + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ] + assert f"LIMIT {limit}" in query + assert "TYPEOF Who" in query + assert "WHEN Contact" in query + assert "WHEN Lead" in query + assert "OFFSET" not in query + + +def test_random_generate_query_with_user_filter(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + user_filter = "WHERE Name IN ('Sample Contact')" + query, fields = select_operator.select_generate_query( + sobject=sobject, + fields=["Name"], + user_filter=user_filter, + limit=limit, + offset=offset, + ) + + assert "WHERE" in query + assert "Sample Contact" in query + assert "LIMIT" in query + assert "OFFSET" not in query + assert fields == ["Id", "Name"] + + +def test_levenshtein_distance(): + assert levenshtein_distance("kitten", "kitten") == 0 # Identical strings + assert levenshtein_distance("kitten", "sitten") == 1 # One substitution + assert levenshtein_distance("kitten", "kitte") == 1 # One deletion + assert levenshtein_distance("kitten", "sittin") == 2 # Two substitutions + assert levenshtein_distance("kitten", "dog") == 6 # Completely different strings + assert levenshtein_distance("kitten", "") == 6 # One string is empty + assert levenshtein_distance("", "") == 0 # Both strings are empty + assert levenshtein_distance("Kitten", "kitten") == 1 # Case sensitivity + assert levenshtein_distance("kit ten", "kitten") == 1 # Strings with spaces + assert ( + levenshtein_distance("levenshtein", "meilenstein") == 4 + ) # Longer strings with multiple differences + + +def test_find_closest_record_different_weights(): + load_record = ["hello", "world"] + query_records = [ + ["record1", "hello", "word"], # Levenshtein distance = 1 + ["record2", "hullo", "word"], # Levenshtein distance = 1 + ["record3", "hello", "word"], # Levenshtein distance = 1 + ] + weights = [2.0, 0.5] + + # With different weights, the first field will have more impact + closest_record, _ = find_closest_record(load_record, query_records, weights) + assert closest_record == [ + "record1", + "hello", + "word", + ], "The closest record should be 'record1'." + + +def test_find_closest_record_basic(): + load_record = ["hello", "world"] + query_records = [ + ["record1", "hello", "word"], # Levenshtein distance = 1 + ["record2", "hullo", "word"], # Levenshtein distance = 1 + ["record3", "hello", "word"], # Levenshtein distance = 1 + ] + weights = [1.0, 1.0] + + closest_record, _ = find_closest_record(load_record, query_records, weights) + assert closest_record == [ + "record1", + "hello", + "word", + ], "The closest record should be 'record1'." + + +def test_find_closest_record_multiple_matches(): + load_record = ["cat", "dog"] + query_records = [ + ["record1", "bat", "dog"], # Levenshtein distance = 1 + ["record2", "cat", "dog"], # Levenshtein distance = 0 + ["record3", "dog", "cat"], # Levenshtein distance = 3 + ] + weights = [1.0, 1.0] + + closest_record, _ = find_closest_record(load_record, query_records, weights) + assert closest_record == [ + "record2", + "cat", + "dog", + ], "The closest record should be 'record2'." + + +def test_similarity_post_process_with_records(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + num_records = 1 + sobject = "Contact" + load_records = [["Tom Cruise", "62", "Actor"]] + query_records = [ + ["001", "Bob Hanks", "62", "Actor"], + ["002", "Tom Cruise", "63", "Actor"], # Slight difference + ["003", "Jennifer Aniston", "30", "Actress"], + ] + + weights = [1.0, 1.0, 1.0] # Adjust weights to match your data structure + + selected_records, _, error_message = select_operator.select_post_process( + load_records=load_records, + query_records=query_records, + num_records=num_records, + sobject=sobject, + weights=weights, + fields=["Name", "Age", "Occupation"], + threshold=None, + ) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + x = [record["id"] for record in selected_records] + print(x) + assert all(record["id"] in ["002"] for record in selected_records) + + +def test_similarity_post_process_with_no_records(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + records = [] + num_records = 2 + sobject = "Lead" + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[1, 1, 1], + fields=[], + threshold=None, + ) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." + + +def test_calculate_levenshtein_distance_basic(): + record1 = ["hello", "world"] + record2 = ["hullo", "word"] + weights = [1.0, 1.0] + + # Expected distance based on simple Levenshtein distances + # Levenshtein("hello", "hullo") = 1, Levenshtein("world", "word") = 1 + expected_distance = (1 / 5 * 1.0 + 1 / 5 * 1.0) / 2 # Averaged over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Basic distance calculation failed." + + # Empty fields + record1 = ["hello", ""] + record2 = ["hullo", ""] + weights = [1.0, 1.0] + + # Expected distance based on simple Levenshtein distances + # Levenshtein("hello", "hullo") = 1, Levenshtein("", "") = 0 + expected_distance = (1 / 5 * 1.0 + 0 * 1.0) / 2 # Averaged over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Basic distance calculation with empty fields failed." + + # Partial empty fields + record1 = ["hello", "world"] + record2 = ["hullo", ""] + weights = [1.0, 1.0] + + # Expected distance based on simple Levenshtein distances + # Levenshtein("hello", "hullo") = 1, Levenshtein("world", "") = 5 + expected_distance = ( + 1 / 5 * 1.0 + 5 / 5 * 0.05 * 1.0 + ) / 2 # Averaged over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Basic distance calculation with partial empty fields failed." + + +def test_calculate_levenshtein_distance_weighted(): + record1 = ["cat", "dog"] + record2 = ["bat", "fog"] + weights = [2.0, 0.5] + + # Levenshtein("cat", "bat") = 1, Levenshtein("dog", "fog") = 1 + expected_distance = ( + 1 / 3 * 2.0 + 1 / 3 * 0.5 + ) / 2.5 # Weighted average over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Weighted distance calculation failed." + + +def test_calculate_levenshtein_distance_records_length_doesnt_match(): + record1 = ["cat", "dog", "cow"] + record2 = ["bat", "fog"] + weights = [2.0, 0.5] + + with pytest.raises(ValueError) as e: + calculate_levenshtein_distance(record1, record2, weights) + assert "Records must have the same number of fields." in str(e.value) + + +def test_calculate_levenshtein_distance_weights_length_doesnt_match(): + record1 = ["cat", "dog"] + record2 = ["bat", "fog"] + weights = [2.0, 0.5, 3.0] + + with pytest.raises(ValueError) as e: + calculate_levenshtein_distance(record1, record2, weights) + assert "Records must be same size as fields (weights)." in str(e.value) + + +def test_replace_empty_strings_with_missing(): + # Case 1: Normal case with some empty strings + records = [ + ["Alice", "", "New York"], + ["Bob", "Engineer", ""], + ["", "Teacher", "Chicago"], + ] + expected = [ + ["Alice", "missing", "New York"], + ["Bob", "Engineer", "missing"], + ["missing", "Teacher", "Chicago"], + ] + assert replace_empty_strings_with_missing(records) == expected + + # Case 2: No empty strings, so the output should be the same as input + records = [["Alice", "Manager", "New York"], ["Bob", "Engineer", "San Francisco"]] + expected = [["Alice", "Manager", "New York"], ["Bob", "Engineer", "San Francisco"]] + assert replace_empty_strings_with_missing(records) == expected + + # Case 3: List with all empty strings + records = [["", "", ""], ["", "", ""]] + expected = [["missing", "missing", "missing"], ["missing", "missing", "missing"]] + assert replace_empty_strings_with_missing(records) == expected + + # Case 4: Empty list (should return an empty list) + records = [] + expected = [] + assert replace_empty_strings_with_missing(records) == expected + + # Case 5: List with some empty sublists + records = [[], ["Alice", ""], []] + expected = [[], ["Alice", "missing"], []] + assert replace_empty_strings_with_missing(records) == expected + + +def test_all_numeric_columns(): + df = pd.DataFrame({"A": [1, 2, 3], "B": [4.5, 5.5, 6.5]}) + weights = [0.1, 0.2] + expected_output = ( + ["A", "B"], # numerical_features + [], # boolean_features + [], # categorical_features + [0.1, 0.2], # numerical_weights + [], # boolean_weights + [], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_all_boolean_columns(): + df = pd.DataFrame({"A": ["true", "false", "true"], "B": ["false", "true", "false"]}) + weights = [0.3, 0.4] + expected_output = ( + [], # numerical_features + ["A", "B"], # boolean_features + [], # categorical_features + [], # numerical_weights + [0.3, 0.4], # boolean_weights + [], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_all_categorical_columns(): + df = pd.DataFrame( + {"A": ["apple", "banana", "cherry"], "B": ["dog", "cat", "mouse"]} + ) + weights = [0.5, 0.6] + expected_output = ( + [], # numerical_features + [], # boolean_features + ["A", "B"], # categorical_features + [], # numerical_weights + [], # boolean_weights + [0.5, 0.6], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_mixed_types(): + df = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["true", "false", "true"], + "C": ["apple", "banana", "cherry"], + } + ) + weights = [0.7, 0.8, 0.9] + expected_output = ( + ["A"], # numerical_features + ["B"], # boolean_features + ["C"], # categorical_features + [0.7], # numerical_weights + [0.8], # boolean_weights + [0.9], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_vectorize_records_mixed_numerical_boolean_categorical(): + # Test data with mixed types: numerical and categorical only + db_records = [["1.0", "true", "apple"], ["2.0", "false", "banana"]] + query_records = [["1.5", "true", "apple"], ["2.5", "false", "cherry"]] + weights = [1.0, 1.0, 1.0] # Equal weights for numerical and categorical columns + hash_features = 4 # Number of hashing vectorizer features for categorical columns + + final_db_vectors, final_query_vectors = vectorize_records( + db_records, query_records, hash_features, weights + ) + + # Check the shape of the output vectors + assert final_db_vectors.shape[0] == len(db_records), "DB vectors row count mismatch" + assert final_query_vectors.shape[0] == len( + query_records + ), "Query vectors row count mismatch" + + # Expected dimensions: numerical (1) + categorical hashed features (4) + expected_feature_count = 2 + hash_features + assert ( + final_db_vectors.shape[1] == expected_feature_count + ), "DB vectors column count mismatch" + assert ( + final_query_vectors.shape[1] == expected_feature_count + ), "Query vectors column count mismatch" + + +def test_annoy_post_process(): + # Test data + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [["q1", "Alice", "Engineer"], ["q2", "Charlie", "Artist"]] + weights = [1.0, 1.0, 1.0] # Example weights + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=None, + ) + + # Assert the closest records + assert ( + len(closest_records) == 2 + ) # We expect two results (one for each query record) + assert ( + closest_records[0]["id"] == "q1" + ) # The first query record should match the first load record + + # No errors expected + assert not insert_records + + +def test_annoy_post_process__insert_records(): + # Test data + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [["q1", "Alice", "Engineer"], ["q2", "Charlie", "Artist"]] + weights = [1.0, 1.0, 1.0] # Example weights + threshold = 0.3 + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=threshold, + ) + + # Assert the closest records + assert len(closest_records) == 2 # We expect two results (one record and one None) + assert ( + closest_records[0]["id"] == "q1" + ) # The first query record should match the first load record + assert closest_records[1] is None # The second query record should be None + assert insert_records[0] == [ + "Bob", + "Doctor", + ] # The first insert record should match the second load record + + +def test_annoy_post_process__no_query_records(): + # Test data + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [] + weights = [1.0, 1.0, 1.0] # Example weights + threshold = 0.3 + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=threshold, + ) + + # Assert the closest records + assert len(closest_records) == 2 # We expect two results (both None) + assert all(rec is None for rec in closest_records) # Both should be None + assert insert_records[0] == [ + "Alice", + "Engineer", + ] # The first insert record should match the second load record + assert insert_records[1] == [ + "Bob", + "Doctor", + ] # The first insert record should match the second load record + + +def test_annoy_post_process__insert_records_with_polymorphic_fields(): + # Test data + load_records = [ + ["Alice", "Engineer", "Alice_Contact", "abcd1234"], + ["Bob", "Doctor", "Bob_Contact", "qwer1234"], + ] + query_records = [ + ["q1", "Alice", "Engineer", "Alice_Contact"], + ["q2", "Charlie", "Artist", "Charlie_Contact"], + ] + weights = [1.0, 1.0, 1.0, 1.0] # Example weights + threshold = 0.3 + all_fields = ["Name", "Occupation", "Contact.Name", "ContactId"] + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=all_fields, + threshold=threshold, + ) + + # Assert the closest records + assert len(closest_records) == 2 # We expect two results (one record and one None) + assert ( + closest_records[0]["id"] == "q1" + ) # The first query record should match the first load record + assert closest_records[1] is None # The second query record should be None + assert insert_records[0] == [ + "Bob", + "Doctor", + "qwer1234", + ] # The first insert record should match the second load record + + +def test_single_record_match_annoy_post_process(): + # Mock data where only the first query record matches the first load record + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [["q1", "Alice", "Engineer"]] + weights = [1.0, 1.0, 1.0] + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=None, + ) + + # Both the load records should be matched with the only query record we have + assert len(closest_records) == 2 + assert closest_records[0]["id"] == "q1" + assert not insert_records + + +@pytest.mark.parametrize( + "filter_clause, limit_clause, offset_clause, expected", + [ + # Test: No existing LIMIT/OFFSET and no new clauses + ("SELECT * FROM users", None, None, " SELECT * FROM users"), + # Test: Existing LIMIT and no new limit provided + ("SELECT * FROM users LIMIT 100", None, None, "SELECT * FROM users LIMIT 100"), + # Test: Existing OFFSET and no new offset provided + ("SELECT * FROM users OFFSET 20", None, None, "SELECT * FROM users OFFSET 20"), + # Test: Existing LIMIT/OFFSET and new clauses provided + ( + "SELECT * FROM users LIMIT 100 OFFSET 20", + 50, + 10, + "SELECT * FROM users LIMIT 50 OFFSET 30", + ), + # Test: Existing LIMIT, new limit larger than existing (should keep the smaller one) + ("SELECT * FROM users LIMIT 100", 150, None, "SELECT * FROM users LIMIT 100"), + # Test: New limit smaller than existing (should use the new one) + ("SELECT * FROM users LIMIT 100", 50, None, "SELECT * FROM users LIMIT 50"), + # Test: Existing OFFSET, adding a new offset (should sum the offsets) + ("SELECT * FROM users OFFSET 20", None, 30, "SELECT * FROM users OFFSET 50"), + # Test: Existing LIMIT/OFFSET and new values set to None + ( + "SELECT * FROM users LIMIT 100 OFFSET 20", + None, + None, + "SELECT * FROM users LIMIT 100 OFFSET 20", + ), + # Test: Removing existing LIMIT and adding a new one + ("SELECT * FROM users LIMIT 200", 50, None, "SELECT * FROM users LIMIT 50"), + # Test: Removing existing OFFSET and adding a new one + ("SELECT * FROM users OFFSET 40", None, 20, "SELECT * FROM users OFFSET 60"), + # Edge case: Filter clause with mixed cases + ( + "SELECT * FROM users LiMiT 100 oFfSeT 20", + 50, + 10, + "SELECT * FROM users LIMIT 50 OFFSET 30", + ), + # Test: Filter clause with trailing/leading spaces + ( + " SELECT * FROM users LIMIT 100 OFFSET 20 ", + 50, + 10, + "SELECT * FROM users LIMIT 50 OFFSET 30", + ), + ], +) +def test_add_limit_offset_to_user_filter( + filter_clause, limit_clause, offset_clause, expected +): + result = add_limit_offset_to_user_filter(filter_clause, limit_clause, offset_clause) + assert result.strip() == expected.strip() + + +def test_reorder_records_basic_reordering(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["job", "name"] + + expected = [ + ["Engineer", "Alice"], + ["Designer", "Bob"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_partial_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["age"] + + expected = [ + [30], + [25], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_missing_fields_in_new_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["nonexistent", "job"] + + expected = [ + ["Engineer"], + ["Designer"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_empty_records(): + records = [] + original_fields = ["name", "age", "job"] + new_fields = ["job", "name"] + + expected = [] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_empty_new_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = [] + + expected = [ + [], + [], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_empty_original_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = [] + new_fields = ["job", "name"] + + with pytest.raises(KeyError): + reorder_records(records, original_fields, new_fields) + + +def test_reorder_records_no_common_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["nonexistent_field"] + + expected = [ + [], + [], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_duplicate_fields_in_new_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["job", "job", "name"] + + expected = [ + ["Engineer", "Engineer", "Alice"], + ["Designer", "Designer", "Bob"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_all_fields_in_order(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["name", "age", "job"] + + expected = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_split_and_filter_fields_basic_case(): + fields = [ + "Account.Name", + "Account.Industry", + "Contact.Name", + "AccountId", + "ContactId", + "CreatedDate", + ] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["AccountId", "ContactId", "CreatedDate"] + assert select_fields == [ + "Account.Name", + "Account.Industry", + "Contact.Name", + "CreatedDate", + ] + + +def test_split_and_filter_fields_all_non_lookup_fields(): + fields = ["Name", "CreatedDate"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["Name", "CreatedDate"] + assert select_fields == fields + + +def test_split_and_filter_fields_all_lookup_fields(): + fields = ["Account.Name", "Account.Industry", "Contact.Name"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == [] + assert select_fields == fields + + +def test_split_and_filter_fields_empty_fields(): + fields = [] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == [] + assert select_fields == [] + + +def test_split_and_filter_fields_single_non_lookup_field(): + fields = ["Id"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["Id"] + assert select_fields == ["Id"] + + +def test_split_and_filter_fields_single_lookup_field(): + fields = ["Account.Name"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == [] + assert select_fields == ["Account.Name"] + + +def test_split_and_filter_fields_multiple_unique_lookups(): + fields = [ + "Account.Name", + "Account.Industry", + "Contact.Email", + "Contact.Phone", + "Id", + ] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["Id"] + assert ( + select_fields == fields + ) # No filtering applied since all components are unique diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index fc8cea7013..e94e91f226 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -1,5 +1,6 @@ import io import json +from itertools import tee from unittest import mock import pytest @@ -7,7 +8,10 @@ from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.load import LoadData +from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import ( + HIGH_PRIORITY_VALUE, + LOW_PRIORITY_VALUE, BulkApiDmlOperation, BulkApiQueryOperation, BulkJobMixin, @@ -18,7 +22,10 @@ DataOperationType, RestApiDmlOperation, RestApiQueryOperation, + assign_weights, download_file, + extract_flattened_headers, + flatten_record, get_dml_operation, get_query_operation, ) @@ -534,242 +541,1894 @@ def test_get_prev_record_values(self): ) step.bulk.get_all_results_for_query_batch.assert_called_once_with("BATCH_ID") - def test_batch(self): + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_standard_strategy_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation context = mock.Mock() - step = BulkApiDmlOperation( sobject="Contact", - operation=DataOperationType.INSERT, - api_options={"batch_size": 2}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, context=context, fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + content_type="JSON", ) - records = iter([["Test"], ["Test2"], ["Test3"]]) - results = list(step._batch(records, n=2)) + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] - assert len(results) == 2 - assert list(results[0]) == [ - '"LastName"\r\n'.encode("utf-8"), - '"Test"\r\n'.encode("utf-8"), - '"Test2"\r\n'.encode("utf-8"), - ] - assert list(results[1]) == [ - '"LastName"\r\n'.encode("utf-8"), - '"Test3"\r\n'.encode("utf-8"), - ] + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO('[{"Id":"003000000000001"}]') - def test_batch__character_limit(self): - context = mock.Mock() + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_standard_strategy_failure__no_records(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() step = BulkApiDmlOperation( sobject="Contact", - operation=DataOperationType.INSERT, - api_options={"batch_size": 2}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, context=context, fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, ) - records = [["Test"], ["Test2"], ["Test3"]] + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] - csv_rows = [step._serialize_csv_record(step.fields)] - for r in records: - csv_rows.append(step._serialize_csv_record(r)) + # Mock the downloaded CSV content indicating no records found + download_mock.return_value = io.StringIO("[]") - char_limit = sum([len(r) for r in csv_rows]) - 1 + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) - # Ask for batches of three, but we - # should get batches of 2 back - results = list(step._batch(iter(records), n=3, char_limit=char_limit)) + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) - assert len(results) == 2 - assert list(results[0]) == [ - '"LastName"\r\n'.encode("utf-8"), - '"Test"\r\n'.encode("utf-8"), - '"Test2"\r\n'.encode("utf-8"), - ] - assert list(results[1]) == [ - '"LastName"\r\n'.encode("utf-8"), - '"Test3"\r\n'.encode("utf-8"), - ] + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 @mock.patch("cumulusci.tasks.bulkdata.step.download_file") - def test_get_results(self, download_mock): + def test_select_records_user_selection_filter_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation context = mock.Mock() - context.bulk.endpoint = "https://test" - download_mock.side_effect = [ - io.StringIO( - """id,success,created,error -003000000000001,true,true, -003000000000002,true,true,""" - ), - io.StringIO( - """id,success,created,error -003000000000003,false,false,error""" - ), - ] - step = BulkApiDmlOperation( sobject="Contact", - operation=DataOperationType.INSERT, - api_options={}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, context=context, fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter='WHERE LastName in ("Sample Name")', ) - step.job_id = "JOB" - step.batch_ids = ["BATCH1", "BATCH2"] - results = step.get_results() + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] - assert list(results) == [ - DataOperationResult("003000000000001", True, None, True), - DataOperationResult("003000000000002", True, None, True), - DataOperationResult(None, False, "error", False), - ] - download_mock.assert_has_calls( - [ - mock.call("https://test/job/JOB/batch/BATCH1/result", context.bulk), - mock.call("https://test/job/JOB/batch/BATCH2/result", context.bulk), - ] + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO('[{"Id":"003000000000001"}]') + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 ) @mock.patch("cumulusci.tasks.bulkdata.step.download_file") - def test_get_results__failure(self, download_mock): + def test_select_records_user_selection_filter_order_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation context = mock.Mock() - context.bulk.endpoint = "https://test" - download_mock.return_value.side_effect = Exception + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter="ORDER BY CreatedDate", + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + '[{"Id":"003000000000003"}, {"Id":"003000000000001"}, {"Id":"003000000000002"}]' + ) + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results are in the order given by user query + assert results[0].id == "003000000000003" + assert results[1].id == "003000000000001" + assert results[2].id == "003000000000002" + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_user_selection_filter_failure(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() step = BulkApiDmlOperation( sobject="Contact", - operation=DataOperationType.INSERT, - api_options={}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, context=context, fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter='WHERE LastName in ("Sample Name")', ) - step.job_id = "JOB" - step.batch_ids = ["BATCH1", "BATCH2"] + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.side_effect = BulkDataException("MALFORMED QUERY") + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() with pytest.raises(BulkDataException): - list(step.get_results()) + step.select_records(records) @mock.patch("cumulusci.tasks.bulkdata.step.download_file") - def test_end_to_end(self, download_mock): + def test_select_records_similarity_strategy_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation context = mock.Mock() - context.bulk.endpoint = "https://test" - context.bulk.create_job.return_value = "JOB" - context.bulk.post_batch.side_effect = ["BATCH1", "BATCH2"] + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record download_mock.return_value = io.StringIO( - """id,success,created,error -003000000000001,true,true, -003000000000002,true,true, -003000000000003,false,false,error""" + """[{"Id":"003000000000001", "Name":"Jawad", "Email":"mjawadtp@example.com"}, {"Id":"003000000000002", "Name":"Aditya", "Email":"aditya@example.com"}, {"Id":"003000000000003", "Name":"Tom", "Email":"tom@example.com"}]""" + ) + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 ) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_failure__no_records( + self, download_mock + ): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() step = BulkApiDmlOperation( sobject="Contact", - operation=DataOperationType.INSERT, - api_options={}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, context=context, - fields=["LastName"], + fields=["Id", "Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content indicating no records found + download_mock.return_value = io.StringIO("[]") + + # Mock the _wait_for_job method to simulate a successful job step._wait_for_job = mock.Mock() step._wait_for_job.return_value = DataOperationJobResult( DataOperationStatus.SUCCESS, [], 0, 0 ) + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Execute the select_records operation step.start() - step.load_records(iter([["Test"], ["Test2"], ["Test3"]])) + step.select_records(records) step.end() - assert step.job_result.status is DataOperationStatus.SUCCESS - results = step.get_results() - - assert list(results) == [ - DataOperationResult("003000000000001", True, None, True), - DataOperationResult("003000000000002", True, None, True), - DataOperationResult(None, False, "error", False), - ] - + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 -class TestRestApiQueryOperation: - def test_query(self): + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_parent_level_records__polymorphic( + self, download_mock + ): + mock_describe_calls() + # Set up mock context and BulkApiDmlOperation context = mock.Mock() - context.sf.query.return_value = { - "totalSize": 2, - "done": True, - "records": [ - { - "Id": "003000000000001", - "LastName": "Narvaez", - "Email": "wayne@example.com", - }, - {"Id": "003000000000002", "LastName": "De Vries", "Email": None}, + step = BulkApiDmlOperation( + sobject="Event", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=[ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + "WhoId", ], - } + selection_strategy=SelectStrategy.SIMILARITY, + ) - query_op = RestApiQueryOperation( + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + download_mock.return_value = io.StringIO( + """[ + {"Id": "003000000000001", "Subject": "Sample Event 1", "Who":{ "attributes": {"type": "Contact"}, "Id": "abcd1234", "Name": "Sample Contact", "Email": "contact@example.com"}}, + { "Id": "003000000000002", "Subject": "Sample Event 2", "Who":{ "attributes": {"type": "Lead"}, "Id": "qwer1234", "Name": "Sample Lead", "Company": "Salesforce"}} + ]""" + ) + + records = iter( + [ + [ + "Sample Event 1", + "Sample Contact", + "contact@example.com", + "", + "", + "lkjh1234", + ], + ["Sample Event 2", "", "", "Sample Lead", "Salesforce", "poiu1234"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_parent_level_records__non_polymorphic( + self, download_mock + ): + mock_describe_calls() + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( sobject="Contact", - fields=["Id", "LastName", "Email"], - api_options={}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, context=context, - query="SELECT Id, LastName, Email FROM Contact", + fields=["Name", "Account.Name", "Account.AccountNumber", "AccountId"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + download_mock.return_value = io.StringIO( + """[ + {"Id": "003000000000001", "Name": "Sample Contact 1", "Account":{ "attributes": {"type": "Account"}, "Id": "abcd1234", "Name": "Sample Account", "AccountNumber": 123456}}, + { "Id": "003000000000002", "Subject": "Sample Contact 2", "Account": null} + ]""" + ) + + records = iter( + [ + ["Sample Contact 3", "Sample Account", "123456", "poiu1234"], + ["Sample Contact 4", "", "", ""], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_priority_fields(self, download_mock): + mock_describe_calls() + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step_1 = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={"Name": "Name", "Email": "Email"}, + ) + + step_2 = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={ + "Account.Name": "Account.Name", + "Account.AccountNumber": "Account.AccountNumber", + }, + ) + + # Mock Bulk API responses + step_1.bulk.endpoint = "https://test" + step_1.bulk.create_query_job.return_value = "JOB" + step_1.bulk.query.return_value = "BATCH" + step_1.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + step_2.bulk.endpoint = "https://test" + step_2.bulk.create_query_job.return_value = "JOB" + step_2.bulk.query.return_value = "BATCH" + step_2.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + sample_response = [ + { + "Id": "003000000000001", + "Name": "Bob The Builder", + "Email": "bob@yahoo.org", + "Account": { + "attributes": {"type": "Account"}, + "Id": "abcd1234", + "Name": "Jawad TP", + "AccountNumber": 567890, + }, + }, + { + "Id": "003000000000002", + "Name": "Tom Cruise", + "Email": "tom@exmaple.com", + "Account": { + "attributes": {"type": "Account"}, + "Id": "qwer1234", + "Name": "Aditya B", + "AccountNumber": 123456, + }, + }, + ] + + download_mock.side_effect = [ + io.StringIO(f"""{json.dumps(sample_response)}"""), + io.StringIO(f"""{json.dumps(sample_response)}"""), + ] + + records = iter( + [ + ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456", "poiu1234"], + ] + ) + records_1, records_2 = tee(records) + step_1.start() + step_1.select_records(records_1) + step_1.end() + + step_2.start() + step_2.select_records(records_2) + step_2.end() + + # Get the results and assert their properties + results_1 = list(step_1.get_results()) + results_2 = list(step_2.get_results()) + assert ( + len(results_1) == 1 + ) # Expect 1 results (matching the input records count) + assert ( + len(results_2) == 1 + ) # Expect 1 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + # Prioritizes Name and Email + assert results_1[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + # Prioritizes Account.Name and Account.AccountNumber + assert results_2[0] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_process_insert_records_success(self, download_mock): + # Mock context and insert records + context = mock.Mock() + insert_records = iter([["John", "Doe"], ["Jane", "Smith"]]) + selected_records = [None, None] + + # Mock insert fields splitting + insert_fields = ["FirstName", "LastName"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ) as split_mock: + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["FirstName", "LastName"], + ) + + # Mock Bulk API + step.bulk.endpoint = "https://test" + step.bulk.create_insert_job.return_value = "JOB" + step.bulk.get_insert_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with successful results + download_mock.return_value = io.StringIO( + "Id,Success,Created\n0011k00003E8xAaAAI,true,true\n0011k00003E8xAbAAJ,true,true\n" + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + step._process_insert_records(insert_records, selected_records) + + # Assertions for split fields and sub-operation + split_mock.assert_called_once_with(fields=["FirstName", "LastName"]) + insert_step.start.assert_called_once() + insert_step.load_records.assert_called_once_with(insert_records) + insert_step.end.assert_called_once() + + # Validate the download file interactions + download_mock.assert_called_once_with( + "https://test/job/JOB/batch/BATCH1/result", insert_step.bulk + ) + + # Validate that selected_records is updated with insert results + assert selected_records == [ + {"id": "0011k00003E8xAaAAI", "success": True, "created": True}, + {"id": "0011k00003E8xAbAAJ", "success": True, "created": True}, + ] + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_process_insert_records_failure(self, download_mock): + # Mock context and insert records + context = mock.Mock() + insert_records = iter([["John", "Doe"], ["Jane", "Smith"]]) + selected_records = [None, None] + + # Mock insert fields splitting + insert_fields = ["FirstName", "LastName"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ): + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["FirstName", "LastName"], + ) + + # Mock failure during results download + download_mock.side_effect = Exception("Download failed") + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + with pytest.raises(BulkDataException) as excinfo: + step._process_insert_records(insert_records, selected_records) + + # Validate that the exception is raised with the correct message + assert "Failed to download results for batch BATCH1" in str( + excinfo.value + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy__insert_records(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + # Add step with threshold + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0.3, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + select_results = io.StringIO( + """[{"Id":"003000000000001", "Name":"Jawad", "Email":"mjawadtp@example.com"}]""" + ) + insert_results = io.StringIO( + "Id,Success,Created\n003000000000002,true,true\n003000000000003,true,true\n" + ) + download_mock.side_effect = [select_results, insert_results] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy__insert_records__no_select_records( + self, download_mock + ): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + # Add step with threshold + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0.3, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + select_results = io.StringIO("""[]""") + insert_results = io.StringIO( + "Id,Success,Created\n003000000000001,true,true\n003000000000002,true,true\n003000000000003,true,true\n" + ) + download_mock.side_effect = [select_results, insert_results] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + + def test_batch(self): + context = mock.Mock() + + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 2}, + context=context, + fields=["LastName"], + ) + + records = iter([["Test"], ["Test2"], ["Test3"]]) + results = list(step._batch(records, n=2)) + + assert len(results) == 2 + assert list(results[0]) == [ + '"LastName"\r\n'.encode("utf-8"), + '"Test"\r\n'.encode("utf-8"), + '"Test2"\r\n'.encode("utf-8"), + ] + assert list(results[1]) == [ + '"LastName"\r\n'.encode("utf-8"), + '"Test3"\r\n'.encode("utf-8"), + ] + + def test_batch__character_limit(self): + context = mock.Mock() + + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 2}, + context=context, + fields=["LastName"], + ) + + records = [["Test"], ["Test2"], ["Test3"]] + + csv_rows = [step._serialize_csv_record(step.fields)] + for r in records: + csv_rows.append(step._serialize_csv_record(r)) + + char_limit = sum([len(r) for r in csv_rows]) - 1 + + # Ask for batches of three, but we + # should get batches of 2 back + results = list(step._batch(iter(records), n=3, char_limit=char_limit)) + + assert len(results) == 2 + assert list(results[0]) == [ + '"LastName"\r\n'.encode("utf-8"), + '"Test"\r\n'.encode("utf-8"), + '"Test2"\r\n'.encode("utf-8"), + ] + assert list(results[1]) == [ + '"LastName"\r\n'.encode("utf-8"), + '"Test3"\r\n'.encode("utf-8"), + ] + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_get_results(self, download_mock): + context = mock.Mock() + context.bulk.endpoint = "https://test" + download_mock.side_effect = [ + io.StringIO( + """id,success,created,error +003000000000001,true,true, +003000000000002,true,true,""" + ), + io.StringIO( + """id,success,created,error +003000000000003,false,false,error""" + ), + ] + + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={}, + context=context, + fields=["LastName"], + ) + step.job_id = "JOB" + step.batch_ids = ["BATCH1", "BATCH2"] + + results = step.get_results() + + assert list(results) == [ + DataOperationResult("003000000000001", True, None, True), + DataOperationResult("003000000000002", True, None, True), + DataOperationResult(None, False, "error", False), + ] + download_mock.assert_has_calls( + [ + mock.call("https://test/job/JOB/batch/BATCH1/result", context.bulk), + mock.call("https://test/job/JOB/batch/BATCH2/result", context.bulk), + ] + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_get_results__failure(self, download_mock): + context = mock.Mock() + context.bulk.endpoint = "https://test" + download_mock.return_value.side_effect = Exception + + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={}, + context=context, + fields=["LastName"], + ) + step.job_id = "JOB" + step.batch_ids = ["BATCH1", "BATCH2"] + + with pytest.raises(BulkDataException): + list(step.get_results()) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_end_to_end(self, download_mock): + context = mock.Mock() + context.bulk.endpoint = "https://test" + context.bulk.create_job.return_value = "JOB" + context.bulk.post_batch.side_effect = ["BATCH1", "BATCH2"] + download_mock.return_value = io.StringIO( + """id,success,created,error +003000000000001,true,true, +003000000000002,true,true, +003000000000003,false,false,error""" + ) + + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={}, + context=context, + fields=["LastName"], + ) + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + step.start() + step.load_records(iter([["Test"], ["Test2"], ["Test3"]])) + step.end() + + assert step.job_result.status is DataOperationStatus.SUCCESS + results = step.get_results() + + assert list(results) == [ + DataOperationResult("003000000000001", True, None, True), + DataOperationResult("003000000000002", True, None, True), + DataOperationResult(None, False, "error", False), + ] + + +class TestRestApiQueryOperation: + def test_query(self): + context = mock.Mock() + context.sf.query.return_value = { + "totalSize": 2, + "done": True, + "records": [ + { + "Id": "003000000000001", + "LastName": "Narvaez", + "Email": "wayne@example.com", + }, + {"Id": "003000000000002", "LastName": "De Vries", "Email": None}, + ], + } + + query_op = RestApiQueryOperation( + sobject="Contact", + fields=["Id", "LastName", "Email"], + api_options={}, + context=context, + query="SELECT Id, LastName, Email FROM Contact", + ) + + query_op.query() + + assert query_op.job_result == DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 2, 0 + ) + assert list(query_op.get_results()) == [ + ["003000000000001", "Narvaez", "wayne@example.com"], + ["003000000000002", "De Vries", ""], + ] + + def test_query_batches(self): + context = mock.Mock() + context.sf.query.return_value = { + "totalSize": 2, + "done": False, + "records": [ + { + "Id": "003000000000001", + "LastName": "Narvaez", + "Email": "wayne@example.com", + } + ], + "nextRecordsUrl": "test", + } + + context.sf.query_more.return_value = { + "totalSize": 2, + "done": True, + "records": [ + {"Id": "003000000000002", "LastName": "De Vries", "Email": None} + ], + } + + query_op = RestApiQueryOperation( + sobject="Contact", + fields=["Id", "LastName", "Email"], + api_options={}, + context=context, + query="SELECT Id, LastName, Email FROM Contact", + ) + + query_op.query() + + assert query_op.job_result == DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 2, 0 + ) + assert list(query_op.get_results()) == [ + ["003000000000001", "Narvaez", "wayne@example.com"], + ["003000000000002", "De Vries", ""], + ] + + +class TestRestApiDmlOperation: + @responses.activate + def test_insert_dml_operation(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + + recs = [["Fred", "Narvaez"], [None, "De Vries"], ["Hiroko", "Aito"]] + + dml_op = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 2}, + context=task, + fields=["FirstName", "LastName"], + ) + + dml_op.start() + dml_op.load_records(iter(recs)) + dml_op.end() + + assert dml_op.job_result == DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 3, 0 + ) + assert list(dml_op.get_results()) == [ + DataOperationResult("003000000000001", True, "", True), + DataOperationResult("003000000000002", True, "", True), + DataOperationResult("003000000000003", True, "", True), + ] + + @responses.activate + def test_get_prev_record_values(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = { + "records": [ + {"LastName": "Test1", "Id": "Id1"}, + {"LastName": "Test2", "Id": "Id2"}, + ] + } + expected_record_values = [["Test1", "Id1"], ["Test2", "Id2"]] + expected_relevant_fields = ("Id", "LastName") + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + prev_record_values, relevant_fields = step.get_prev_record_values(records) + + assert sorted(map(sorted, prev_record_values)) == sorted( + map(sorted, expected_record_values) + ) + assert set(relevant_fields) == set(expected_relevant_fields) + + @responses.activate + def test_select_records_standard_strategy_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + ) + + results = { + "records": [ + {"Id": "003000000000001"}, + ], + "done": True, + } + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + + @responses.activate + def test_select_records_standard_strategy_success_pagination(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + ) + + # Set up pagination: First call returns done=False, second call returns done=True + step.sf.restful = mock.Mock( + side_effect=[ + { + "records": [{"Id": "003000000000001"}, {"Id": "003000000000002"}], + "done": False, # Pagination in progress + "nextRecordsUrl": "/services/data/vXX.X/query/next-records", + }, + ] + ) + + step.sf.query_more = mock.Mock( + side_effect=[ + {"records": [{"Id": "003000000000003"}], "done": True} # Final page + ] + ) + + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 + ) + + @responses.activate + def test_select_records_standard_strategy_failure__no_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + ) + + results = {"records": [], "done": True} + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + + @responses.activate + def test_select_records_user_selection_filter_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter='WHERE LastName IN ("Sample Name")', + ) + + results = { + "records": [ + {"Id": "003000000000001"}, + ], + "done": True, + } + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + + @responses.activate + def test_select_records_user_selection_filter_order_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter="ORDER BY CreatedDate", + ) + + results = { + "records": [ + {"Id": "003000000000003"}, + {"Id": "003000000000001"}, + {"Id": "003000000000002"}, + ], + "done": True, + } + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results are in the order of user_query + assert results[0].id == "003000000000003" + assert results[1].id == "003000000000001" + assert results[2].id == "003000000000002" + + @responses.activate + def test_select_records_user_selection_filter_failure(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter="MALFORMED FILTER", # Applying malformed filter + ) + + step.sf.restful = mock.Mock() + step.sf.restful.side_effect = Exception("MALFORMED QUERY") + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + with pytest.raises(Exception): + step.select_records(records) + + @responses.activate + def test_select_records_similarity_strategy_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + results_first_call = { + "records": [ + { + "Id": "003000000000001", + "Name": "Jawad", + "Email": "mjawadtp@example.com", + }, + { + "Id": "003000000000002", + "Name": "Aditya", + "Email": "aditya@example.com", + }, + { + "Id": "003000000000003", + "Name": "Tom Cruise", + "Email": "tomcruise@example.com", + }, + ], + "done": True, + } + + # First call returns `results_first_call`, second call returns an empty list + step.sf.restful = mock.Mock( + side_effect=[results_first_call, {"records": [], "done": True}] + ) + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 + ) + + @responses.activate + def test_select_records_similarity_strategy_failure__no_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + results = {"records": [], "done": True} + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter( + [ + ["Id: 1", "Jawad", "mjawadtp@example.com"], + ["Id: 2", "Aditya", "aditya@example.com"], + ["Id: 2", "Tom", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + + @responses.activate + def test_select_records_similarity_strategy_parent_level_records__polymorphic(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, ) - - query_op.query() - - assert query_op.job_result == DataOperationJobResult( - DataOperationStatus.SUCCESS, [], 2, 0 + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, ) - assert list(query_op.get_results()) == [ - ["003000000000001", "Narvaez", "wayne@example.com"], - ["003000000000002", "De Vries", ""], - ] - - def test_query_batches(self): - context = mock.Mock() - context.sf.query.return_value = { - "totalSize": 2, - "done": False, - "records": [ - { - "Id": "003000000000001", - "LastName": "Narvaez", - "Email": "wayne@example.com", - } + step = RestApiDmlOperation( + sobject="Event", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task, + fields=[ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + "WhoId", ], - "nextRecordsUrl": "test", - } + selection_strategy=SelectStrategy.SIMILARITY, + ) - context.sf.query_more.return_value = { - "totalSize": 2, - "done": True, - "records": [ - {"Id": "003000000000002", "LastName": "De Vries", "Email": None} - ], - } + step.sf.restful = mock.Mock( + side_effect=[ + { + "records": [ + { + "Id": "003000000000001", + "Subject": "Sample Event 1", + "Who": { + "attributes": {"type": "Contact"}, + "Id": "abcd1234", + "Name": "Sample Contact", + "Email": "contact@example.com", + }, + }, + { + "Id": "003000000000002", + "Subject": "Sample Event 2", + "Who": { + "attributes": {"type": "Lead"}, + "Id": "qwer1234", + "Name": "Sample Lead", + "Company": "Salesforce", + }, + }, + ], + "done": True, + }, + ] + ) - query_op = RestApiQueryOperation( - sobject="Contact", - fields=["Id", "LastName", "Email"], - api_options={}, - context=context, - query="SELECT Id, LastName, Email FROM Contact", + records = iter( + [ + [ + "Sample Event 1", + "Sample Contact", + "contact@example.com", + "", + "", + "poiu1234", + ], + ["Sample Event 2", "", "", "Sample Lead", "Salesforce", "lkjh1234"], + ] ) + step.start() + step.select_records(records) + step.end() - query_op.query() + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) - assert query_op.job_result == DataOperationJobResult( - DataOperationStatus.SUCCESS, [], 2, 0 + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False ) - assert list(query_op.get_results()) == [ - ["003000000000001", "Narvaez", "wayne@example.com"], - ["003000000000002", "De Vries", ""], - ] - -class TestRestApiDmlOperation: @responses.activate - def test_insert_dml_operation(self): + def test_select_records_similarity_strategy_parent_level_records__non_polymorphic( + self, + ): mock_describe_calls() task = _make_task( LoadData, @@ -798,34 +2457,66 @@ def test_insert_dml_operation(self): json=[{"id": "003000000000003", "success": True}], status=200, ) - - recs = [["Fred", "Narvaez"], [None, "De Vries"], ["Hiroko", "Aito"]] - - dml_op = RestApiDmlOperation( + step = RestApiDmlOperation( sobject="Contact", - operation=DataOperationType.INSERT, - api_options={"batch_size": 2}, + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, context=task, - fields=["FirstName", "LastName"], + fields=["Name", "Account.Name", "Account.AccountNumber", "AccountId"], + selection_strategy=SelectStrategy.SIMILARITY, ) - dml_op.start() - dml_op.load_records(iter(recs)) - dml_op.end() + step.sf.restful = mock.Mock( + side_effect=[ + { + "records": [ + { + "Id": "003000000000001", + "Name": "Sample Contact 1", + "Account": { + "attributes": {"type": "Account"}, + "Id": "abcd1234", + "Name": "Sample Account", + "AccountNumber": 123456, + }, + }, + { + "Id": "003000000000002", + "Name": "Sample Contact 2", + "Account": None, + }, + ], + "done": True, + }, + ] + ) - assert dml_op.job_result == DataOperationJobResult( - DataOperationStatus.SUCCESS, [], 3, 0 + records = iter( + [ + ["Sample Contact 3", "Sample Account", "123456", "poiu1234"], + ["Sample Contact 4", "", "", ""], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False ) - assert list(dml_op.get_results()) == [ - DataOperationResult("003000000000001", True, "", True), - DataOperationResult("003000000000002", True, "", True), - DataOperationResult("003000000000003", True, "", True), - ] @responses.activate - def test_get_prev_record_values(self): + def test_select_records_similarity_strategy_priority_fields(self): mock_describe_calls() - task = _make_task( + task_1 = _make_task( LoadData, { "options": { @@ -834,8 +2525,20 @@ def test_get_prev_record_values(self): } }, ) - task.project_config.project__package__api_version = CURRENT_SF_API_VERSION - task._init_task() + task_1.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task_1._init_task() + + task_2 = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task_2.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task_2._init_task() responses.add( responses.POST, @@ -852,32 +2555,341 @@ def test_get_prev_record_values(self): json=[{"id": "003000000000003", "success": True}], status=200, ) + step_1 = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task_1, + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={"Name": "Name", "Email": "Email"}, + ) + + step_2 = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task_2, + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={ + "Account.Name": "Account.Name", + "Account.AccountNumber": "Account.AccountNumber", + }, + ) + + sample_response = [ + { + "records": [ + { + "Id": "003000000000001", + "Name": "Bob The Builder", + "Email": "bob@yahoo.org", + "Account": { + "attributes": {"type": "Account"}, + "Id": "abcd1234", + "Name": "Jawad TP", + "AccountNumber": 567890, + }, + }, + { + "Id": "003000000000002", + "Name": "Tom Cruise", + "Email": "tom@exmaple.com", + "Account": { + "attributes": {"type": "Account"}, + "Id": "qwer1234", + "Name": "Aditya B", + "AccountNumber": 123456, + }, + }, + ], + "done": True, + }, + ] + + step_1.sf.restful = mock.Mock(side_effect=sample_response) + step_2.sf.restful = mock.Mock(side_effect=sample_response) + + records = iter( + [ + ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456", "poiu1234"], + ] + ) + records_1, records_2 = tee(records) + step_1.start() + step_1.select_records(records_1) + step_1.end() + + step_2.start() + step_2.select_records(records_2) + step_2.end() + + # Get the results and assert their properties + results_1 = list(step_1.get_results()) + results_2 = list(step_2.get_results()) + assert ( + len(results_1) == 1 + ) # Expect 1 results (matching the input records count) + assert ( + len(results_2) == 1 + ) # Expect 1 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + # Prioritizes Name and Email + assert results_1[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + # Prioritizes Account.Name and Account.AccountNumber + assert results_2[0] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @responses.activate + def test_process_insert_records_success(self): + # Mock describe calls + mock_describe_calls() + + # Create a task and mock project config + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + # Prepare inputs + insert_records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tomcruise@example.com"], + ] + ) + selected_records = [None, None, None] + + # Mock fields splitting + insert_fields = ["Name", "Email"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ) as split_mock: + # Mock the instance of RestApiDmlOperation + mock_rest_api_dml_operation = mock.create_autospec( + RestApiDmlOperation, instance=True + ) + mock_rest_api_dml_operation.results = [ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + {"id": "003000000000003", "success": True}, + ] + + with mock.patch( + "cumulusci.tasks.bulkdata.step.RestApiDmlOperation", + return_value=mock_rest_api_dml_operation, + ): + # Call the function + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Email"], + ) + step._process_insert_records(insert_records, selected_records) + + # Assert the mocked splitting is called + split_mock.assert_called_once_with(fields=["Name", "Email"]) + + # Validate that `selected_records` is updated correctly + assert selected_records == [ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + {"id": "003000000000003", "success": True}, + ] + + # Validate the operation sequence + mock_rest_api_dml_operation.start.assert_called_once() + mock_rest_api_dml_operation.load_records.assert_called_once_with( + insert_records + ) + mock_rest_api_dml_operation.end.assert_called_once() + + @responses.activate + def test_process_insert_records_failure(self): + # Mock describe calls + mock_describe_calls() + + # Create a task and mock project config + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + # Prepare inputs + insert_records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ] + ) + selected_records = [None, None] + + # Mock fields splitting + insert_fields = ["Name", "Email"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ) as split_mock: + # Mock the instance of RestApiDmlOperation + mock_rest_api_dml_operation = mock.create_autospec( + RestApiDmlOperation, instance=True + ) + mock_rest_api_dml_operation.results = ( + None # Simulate no results due to an exception + ) + + # Simulate an exception during processing results + mock_rest_api_dml_operation.load_records.side_effect = BulkDataException( + "Simulated failure" + ) + + with mock.patch( + "cumulusci.tasks.bulkdata.step.RestApiDmlOperation", + return_value=mock_rest_api_dml_operation, + ): + # Call the function and verify that it raises the expected exception + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Email"], + ) + with pytest.raises(BulkDataException): + step._process_insert_records(insert_records, selected_records) + + # Assert the mocked splitting is called + split_mock.assert_called_once_with(fields=["Name", "Email"]) + + # Validate that `selected_records` remains unchanged + assert selected_records == [None, None] + + # Validate the operation sequence + mock_rest_api_dml_operation.start.assert_called_once() + mock_rest_api_dml_operation.load_records.assert_called_once_with( + insert_records + ) + mock_rest_api_dml_operation.end.assert_not_called() + + @responses.activate + def test_select_records_similarity_strategy__insert_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + # Create step with threshold step = RestApiDmlOperation( sobject="Contact", operation=DataOperationType.UPSERT, - api_options={"batch_size": 10, "update_key": "LastName"}, + api_options={"batch_size": 10}, context=task, - fields=["LastName"], + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0.3, ) - results = { + results_select_call = { "records": [ - {"LastName": "Test1", "Id": "Id1"}, - {"LastName": "Test2", "Id": "Id2"}, - ] + { + "Id": "003000000000001", + "Name": "Jawad", + "Email": "mjawadtp@example.com", + }, + ], + "done": True, } - expected_record_values = [["Test1", "Id1"], ["Test2", "Id2"]] - expected_relevant_fields = ("Id", "LastName") - step.sf.query = mock.Mock() - step.sf.query.return_value = results - records = iter([["Test1"], ["Test2"], ["Test3"]]) - prev_record_values, relevant_fields = step.get_prev_record_values(records) - assert sorted(map(sorted, prev_record_values)) == sorted( - map(sorted, expected_record_values) + results_insert_call = [ + {"id": "003000000000002", "success": True, "created": True}, + {"id": "003000000000003", "success": True, "created": True}, + ] + + step.sf.restful = mock.Mock( + side_effect=[results_select_call, results_insert_call] + ) + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 ) - assert set(relevant_fields) == set(expected_relevant_fields) @responses.activate def test_insert_dml_operation__boolean_conversion(self): @@ -1355,6 +3367,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): context=context, api=DataApi.BULK, volume=1, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ) assert op == bulk_dml.return_value @@ -1364,6 +3378,11 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): fields=["Name"], api_options={}, context=context, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, + selection_priority_fields=None, + content_type=None, + threshold=None, ) op = get_dml_operation( @@ -1374,6 +3393,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): context=context, api=DataApi.REST, volume=1, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ) assert op == rest_dml.return_value @@ -1383,6 +3404,11 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): fields=["Name"], api_options={}, context=context, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, + selection_priority_fields=None, + content_type=None, + threshold=None, ) @mock.patch("cumulusci.tasks.bulkdata.step.BulkApiDmlOperation") @@ -1545,3 +3571,122 @@ def test_cleanup_date_strings__upsert_update(self, operation): "Name": "Bill", "attributes": {"type": "Test__c"}, }, json_out + + +@pytest.mark.parametrize( + "query_fields, expected", + [ + # Test with simple field names + (["Id", "Name", "Email"], ["Id", "Name", "Email"]), + # Test with TYPEOF fields (polymorphic fields) + ( + [ + "Subject", + { + "Who": [ + {"Contact": ["Name", "Email"]}, + {"Lead": ["Name", "Company"]}, + ] + }, + ], + [ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ], + ), + # Test with mixed simple and TYPEOF fields + ( + ["Subject", {"Who": [{"Contact": ["Email"]}]}, "Account.Name"], + ["Subject", "Who.Contact.Email", "Account.Name"], + ), + # Test with an empty list + ([], []), + ], +) +def test_extract_flattened_headers(query_fields, expected): + result = extract_flattened_headers(query_fields) + assert result == expected + + +@pytest.mark.parametrize( + "record, headers, expected", + [ + # Test with simple field matching + ( + {"Id": "001", "Name": "John Doe", "Email": "john@example.com"}, + ["Id", "Name", "Email"], + ["001", "John Doe", "john@example.com"], + ), + # Test with lookup fields and missing values + ( + { + "Who": { + "attributes": {"type": "Contact"}, + "Name": "Jane Doe", + "Email": "johndoe@org.com", + "Number": 10, + } + }, + ["Who.Contact.Name", "Who.Contact.Email", "Who.Contact.Number"], + ["Jane Doe", "johndoe@org.com", "10"], + ), + # Test with non-matching ref_obj type + ( + {"Who": {"attributes": {"type": "Contact"}, "Email": "jane@contact.com"}}, + ["Who.Lead.Email"], + [""], + ), + # Test with mixed fields and nested lookups + ( + { + "Who": {"attributes": {"type": "Lead"}, "Name": "John Doe"}, + "Email": "john@example.com", + }, + ["Who.Lead.Name", "Who.Lead.Company", "Email"], + ["John Doe", "", "john@example.com"], + ), + # Test with mixed fields and nested lookups + ( + { + "Who": {"attributes": {"type": "Lead"}, "Name": "John Doe"}, + "Email": "john@example.com", + }, + ["What.Account.Name"], + [""], + ), + # Test with empty record + ({}, ["Id", "Name"], ["", ""]), + ], +) +def test_flatten_record(record, headers, expected): + result = flatten_record(record, headers) + assert result == expected + + +@pytest.mark.parametrize( + "priority_fields, fields, expected", + [ + # Test with priority fields matching + ( + {"Id": "Id", "Name": "Name"}, + ["Id", "Name", "Email"], + [HIGH_PRIORITY_VALUE, HIGH_PRIORITY_VALUE, LOW_PRIORITY_VALUE], + ), + # Test with no priority fields provided + (None, ["Id", "Name", "Email"], [1, 1, 1]), + # Test with empty priority fields dictionary + ({}, ["Id", "Name", "Email"], [1, 1, 1]), + # Test with some fields not in priority_fields + ( + {"Id": "Id"}, + ["Id", "Name", "Email"], + [HIGH_PRIORITY_VALUE, LOW_PRIORITY_VALUE, LOW_PRIORITY_VALUE], + ), + ], +) +def test_assign_weights(priority_fields, fields, expected): + result = assign_weights(priority_fields, fields) + assert result == expected diff --git a/cumulusci/tasks/bulkdata/tests/utils.py b/cumulusci/tasks/bulkdata/tests/utils.py index 173f4c6122..c0db0f9515 100644 --- a/cumulusci/tasks/bulkdata/tests/utils.py +++ b/cumulusci/tasks/bulkdata/tests/utils.py @@ -98,6 +98,9 @@ def get_prev_record_values(self, records): def load_records(self, records): self.records.extend(records) + def select_records(self, records): + pass + def get_results(self): return iter(self.results) diff --git a/cumulusci/tasks/bulkdata/utils.py b/cumulusci/tasks/bulkdata/utils.py index 082277fb16..cee6a4ab66 100644 --- a/cumulusci/tasks/bulkdata/utils.py +++ b/cumulusci/tasks/bulkdata/utils.py @@ -5,15 +5,38 @@ from contextlib import contextmanager, nullcontext from pathlib import Path +from requests.structures import CaseInsensitiveDict as RequestsCaseInsensitiveDict from simple_salesforce import Salesforce from sqlalchemy import Boolean, Column, MetaData, Table, Unicode, inspect from sqlalchemy.engine.base import Connection from sqlalchemy.orm import Session, mapper +from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.utils.iterators import iterate_in_chunks +class DataApi(StrEnum): + """Enum defining requested Salesforce data API for an operation.""" + + BULK = "bulk" + REST = "rest" + SMART = "smart" + + +class CaseInsensitiveDict(RequestsCaseInsensitiveDict): + def __init__(self, *args, **kwargs): + self._canonical_keys = {} + super().__init__(*args, **kwargs) + + def canonical_key(self, name): + return self._canonical_keys[name.lower()] + + def __setitem__(self, key, value): + super().__setitem__(key, value) + self._canonical_keys[key.lower()] = key + + class SqlAlchemyMixin: logger: logging.Logger metadata: MetaData diff --git a/docs/data.md b/docs/data.md index 063e3f33f5..9badb404e8 100644 --- a/docs/data.md +++ b/docs/data.md @@ -250,6 +250,131 @@ Insert Accounts: Whenever `update_key` is supplied, the action must be `upsert` and vice versa. +--- + +### Selects + +The `select` functionality is designed to streamline the mapping process by enabling the selection of specific records directly from Salesforce for lookups. This feature is particularly useful when dealing with non-insertable Salesforce objects and ensures that pre-existing records are used rather than inserting new ones. The selection process is highly customizable with various strategies, filters, and additional capabilities that provide flexibility and precision in data mapping. + +```yaml +Account: + sf_object: Account + fields: + - Name + - Description + +Contact: + sf_object: Contact + fields: + - LastName + - Email + lookups: + AccountId: + table: Account + +Lead: + sf_object: Lead + fields: + - LastName + - Company + +Event: + sf_object: Event + action: select + select_options: + strategy: similarity + filter: WHERE Subject LIKE 'Meeting%' + priority_fields: + - Subject + - WhoId + threshold: 0.3 + fields: + - Subject + - DurationInMinutes + - ActivityDateTime + lookups: + WhoId: + table: + - Contact + - Lead + WhatId: + table: Account +``` + +--- + +#### Selection Strategies + +The `strategy` parameter determines how records are selected from the target org. It is **optional**; if no strategy is specified, the `standard` strategy will be applied by default. + +- **`standard` Strategy:** + The `standard` selection strategy retrieves records from target org in the same order as they appear, applying any specified filters and sorting criteria. This method ensures that records are selected without any prioritization based on similarity or randomness, offering a straightforward way to pull the desired data. + +- **`similarity` Strategy:** + The `similarity` strategy is used when you need to find records in the target org that closely resemble the records defined in your SQL file. This strategy performs a similarity match between the records in the SQL file and those in the target org. In addition to comparing the fields of the record itself, this strategy includes the fields of parent records (up to one level) for a more granular and accurate match. + +- **`random` Strategy:** + The `random` selection strategy randomly assigns records picked from the target org. This method is useful when the selection order does not matter, and you simply need to fetch records in a randomized manner. + +--- + +#### Selection Filters + +The selection `filter` provides a flexible way to refine the records selected by using any functionality supported by SOQL. This includes filtering, sorting, and limiting records based on specific conditions, such as using the `WHERE` clause to filter records by field values, the `ORDER BY` clause to sort records in ascending or descending order, and the `LIMIT` clause to restrict the number of records returned. Essentially, any feature available in SOQL for record selection is supported here, allowing you to tailor the selection process to your precise needs and ensuring only the relevant records are included in the mapping process. + +This parameter is **optional**; and if not specified, no filter will apply. + +--- + +#### Priority Fields + +The `priority_fields` feature enables you to specify a subset of fields in your mapping step that will have more weight during the similarity matching process. When similarity matching is performed, these priority fields will be given greater importance compared to other fields, allowing for a more refined match. + +This parameter is **optional**; and if not specified, all fields will be considered with same priority. + +This feature is particularly useful when certain fields are more critical in defining the identity or relevance of a record, ensuring that these fields have a stronger influence in the selection process. + +--- + +#### Threshold + +This feature allows you to either select or insert records based on a similarity threshold. When using the `select` action with the `similarity` strategy, you can specify a `threshold` value between `0` and `1`, where `0` represents a perfect match and `1` signifies no similarity. + +- **Select Records:** + If a record from your SQL file has a similarity score below the threshold, it will be selected from the target org. + +- **Insert Records:** + If the similarity score exceeds the threshold, the record will be inserted into the target org instead of being selected. + +This parameter is **optional**; if not specified, no threshold will be applied and all records will default to be selected. + +This feature is particularly useful during version upgrades, where records that closely match can be selected, while those that do not match sufficiently can be inserted into the target org. + +--- + +#### Example + +To demonstrate the `select` functionality, consider the example of the `Event` entity, which utilizes the `similarity` strategy, a filter condition, and other advanced options to select matching records effectively as given in the yaml above. + +1. **Basic Object Configuration**: + + - The `Account`, `Contact`, and `Lead` objects are configured for straightforward field mapping. + - A `lookup` is defined on the `Contact` object to map `AccountId` to the `Account` table. + +2. **Advanced `Event` Object Mapping**: + - **Action**: The `Event` object uses the `select` action, meaning records are selected rather than inserted. + - **Strategy**: The `similarity` strategy matches `Event` records in target org that are similar to those defined in the SQL file. + - **Filter**: Only `Event` records with a `Subject` field starting with "Meeting" are considered. + - **Priority Fields**: The `Subject` and `WhoId` fields are given more weight during similarity matching. + - **Threshold**: A similarity score of 0.3 is used to determine whether records are selected or inserted. + - **Lookups**: + - The `WhoId` field looks up records from either the `Contact` or `Lead` objects. + - The `WhatId` field looks up records from the `Account` object. + +This example highlights how the `select` functionality can be applied in real-world scenarios, such as selecting `Event` records that meet specific criteria while considering similarity, filters, and priority fields. + +--- + ### Database Mapping CumulusCI's definition format includes considerable flexibility for use diff --git a/pyproject.toml b/pyproject.toml index 585e7f4654..7dec9eedab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", ] dependencies = [ + "annoy", "click>=8.1", "cryptography", "python-dateutil", @@ -34,6 +35,8 @@ dependencies = [ "defusedxml", "lxml", "MarkupSafe", + "numpy", + "pandas", "psutil", "pydantic<2", "PyJWT", @@ -50,6 +53,7 @@ dependencies = [ "rst2ansi>=0.1.5", "salesforce-bulk", "sarge", + "scikit-learn", "selenium<4", "simple-salesforce==1.11.4", "snowfakery>=4.0.0",