Skip to content

Commit 6351066

Browse files
kevinjqliumattmartin14Fokko
authored
Add table upsert support (#1660)
Closes #402 This PR adds the `upsert` function to the `Table` class and supports the following upsert operations: - when matched update all - when not matched insert all This PR is a remake of #1534 due to some infrastructure issues. For additional context, please refer to that PR. --------- Co-authored-by: VAA7RQ <[email protected]> Co-authored-by: VAA7RQ <[email protected]> Co-authored-by: mattmartin14 <[email protected]> Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 6d1c30c commit 6351066

File tree

5 files changed

+539
-7
lines changed

5 files changed

+539
-7
lines changed

poetry.lock

Lines changed: 27 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyiceberg/table/__init__.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@
153153
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
154154

155155

156+
@dataclass()
157+
class UpsertResult:
158+
"""Summary the upsert operation."""
159+
160+
rows_updated: int = 0
161+
rows_inserted: int = 0
162+
163+
156164
class TableProperties:
157165
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
158166
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
@@ -1092,6 +1100,78 @@ def name_mapping(self) -> Optional[NameMapping]:
10921100
"""Return the table's field-id NameMapping."""
10931101
return self.metadata.name_mapping()
10941102

1103+
def upsert(
1104+
self, df: pa.Table, join_cols: list[str], when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True
1105+
) -> UpsertResult:
1106+
"""Shorthand API for performing an upsert to an iceberg table.
1107+
1108+
Args:
1109+
1110+
df: The input dataframe to upsert with the table's data.
1111+
join_cols: The columns to join on. These are essentially analogous to primary keys
1112+
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
1113+
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
1114+
1115+
Example Use Cases:
1116+
Case 1: Both Parameters = True (Full Upsert)
1117+
Existing row found → Update it
1118+
New row found → Insert it
1119+
1120+
Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
1121+
Existing row found → Do nothing (no updates)
1122+
New row found → Insert it
1123+
1124+
Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
1125+
Existing row found → Update it
1126+
New row found → Do nothing (no inserts)
1127+
1128+
Case 4: Both Parameters = False (No Merge Effect)
1129+
Existing row found → Do nothing
1130+
New row found → Do nothing
1131+
(Function effectively does nothing)
1132+
1133+
1134+
Returns:
1135+
An UpsertResult class (contains details of rows updated and inserted)
1136+
"""
1137+
from pyiceberg.table import upsert_util
1138+
1139+
if not when_matched_update_all and not when_not_matched_insert_all:
1140+
raise ValueError("no upsert options selected...exiting")
1141+
1142+
if upsert_util.has_duplicate_rows(df, join_cols):
1143+
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")
1144+
1145+
# get list of rows that exist so we don't have to load the entire target table
1146+
matched_predicate = upsert_util.create_match_filter(df, join_cols)
1147+
matched_iceberg_table = self.scan(row_filter=matched_predicate).to_arrow()
1148+
1149+
update_row_cnt = 0
1150+
insert_row_cnt = 0
1151+
1152+
with self.transaction() as tx:
1153+
if when_matched_update_all:
1154+
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
1155+
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
1156+
# this extra step avoids unnecessary IO and writes
1157+
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)
1158+
1159+
update_row_cnt = len(rows_to_update)
1160+
1161+
# build the match predicate filter
1162+
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
1163+
1164+
tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
1165+
1166+
if when_not_matched_insert_all:
1167+
rows_to_insert = upsert_util.get_rows_to_insert(df, matched_iceberg_table, join_cols)
1168+
1169+
insert_row_cnt = len(rows_to_insert)
1170+
1171+
tx.append(rows_to_insert)
1172+
1173+
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
1174+
10951175
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
10961176
"""
10971177
Shorthand API for appending a PyArrow table to the table.

pyiceberg/table/upsert_util.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import functools
18+
import operator
19+
20+
import pyarrow as pa
21+
from pyarrow import Table as pyarrow_table
22+
from pyarrow import compute as pc
23+
24+
from pyiceberg.expressions import (
25+
And,
26+
BooleanExpression,
27+
EqualTo,
28+
In,
29+
Or,
30+
)
31+
32+
33+
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
34+
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
35+
36+
if len(join_cols) == 1:
37+
return In(join_cols[0], unique_keys[0].to_pylist())
38+
else:
39+
return Or(*[And(*[EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()])
40+
41+
42+
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
43+
"""Check for duplicate rows in a PyArrow table based on the join columns."""
44+
return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0
45+
46+
47+
def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
48+
"""
49+
Return a table with rows that need to be updated in the target table based on the join columns.
50+
51+
When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred.
52+
Only matched rows that have an actual change to a non-key column value will be returned in the final output.
53+
"""
54+
all_columns = set(source_table.column_names)
55+
join_cols_set = set(join_cols)
56+
57+
non_key_cols = list(all_columns - join_cols_set)
58+
59+
match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols])
60+
61+
matching_source_rows = source_table.filter(match_expr)
62+
63+
rows_to_update = []
64+
65+
for index in range(matching_source_rows.num_rows):
66+
source_row = matching_source_rows.slice(index, 1)
67+
68+
target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols])
69+
70+
matching_target_row = target_table.filter(target_filter)
71+
72+
if matching_target_row.num_rows > 0:
73+
needs_update = False
74+
75+
for non_key_col in non_key_cols:
76+
source_value = source_row.column(non_key_col)[0].as_py()
77+
target_value = matching_target_row.column(non_key_col)[0].as_py()
78+
79+
if source_value != target_value:
80+
needs_update = True
81+
break
82+
83+
if needs_update:
84+
rows_to_update.append(source_row)
85+
86+
if rows_to_update:
87+
rows_to_update_table = pa.concat_tables(rows_to_update)
88+
else:
89+
rows_to_update_table = pa.Table.from_arrays([], names=source_table.column_names)
90+
91+
common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
92+
rows_to_update_table = rows_to_update_table.select(list(common_columns))
93+
94+
return rows_to_update_table
95+
96+
97+
def get_rows_to_insert(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
98+
source_filter_expr = pc.scalar(True)
99+
100+
for col in join_cols:
101+
target_values = target_table.column(col).to_pylist()
102+
expr = pc.field(col).isin(target_values)
103+
104+
if source_filter_expr is None:
105+
source_filter_expr = expr
106+
else:
107+
source_filter_expr = source_filter_expr & expr
108+
109+
non_matching_expr = ~source_filter_expr
110+
111+
source_columns = set(source_table.column_names)
112+
target_columns = set(target_table.column_names)
113+
114+
common_columns = source_columns.intersection(target_columns)
115+
116+
non_matching_rows = source_table.filter(non_matching_expr).select(common_columns)
117+
118+
return non_matching_rows

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ pytest-mock = "3.14.0"
9797
pyspark = "3.5.3"
9898
cython = "3.0.12"
9999
deptry = ">=0.14,<0.24"
100+
datafusion = "^44.0.0"
100101
docutils = "!=0.21.post1" # https://github.com/python-poetry/poetry/issues/9248#issuecomment-2026240520
101102

102103
[tool.poetry.group.docs.dependencies]
@@ -504,5 +505,9 @@ ignore_missing_imports = true
504505
module = "polars.*"
505506
ignore_missing_imports = true
506507

508+
[[tool.mypy.overrides]]
509+
module = "datafusion.*"
510+
ignore_missing_imports = true
511+
507512
[tool.coverage.run]
508513
source = ['pyiceberg/']

0 commit comments

Comments
 (0)