Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@

from __future__ import annotations

import typing

import sqlglot.expressions as sge

from bigframes.core import window_spec
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
from bigframes.operations import aggregations as agg_ops
Expand All @@ -29,9 +26,35 @@
def compile(
op: agg_ops.WindowOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
order_by: typing.Sequence[sge.Expression] = [],
*,
order_by: tuple[sge.Expression, ...],
) -> sge.Expression:
return ORDERED_UNARY_OP_REGISTRATION[op](op, column, order_by=order_by)


@ORDERED_UNARY_OP_REGISTRATION.register(agg_ops.ArrayAggOp)
def _(
op: agg_ops.ArrayAggOp,
column: typed_expr.TypedExpr,
*,
order_by: tuple[sge.Expression, ...],
) -> sge.Expression:
return ORDERED_UNARY_OP_REGISTRATION[op](
op, column, window=window, order_by=order_by
)
expr = column.expr
if len(order_by) > 0:
expr = sge.Order(this=column.expr, expressions=list(order_by))
return sge.IgnoreNulls(this=sge.ArrayAgg(this=expr))


@ORDERED_UNARY_OP_REGISTRATION.register(agg_ops.StringAggOp)
def _(
op: agg_ops.StringAggOp,
column: typed_expr.TypedExpr,
*,
order_by: tuple[sge.Expression, ...],
) -> sge.Expression:
expr = column.expr
if len(order_by) > 0:
expr = sge.Order(this=expr, expressions=list(order_by))

expr = sge.GroupConcat(this=expr, separator=sge.convert(op.sep))
return sge.func("COALESCE", expr, sge.convert(""))
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
ARRAY_AGG(`bfcol_0` IGNORE NULLS ORDER BY `bfcol_0` IS NULL ASC, `bfcol_0` ASC) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
COALESCE(STRING_AGG(`bfcol_0`, ','
ORDER BY
`bfcol_0` IS NULL ASC,
`bfcol_0` ASC), '') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import typing

import pytest

from bigframes.core import agg_expressions as agg_exprs
from bigframes.core import array_value, identifiers, nodes, ordering
from bigframes.operations import aggregations as agg_ops
import bigframes.pandas as bpd

pytest.importorskip("pytest_snapshot")


def _apply_ordered_unary_agg_ops(
obj: bpd.DataFrame,
ops_list: typing.Sequence[agg_exprs.UnaryAggregation],
new_names: typing.Sequence[str],
ordering_args: typing.Sequence[str],
) -> str:
ordering_exprs = tuple(ordering.ascending_over(arg) for arg in ordering_args)
aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)]

agg_node = nodes.AggregateNode(
obj._block.expr.node,
aggregations=tuple(aggs),
by_column_ids=(),
order_by=ordering_exprs,
)
result = array_value.ArrayValue(agg_node)

sql = result.session._executor.to_sql(result, enable_cache=False)
return sql


def test_array_agg(scalar_types_df: bpd.DataFrame, snapshot):
# TODO: Verify "NULL LAST" syntax issue on Python < 3.12
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)

col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.ArrayAggOp().as_expr(col_name)
sql = _apply_ordered_unary_agg_ops(
bf_df, [agg_expr], [col_name], ordering_args=[col_name]
)

snapshot.assert_match(sql, "out.sql")


def test_string_agg(scalar_types_df: bpd.DataFrame, snapshot):
# TODO: Verify "NULL LAST" syntax issue on Python < 3.12
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)

col_name = "string_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.StringAggOp(sep=",").as_expr(col_name)
sql = _apply_ordered_unary_agg_ops(
bf_df, [agg_expr], [col_name], ordering_args=[col_name]
)

snapshot.assert_match(sql, "out.sql")