Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,61 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import operator
from math import isnan
from typing import Any

import pyarrow as pa
from pyarrow import Table as pyarrow_table
from pyarrow import compute as pc

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
In,
IsNaN,
IsNull,
Or,
)


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
Copy link
Contributor

@zhongyujiang zhongyujiang Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly, this is creating a predicate to test whether a row might exist in the pyarrow_table (matching on join_cols).
And since Null == Any should always return unknown in SQL, can we just filter out any rows from the pyarrow_table where the join_cols fields contain None(we treat None as SQL Null), and then build the match filter based on the filtered pyarrow table (using the existing logic for building the match filter)? This would be much simpler.

Copy link
Author

@mdwint mdwint Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that mean it's impossible to update rows with null in the join columns, since they are filtered out?
If so, that's not what I was going for. I'd like the solution to pass this test: https://github.com/mdwint/iceberg-python/blob/f818016e5c198581b7d7b11dba2b9ebd414e19bc/tests/table/test_upsert.py#L784-L831

This would be equivalent to the following Spark SQL (using the null-safe equality operator <=>):

MERGE INTO target_table AS t
USING source_table AS s
ON (t.foo <=> s.foo AND t.bar <=> s.bar)
WHEN MATCHED THEN UPDATE SET *
WHEN NOT MATCHED THEN INSERT *

Copy link
Author

@mdwint mdwint Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking through this some more, I ask myself: What should the semantics of upsert be? Should it use = or <=> to test equality? For my use case <=> is right, and I also find it most intuitive, but does that mean it should be the default?

I see several options:

  • Make <=> the default. Users who don't want to update nulls can filter them out themselves before calling upsert. The status quo is crashing, so there are no existing users expecting a different behaviour.
  • Make = the default. This means I can't achieve my goal, and I'll need to reimplement upsert myself. It also means new rows will be inserted for every row containing null in the join columns. This is unintuitive to me, but who knows someone might want it?
  • Add an argument to upsert to select the comparison operator. Maximum flexibility, more work to implement.

unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
filters: list[BooleanExpression] = []

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
column = join_cols[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way we can simplify the logic here?

i think the primary issue is that the In operator cannot handle Null, is that right?

Copy link
Author

@mdwint mdwint Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the In operator cannot handle null by design, and this goes for SQL as well.

The following SQL is invalid:

WHERE x IN (1, 2, 3, NULL)

Instead it should be this:

WHERE x IN (1, 2, 3) OR x IS NULL

Testing for null requires IS NULL (or IS NOT NULL), and it's impossible with IN or =.

This is the reason for changing the create_match_filter function: we need to build more complex expressions if null is involved. Examples of such expressions are shown in the test cases.

If there's a better way I'm open to changing it, but I believe the added complexity in building filter expressions with null is justified. When null is not involved we produce the same In expression as before.

values = set(unique_keys[0].to_pylist())

if None in values:
filters.append(IsNull(column))
values.remove(None)

if nans := {v for v in values if isinstance(v, float) and isnan(v)}:
filters.append(IsNaN(column))
values -= nans

filters.append(In(column, values))
else:
filters = [
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)
def equals(column: str, value: Any) -> BooleanExpression:
if value is None:
return IsNull(column)

if isinstance(value, float) and isnan(value):
return IsNaN(column)

return EqualTo(column, value)

filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down Expand Up @@ -98,13 +121,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))

# Step 3: Perform an inner join to find which rows from source exist in target
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
# PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
# This is equivalent to:
# matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()}
target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()}
matching_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None]

# Step 4: Compare all rows using Python
to_update_indices = []
for source_idx, target_idx in zip(
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist()
):
for source_idx, target_idx in matching_indices:
source_row = source_table.slice(source_idx, 1)
target_row = target_table.slice(target_idx, 1)

Expand Down
130 changes: 128 additions & 2 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, In, IsNaN, IsNull, Or, Reference
from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import UpsertResult
Expand Down Expand Up @@ -440,6 +440,82 @@ def test_create_match_filter_single_condition() -> None:
)


def test_create_match_filter_single_column_without_null() -> None:
data = [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}]

schema = pa.schema([pa.field("x", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)

expr = create_match_filter(table, join_cols=["x"])

assert expr == In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)})


def test_create_match_filter_single_column_with_null() -> None:
data = [
{"x": 1.0},
{"x": 2.0},
{"x": None},
{"x": 4.0},
{"x": float("nan")},
]
schema = pa.schema([pa.field("x", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)

expr = create_match_filter(table, join_cols=["x"])

assert expr == Or(
left=IsNull(term=Reference(name="x")),
right=Or(
left=IsNaN(term=Reference(name="x")),
right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}),
),
)


def test_create_match_filter_multi_column_with_null() -> None:
data = [
{"x": 1.0, "y": 9.0},
{"x": 2.0, "y": None},
{"x": None, "y": 7.0},
{"x": 4.0, "y": float("nan")},
{"x": float("nan"), "y": 0.0},
]
schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)

expr = create_match_filter(table, join_cols=["x", "y"])

assert expr == Or(
left=Or(
left=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)),
),
right=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)),
right=IsNull(term=Reference(name="y")),
),
),
right=Or(
left=And(
left=IsNull(term=Reference(name="x")),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)),
),
right=Or(
left=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)),
right=IsNaN(term=Reference(name="y")),
),
right=And(
left=IsNaN(term=Reference(name="x")),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)),
),
),
),
)


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"

Expand Down Expand Up @@ -711,6 +787,56 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
)


def test_upsert_with_nulls_in_join_columns(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_nulls_in_join_columns"
_drop_table(catalog, identifier)

schema = pa.schema(
[
("foo", pa.string()),
("bar", pa.int32()),
("baz", pa.bool_()),
]
)
table = catalog.create_table(identifier, schema)

# upsert table with null value
data_with_null = pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_with_null, join_cols=["foo"])
assert upd.rows_updated == 0
assert upd.rows_inserted == 1
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": False},
],
schema=schema,
)

# upsert table with null and non-null values, in two join columns
data_with_null = pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": True},
{"foo": "lemon", "bar": None, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_with_null, join_cols=["foo", "bar"])
assert upd.rows_updated == 1
assert upd.rows_inserted == 1
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": "lemon", "bar": None, "baz": False},
{"foo": None, "bar": 1, "baz": True},
],
schema=schema,
)


def test_transaction(catalog: Catalog) -> None:
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
rolled back."""
Expand Down