diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index cefdd101a0..918e47b806 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import functools -import operator +from math import isnan +from typing import Any import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -23,29 +23,52 @@ from pyiceberg.expressions import ( AlwaysFalse, + And, BooleanExpression, EqualTo, In, + IsNaN, + IsNull, Or, ) def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + filters: list[BooleanExpression] = [] if len(join_cols) == 1: - return In(join_cols[0], unique_keys[0].to_pylist()) + column = join_cols[0] + values = set(unique_keys[0].to_pylist()) + + if None in values: + filters.append(IsNull(column)) + values.remove(None) + + if nans := {v for v in values if isinstance(v, float) and isnan(v)}: + filters.append(IsNaN(column)) + values -= nans + + filters.append(In(column, values)) else: - filters = [ - functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() - ] - if len(filters) == 0: - return AlwaysFalse() - elif len(filters) == 1: - return filters[0] - else: - return Or(*filters) + def equals(column: str, value: Any) -> BooleanExpression: + if value is None: + return IsNull(column) + + if isinstance(value, float) and isnan(value): + return IsNaN(column) + + return EqualTo(column, value) + + filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()] + + if len(filters) == 0: + return AlwaysFalse() + elif len(filters) == 1: + return filters[0] + else: + return Or(*filters) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: @@ -98,13 +121,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) # Step 3: Perform an inner join to find which rows from source exist in target - matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + # PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python. + # This is equivalent to: + # matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()} + target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()} + matching_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None] # Step 4: Compare all rows using Python to_update_indices = [] - for source_idx, target_idx in zip( - matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist() - ): + for source_idx, target_idx in matching_indices: source_row = source_table.slice(source_idx, 1) target_row = target_table.slice(target_idx, 1) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 891d4bbac7..334c026233 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,8 +23,8 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference -from pyiceberg.expressions.literals import LongLiteral +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, In, IsNaN, IsNull, Or, Reference +from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.table import UpsertResult @@ -440,6 +440,82 @@ def test_create_match_filter_single_condition() -> None: ) +def test_create_match_filter_single_column_without_null() -> None: + data = [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}] + + schema = pa.schema([pa.field("x", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x"]) + + assert expr == In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}) + + +def test_create_match_filter_single_column_with_null() -> None: + data = [ + {"x": 1.0}, + {"x": 2.0}, + {"x": None}, + {"x": 4.0}, + {"x": float("nan")}, + ] + schema = pa.schema([pa.field("x", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x"]) + + assert expr == Or( + left=IsNull(term=Reference(name="x")), + right=Or( + left=IsNaN(term=Reference(name="x")), + right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), + ), + ) + + +def test_create_match_filter_multi_column_with_null() -> None: + data = [ + {"x": 1.0, "y": 9.0}, + {"x": 2.0, "y": None}, + {"x": None, "y": 7.0}, + {"x": 4.0, "y": float("nan")}, + {"x": float("nan"), "y": 0.0}, + ] + schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x", "y"]) + + assert expr == Or( + left=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), + ), + right=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), + right=IsNull(term=Reference(name="y")), + ), + ), + right=Or( + left=And( + left=IsNull(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), + ), + right=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), + right=IsNaN(term=Reference(name="y")), + ), + right=And( + left=IsNaN(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), + ), + ), + ), + ) + + def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: identifier = "default.test_upsert_with_duplicate_rows_in_table" @@ -711,6 +787,56 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: ) +def test_upsert_with_nulls_in_join_columns(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_nulls_in_join_columns" + _drop_table(catalog, identifier) + + schema = pa.schema( + [ + ("foo", pa.string()), + ("bar", pa.int32()), + ("baz", pa.bool_()), + ] + ) + table = catalog.create_table(identifier, schema) + + # upsert table with null value + data_with_null = pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_with_null, join_cols=["foo"]) + assert upd.rows_updated == 0 + assert upd.rows_inserted == 1 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": False}, + ], + schema=schema, + ) + + # upsert table with null and non-null values, in two join columns + data_with_null = pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": True}, + {"foo": "lemon", "bar": None, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_with_null, join_cols=["foo", "bar"]) + assert upd.rows_updated == 1 + assert upd.rows_inserted == 1 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": "lemon", "bar": None, "baz": False}, + {"foo": None, "bar": 1, "baz": True}, + ], + schema=schema, + ) + + def test_transaction(catalog: Catalog) -> None: """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is rolled back."""