Skip to content

Commit

Permalink
Add table upsert support (#1660)
Browse files Browse the repository at this point in the history
Closes #402
This PR adds the `upsert` function to the `Table` class and supports the
following upsert operations:
- when matched update all
- when not matched insert all

This PR is a remake of #1534 due to some infrastructure issues. For
additional context, please refer to that PR.

---------

Co-authored-by: VAA7RQ <[email protected]>
Co-authored-by: VAA7RQ <[email protected]>
Co-authored-by: mattmartin14 <[email protected]>
Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
5 people authored Feb 13, 2025
1 parent 6d1c30c commit 6351066
Show file tree
Hide file tree
Showing 5 changed files with 539 additions and 7 deletions.
34 changes: 27 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"


@dataclass()
class UpsertResult:
"""Summary the upsert operation."""

rows_updated: int = 0
rows_inserted: int = 0


class TableProperties:
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
Expand Down Expand Up @@ -1092,6 +1100,78 @@ def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
return self.metadata.name_mapping()

def upsert(
self, df: pa.Table, join_cols: list[str], when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True
) -> UpsertResult:
"""Shorthand API for performing an upsert to an iceberg table.
Args:
df: The input dataframe to upsert with the table's data.
join_cols: The columns to join on. These are essentially analogous to primary keys
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Existing row found → Update it
New row found → Insert it
Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
Existing row found → Do nothing (no updates)
New row found → Insert it
Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
Existing row found → Update it
New row found → Do nothing (no inserts)
Case 4: Both Parameters = False (No Merge Effect)
Existing row found → Do nothing
New row found → Do nothing
(Function effectively does nothing)
Returns:
An UpsertResult class (contains details of rows updated and inserted)
"""
from pyiceberg.table import upsert_util

if not when_matched_update_all and not when_not_matched_insert_all:
raise ValueError("no upsert options selected...exiting")

if upsert_util.has_duplicate_rows(df, join_cols):
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
matched_iceberg_table = self.scan(row_filter=matched_predicate).to_arrow()

update_row_cnt = 0
insert_row_cnt = 0

with self.transaction() as tx:
if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)

update_row_cnt = len(rows_to_update)

# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)

tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)

if when_not_matched_insert_all:
rows_to_insert = upsert_util.get_rows_to_insert(df, matched_iceberg_table, join_cols)

insert_row_cnt = len(rows_to_insert)

tx.append(rows_to_insert)

return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to the table.
Expand Down
118 changes: 118 additions & 0 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import operator

import pyarrow as pa
from pyarrow import Table as pyarrow_table
from pyarrow import compute as pc

from pyiceberg.expressions import (
And,
BooleanExpression,
EqualTo,
In,
Or,
)


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
else:
return Or(*[And(*[EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()])


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
"""Check for duplicate rows in a PyArrow table based on the join columns."""
return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0


def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
"""
Return a table with rows that need to be updated in the target table based on the join columns.
When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred.
Only matched rows that have an actual change to a non-key column value will be returned in the final output.
"""
all_columns = set(source_table.column_names)
join_cols_set = set(join_cols)

non_key_cols = list(all_columns - join_cols_set)

match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols])

matching_source_rows = source_table.filter(match_expr)

rows_to_update = []

for index in range(matching_source_rows.num_rows):
source_row = matching_source_rows.slice(index, 1)

target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols])

matching_target_row = target_table.filter(target_filter)

if matching_target_row.num_rows > 0:
needs_update = False

for non_key_col in non_key_cols:
source_value = source_row.column(non_key_col)[0].as_py()
target_value = matching_target_row.column(non_key_col)[0].as_py()

if source_value != target_value:
needs_update = True
break

if needs_update:
rows_to_update.append(source_row)

if rows_to_update:
rows_to_update_table = pa.concat_tables(rows_to_update)
else:
rows_to_update_table = pa.Table.from_arrays([], names=source_table.column_names)

common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
rows_to_update_table = rows_to_update_table.select(list(common_columns))

return rows_to_update_table


def get_rows_to_insert(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
source_filter_expr = pc.scalar(True)

for col in join_cols:
target_values = target_table.column(col).to_pylist()
expr = pc.field(col).isin(target_values)

if source_filter_expr is None:
source_filter_expr = expr
else:
source_filter_expr = source_filter_expr & expr

non_matching_expr = ~source_filter_expr

source_columns = set(source_table.column_names)
target_columns = set(target_table.column_names)

common_columns = source_columns.intersection(target_columns)

non_matching_rows = source_table.filter(non_matching_expr).select(common_columns)

return non_matching_rows
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pytest-mock = "3.14.0"
pyspark = "3.5.3"
cython = "3.0.12"
deptry = ">=0.14,<0.24"
datafusion = "^44.0.0"
docutils = "!=0.21.post1" # https://github.com/python-poetry/poetry/issues/9248#issuecomment-2026240520

[tool.poetry.group.docs.dependencies]
Expand Down Expand Up @@ -504,5 +505,9 @@ ignore_missing_imports = true
module = "polars.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "datafusion.*"
ignore_missing_imports = true

[tool.coverage.run]
source = ['pyiceberg/']
Loading

0 comments on commit 6351066

Please sign in to comment.