Skip to content

Move implementation of upsert from Table to Transaction #1817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 116 additions & 66 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,115 @@ def delete(
if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed:
warnings.warn("Delete operation did not match any records")

def upsert(
self,
df: pa.Table,
join_cols: Optional[List[str]] = None,
when_matched_update_all: bool = True,
when_not_matched_insert_all: bool = True,
case_sensitive: bool = True,
) -> UpsertResult:
"""Shorthand API for performing an upsert to an iceberg table.

Args:

df: The input dataframe to upsert with the table's data.
join_cols: Columns to join on, if not provided, it will use the identifier-field-ids.
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
case_sensitive: Bool indicating if the match should be case-sensitive

To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids

Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Existing row found → Update it
New row found → Insert it

Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
Existing row found → Do nothing (no updates)
New row found → Insert it

Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
Existing row found → Update it
New row found → Do nothing (no inserts)

Case 4: Both Parameters = False (No Merge Effect)
Existing row found → Do nothing
New row found → Do nothing
(Function effectively does nothing)


Returns:
An UpsertResult class (contains details of rows updated and inserted)
"""
try:
import pyarrow as pa # noqa: F401
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import expression_to_pyarrow
from pyiceberg.table import upsert_util

if join_cols is None:
join_cols = []
for field_id in self.table_metadata.schema().identifier_field_ids:
col = self.table_metadata.schema().find_column_name(field_id)
if col is not None:
join_cols.append(col)
else:
raise ValueError(f"Field-ID could not be found: {join_cols}")

if len(join_cols) == 0:
raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.")

if not when_matched_update_all and not when_not_matched_insert_all:
raise ValueError("no upsert options selected...exiting")

if upsert_util.has_duplicate_rows(df, join_cols):
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
matched_iceberg_table = self._table.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()

update_row_cnt = 0
insert_row_cnt = 0

if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)

update_row_cnt = len(rows_to_update)

if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)

self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)

if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
rows_to_insert = df.filter(~expr_match_arrow)

insert_row_cnt = len(rows_to_insert)

if insert_row_cnt > 0:
self.append(rows_to_insert)

return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)

def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
) -> None:
Expand Down Expand Up @@ -1149,73 +1258,14 @@ def upsert(
Returns:
An UpsertResult class (contains details of rows updated and inserted)
"""
try:
import pyarrow as pa # noqa: F401
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import expression_to_pyarrow
from pyiceberg.table import upsert_util

if join_cols is None:
join_cols = []
for field_id in self.schema().identifier_field_ids:
col = self.schema().find_column_name(field_id)
if col is not None:
join_cols.append(col)
else:
raise ValueError(f"Field-ID could not be found: {join_cols}")

if len(join_cols) == 0:
raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.")

if not when_matched_update_all and not when_not_matched_insert_all:
raise ValueError("no upsert options selected...exiting")

if upsert_util.has_duplicate_rows(df, join_cols):
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()

update_row_cnt = 0
insert_row_cnt = 0

with self.transaction() as tx:
if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)

update_row_cnt = len(rows_to_update)

if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)

tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)

if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
rows_to_insert = df.filter(~expr_match_arrow)

insert_row_cnt = len(rows_to_insert)

if insert_row_cnt > 0:
tx.append(rows_to_insert)

return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
return tx.upsert(
df=df,
join_cols=join_cols,
when_matched_update_all=when_matched_update_all,
when_not_matched_insert_all=when_not_matched_insert_all,
case_sensitive=case_sensitive,
)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Expand Down
66 changes: 65 additions & 1 deletion tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import And, EqualTo, Reference
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
Expand Down Expand Up @@ -511,6 +511,70 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None:
tbl.upsert(df)


def test_transaction(catalog: Catalog) -> None:
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
rolled back."""
identifier = "default.test_merge_source_dups"
_drop_table(catalog, identifier)

ctx = SessionContext()

table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier)
df_before_transaction = table.scan().to_arrow()

source_df = gen_source_dataset(5, 15, False, True, ctx)

with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"):
with table.transaction() as tx:
tx.delete(delete_filter=AlwaysTrue())
tx.upsert(df=source_df, join_cols=["order_id"])

df = table.scan().to_arrow()

assert df_before_transaction == df


@pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction")
def test_transaction_multiple_upserts(catalog: Catalog) -> None:
identifier = "default.test_multi_upsert"
_drop_table(catalog, identifier)

schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=True),
identifier_field_ids=[1],
)

tbl = catalog.create_table(identifier, schema=schema)

# Define exact schema: required int32 and required string
arrow_schema = pa.schema(
[
pa.field("id", pa.int32(), nullable=False),
pa.field("name", pa.string(), nullable=False),
]
)

tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema))

df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema)

with tbl.transaction() as txn:
txn.append(df)
txn.delete(delete_filter="id = 1")
txn.append(df)
# This should read the uncommitted changes?
txn.upsert(df, join_cols=["id"])

# txn.upsert(df, join_cols=["id"])

result = tbl.scan().to_arrow().to_pylist()
assert sorted(result, key=lambda x: x["id"]) == [
{"id": 1, "name": "Alicia"},
{"id": 2, "name": "Bob"},
]


def test_upsert_with_nulls(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_nulls"
_drop_table(catalog, identifier)
Expand Down