-
Notifications
You must be signed in to change notification settings - Fork 382
fix: upsert with null values in join columns #2429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1321c1b
4df175f
763042f
a32bb06
9af1b3d
6d772b9
cf3d68e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,38 +14,61 @@ | |
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| import functools | ||
| import operator | ||
| from math import isnan | ||
| from typing import Any | ||
|
|
||
| import pyarrow as pa | ||
| from pyarrow import Table as pyarrow_table | ||
| from pyarrow import compute as pc | ||
|
|
||
| from pyiceberg.expressions import ( | ||
| AlwaysFalse, | ||
| And, | ||
| BooleanExpression, | ||
| EqualTo, | ||
| In, | ||
| IsNaN, | ||
| IsNull, | ||
| Or, | ||
| ) | ||
|
|
||
|
|
||
| def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: | ||
| unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) | ||
| filters: list[BooleanExpression] = [] | ||
|
|
||
| if len(join_cols) == 1: | ||
| return In(join_cols[0], unique_keys[0].to_pylist()) | ||
| column = join_cols[0] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a way we can simplify the logic here? i think the primary issue is that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the The following SQL is invalid: WHERE x IN (1, 2, 3, NULL)Instead it should be this: WHERE x IN (1, 2, 3) OR x IS NULLTesting for null requires This is the reason for changing the If there's a better way I'm open to changing it, but I believe the added complexity in building filter expressions with null is justified. When null is not involved we produce the same |
||
| values = set(unique_keys[0].to_pylist()) | ||
|
|
||
| if None in values: | ||
| filters.append(IsNull(column)) | ||
| values.remove(None) | ||
|
|
||
| if nans := {v for v in values if isinstance(v, float) and isnan(v)}: | ||
| filters.append(IsNaN(column)) | ||
| values -= nans | ||
|
|
||
| filters.append(In(column, values)) | ||
| else: | ||
| filters = [ | ||
| functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() | ||
| ] | ||
|
|
||
| if len(filters) == 0: | ||
| return AlwaysFalse() | ||
| elif len(filters) == 1: | ||
| return filters[0] | ||
| else: | ||
| return Or(*filters) | ||
| def equals(column: str, value: Any) -> BooleanExpression: | ||
| if value is None: | ||
| return IsNull(column) | ||
|
|
||
| if isinstance(value, float) and isnan(value): | ||
| return IsNaN(column) | ||
|
|
||
| return EqualTo(column, value) | ||
|
|
||
| filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()] | ||
|
|
||
| if len(filters) == 0: | ||
| return AlwaysFalse() | ||
| elif len(filters) == 1: | ||
| return filters[0] | ||
| else: | ||
| return Or(*filters) | ||
|
|
||
|
|
||
| def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: | ||
|
|
@@ -98,13 +121,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols | |
| target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) | ||
|
|
||
| # Step 3: Perform an inner join to find which rows from source exist in target | ||
| matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") | ||
| # PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python. | ||
| # This is equivalent to: | ||
| # matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") | ||
| source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()} | ||
| target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()} | ||
| matching_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None] | ||
|
|
||
| # Step 4: Compare all rows using Python | ||
| to_update_indices = [] | ||
| for source_idx, target_idx in zip( | ||
| matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist() | ||
| ): | ||
| for source_idx, target_idx in matching_indices: | ||
| source_row = source_table.slice(source_idx, 1) | ||
| target_row = target_table.slice(target_idx, 1) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand this correctly, this is creating a predicate to test whether a row might exist in the
pyarrow_table(matching onjoin_cols).And since
Null == Anyshould always return unknown in SQL, can we just filter out any rows from thepyarrow_tablewhere thejoin_colsfields containNone(we treatNoneas SQLNull), and then build the match filter based on the filtered pyarrow table (using the existing logic for building the match filter)? This would be much simpler.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would that mean it's impossible to update rows with null in the join columns, since they are filtered out?
If so, that's not what I was going for. I'd like the solution to pass this test: https://github.com/mdwint/iceberg-python/blob/f818016e5c198581b7d7b11dba2b9ebd414e19bc/tests/table/test_upsert.py#L784-L831
This would be equivalent to the following Spark SQL (using the null-safe equality operator
<=>):Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking through this some more, I ask myself: What should the semantics of
upsertbe? Should it use=or<=>to test equality? For my use case<=>is right, and I also find it most intuitive, but does that mean it should be the default?I see several options:
<=>the default. Users who don't want to update nulls can filter them out themselves before callingupsert. The status quo is crashing, so there are no existing users expecting a different behaviour.=the default. This means I can't achieve my goal, and I'll need to reimplementupsertmyself. It also means new rows will be inserted for every row containing null in the join columns. This is unintuitive to me, but who knows someone might want it?upsertto select the comparison operator. Maximum flexibility, more work to implement.