diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 1ebdd8bcf1..78c9695250 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -685,6 +685,115 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records") + def upsert( + self, + df: pa.Table, + join_cols: Optional[List[str]] = None, + when_matched_update_all: bool = True, + when_not_matched_insert_all: bool = True, + case_sensitive: bool = True, + ) -> UpsertResult: + """Shorthand API for performing an upsert to an iceberg table. + + Args: + + df: The input dataframe to upsert with the table's data. + join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. + when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing + when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table + case_sensitive: Bool indicating if the match should be case-sensitive + + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it + + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it + + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) + + + Returns: + An UpsertResult class (contains details of rows updated and inserted) + """ + try: + import pyarrow as pa # noqa: F401 + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + from pyiceberg.io.pyarrow import expression_to_pyarrow + from pyiceberg.table import upsert_util + + if join_cols is None: + join_cols = [] + for field_id in self.table_metadata.schema().identifier_field_ids: + col = self.table_metadata.schema().find_column_name(field_id) + if col is not None: + join_cols.append(col) + else: + raise ValueError(f"Field-ID could not be found: {join_cols}") + + if len(join_cols) == 0: + raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") + + if not when_matched_update_all and not when_not_matched_insert_all: + raise ValueError("no upsert options selected...exiting") + + if upsert_util.has_duplicate_rows(df, join_cols): + raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") + + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible + + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) + + # get list of rows that exist so we don't have to load the entire target table + matched_predicate = upsert_util.create_match_filter(df, join_cols) + matched_iceberg_table = self._table.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + + update_row_cnt = 0 + insert_row_cnt = 0 + + if when_matched_update_all: + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed + # this extra step avoids unnecessary IO and writes + rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) + + update_row_cnt = len(rows_to_update) + + if len(rows_to_update) > 0: + # build the match predicate filter + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) + + self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) + + if when_not_matched_insert_all: + expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) + expr_match_arrow = expression_to_pyarrow(expr_match_bound) + rows_to_insert = df.filter(~expr_match_arrow) + + insert_row_cnt = len(rows_to_insert) + + if insert_row_cnt > 0: + self.append(rows_to_insert) + + return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True ) -> None: @@ -1149,73 +1258,14 @@ def upsert( Returns: An UpsertResult class (contains details of rows updated and inserted) """ - try: - import pyarrow as pa # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - from pyiceberg.io.pyarrow import expression_to_pyarrow - from pyiceberg.table import upsert_util - - if join_cols is None: - join_cols = [] - for field_id in self.schema().identifier_field_ids: - col = self.schema().find_column_name(field_id) - if col is not None: - join_cols.append(col) - else: - raise ValueError(f"Field-ID could not be found: {join_cols}") - - if len(join_cols) == 0: - raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") - - if not when_matched_update_all and not when_not_matched_insert_all: - raise ValueError("no upsert options selected...exiting") - - if upsert_util.has_duplicate_rows(df, join_cols): - raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") - - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible - - downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_pyarrow_schema_compatible( - self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) - - # get list of rows that exist so we don't have to load the entire target table - matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() - - update_row_cnt = 0 - insert_row_cnt = 0 - with self.transaction() as tx: - if when_matched_update_all: - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed - # this extra step avoids unnecessary IO and writes - rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) - - update_row_cnt = len(rows_to_update) - - if len(rows_to_update) > 0: - # build the match predicate filter - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) - - tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) - - if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - rows_to_insert = df.filter(~expr_match_arrow) - - insert_row_cnt = len(rows_to_insert) - - if insert_row_cnt > 0: - tx.append(rows_to_insert) - - return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + return tx.upsert( + df=df, + join_cols=join_cols, + when_matched_update_all=when_matched_update_all, + when_not_matched_insert_all=when_not_matched_insert_all, + case_sensitive=case_sensitive, + ) def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 5de4a61187..4c853803d9 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,7 +23,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import And, EqualTo, Reference +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference from pyiceberg.expressions.literals import LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema @@ -511,6 +511,70 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: tbl.upsert(df) +def test_transaction(catalog: Catalog) -> None: + """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is + rolled back.""" + identifier = "default.test_merge_source_dups" + _drop_table(catalog, identifier) + + ctx = SessionContext() + + table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier) + df_before_transaction = table.scan().to_arrow() + + source_df = gen_source_dataset(5, 15, False, True, ctx) + + with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"): + with table.transaction() as tx: + tx.delete(delete_filter=AlwaysTrue()) + tx.upsert(df=source_df, join_cols=["order_id"]) + + df = table.scan().to_arrow() + + assert df_before_transaction == df + + +@pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") +def test_transaction_multiple_upserts(catalog: Catalog) -> None: + identifier = "default.test_multi_upsert" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "name", StringType(), required=True), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + # Define exact schema: required int32 and required string + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + ] + ) + + tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema)) + + df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) + + with tbl.transaction() as txn: + txn.append(df) + txn.delete(delete_filter="id = 1") + txn.append(df) + # This should read the uncommitted changes? + txn.upsert(df, join_cols=["id"]) + + # txn.upsert(df, join_cols=["id"]) + + result = tbl.scan().to_arrow().to_pylist() + assert sorted(result, key=lambda x: x["id"]) == [ + {"id": 1, "name": "Alicia"}, + {"id": 2, "name": "Bob"}, + ] + + def test_upsert_with_nulls(catalog: Catalog) -> None: identifier = "default.test_upsert_with_nulls" _drop_table(catalog, identifier)