Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4d6488c
Initial implementaitons of candidate vs rewrite shuttle
knassre-bodo Oct 9, 2025
5369379
Initial implementation of predicate server integration working on cry…
knassre-bodo Oct 9, 2025
36cab6e
WIP adding to lookup table
knassre-bodo Oct 9, 2025
ed6650c
Rewriting the rest of the filter count queries
knassre-bodo Oct 9, 2025
cc2bbed
Moving server address into mask server info setup
knassre-bodo Oct 9, 2025
a6d4b29
[RUN ALL]
knassre-bodo Oct 9, 2025
beadb15
Adding more tests
knassre-bodo Oct 10, 2025
1b4bcac
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Oct 14, 2025
5ea82f1
Switching up relational shuttle handling for simplification
knassre-bodo Oct 15, 2025
f0f512c
Minor adjustments to file placement
knassre-bodo Oct 15, 2025
54ecef1
Moved some logic from rewrite shuttle to candidate visitor
knassre-bodo Oct 15, 2025
557aaeb
Added more tests
knassre-bodo Oct 15, 2025
6b109d9
Added rewrite shuttle docstrings/comments
knassre-bodo Oct 16, 2025
1377916
Adding remaining documentation
knassre-bodo Oct 16, 2025
891c472
Removing dead rule
knassre-bodo Oct 16, 2025
7d7580b
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Oct 16, 2025
62db4bf
[RUN ALL]
knassre-bodo Oct 16, 2025
c9f6a59
[RUN ALL]
knassre-bodo Oct 16, 2025
7c37110
Adding logging to keep track of the batch requests sent
knassre-bodo Oct 26, 2025
127244f
Ensuring non-predicate sub-expressions are not sent to the server [RU…
knassre-bodo Oct 26, 2025
1f2dc6d
Ensuring non-predicate sub-expressions are not sent to the server [RU…
knassre-bodo Oct 26, 2025
2864e4a
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Oct 26, 2025
b278f9b
Adding date/datetime/timestamp literal handling tests [RUN CI]
knassre-bodo Oct 26, 2025
dcbb69c
Added new operators support, need to add new tests for datetime, quar…
knassre-bodo Oct 30, 2025
feabd8a
Added more tests, handled predicate pushdown bug with least/greatest,…
knassre-bodo Oct 30, 2025
940dd16
Added remaining tests [RUN CI]
knassre-bodo Oct 31, 2025
a6f6a37
Predicate server revisions with new API
knassre-bodo Nov 5, 2025
af10c5b
JSON request/response reformatting WIP
knassre-bodo Nov 16, 2025
0371ec5
Adding four-phase algorithm, need to implement step #3
knassre-bodo Nov 19, 2025
3996ced
Updating rewrite handling, need to add DP algorithm
knassre-bodo Nov 19, 2025
29e0e3f
Finishing implementation of min cover set
knassre-bodo Nov 21, 2025
f9c05b2
Added edge case tests for selection algorithm
knassre-bodo Nov 21, 2025
4f274fd
Minor test adjustment
knassre-bodo Nov 21, 2025
18379ef
Minor test adjustment
knassre-bodo Nov 21, 2025
f512f8b
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Nov 24, 2025
90f0671
Resolving conflicts [RUN ALL]
knassre-bodo Nov 24, 2025
f6a571b
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Nov 26, 2025
b728348
Added the FQN slash handling
knassre-bodo Nov 26, 2025
8e03b04
Revisions, QUOTE operator handling, docstrings/documentation [RUN ALL]
knassre-bodo Dec 2, 2025
a3c79cf
Fixing mask server tests [RUN ALL]
knassre-bodo Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions pydough/configs/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
existing state.
"""

from typing import TYPE_CHECKING, Union

from pydough.database_connectors import (
DatabaseContext,
DatabaseDialect,
Expand All @@ -30,6 +32,9 @@

from .pydough_configs import PyDoughConfigs

if TYPE_CHECKING:
from pydough.mask_server import MaskServerInfo


class PyDoughSession:
"""
Expand All @@ -50,6 +55,7 @@ def __init__(self) -> None:
connection=empty_connection, dialect=DatabaseDialect.ANSI
)
self._error_builder: PyDoughErrorBuilder = PyDoughErrorBuilder()
self._mask_server: MaskServerInfo | None = None

@property
def metadata(self) -> GraphMetadata | None:
Expand Down Expand Up @@ -131,6 +137,26 @@ def error_builder(self, builder: PyDoughErrorBuilder) -> None:
"""
self._error_builder = builder

@property
def mask_server(self) -> Union["MaskServerInfo", None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a special reason why to use Union["MaskServerInfo", None] instead of MaskServerInfo | None with from __future__ import annotations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it is imported via if TYPE_CHECKING:, so MaskServerInfo won't always be imported, but the type checker will recognize "MaskServerInfo" (which can't be done with | None). This is how we avoid circular imports.

"""
Get the active mask server information.

Returns:
The active mask server information.
"""
return self._mask_server

@mask_server.setter
def mask_server(self, server_info: Union["MaskServerInfo", None]) -> None:
"""
Set the active mask server information.

Args:
The mask server information to set.
"""
self._mask_server = server_info

def connect_database(self, database_name: str, **kwargs) -> DatabaseContext:
"""
Create a new DatabaseContext and register it in the session. This returns
Expand Down
4 changes: 2 additions & 2 deletions pydough/conversion/masking_shuttles.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rewrite_masked_literal_comparison(
# literal in a call to MASK by toggling is_unmask to False.
masked_literal = CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
call_arg.op.masking_metadata, call_arg.op.table_path, False
),
call_arg.data_type,
[literal_arg],
Expand All @@ -83,7 +83,7 @@ def rewrite_masked_literal_comparison(
[
CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
call_arg.op.masking_metadata, call_arg.op.table_path, False
),
call_arg.data_type,
[LiteralExpression(v, inner_type)],
Expand Down
40 changes: 32 additions & 8 deletions pydough/conversion/relational_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import pydough.pydough_operators as pydop
from pydough.configs import PyDoughSession
from pydough.mask_server.mask_server_candidate_visitor import MaskServerCandidateVisitor
from pydough.mask_server.mask_server_rewrite_shuttle import MaskServerRewriteShuttle
from pydough.metadata import (
CartesianProductMetadata,
GeneralJoinMetadata,
Expand Down Expand Up @@ -45,7 +47,10 @@
LiteralExpression,
Project,
RelationalExpression,
RelationalExpressionDispatcher,
RelationalExpressionShuttle,
RelationalExpressionShuttleDispatcher,
RelationalExpressionVisitor,
RelationalNode,
RelationalRoot,
Scan,
Expand Down Expand Up @@ -861,7 +866,9 @@ def build_simple_table_scan(
)
unmask_columns[name] = CallExpression(
pydop.MaskedExpressionFunctionOperator(
hybrid_expr.column.column_property, True
hybrid_expr.column.column_property,
node.collection.collection.table_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the reason why we need to use the full table path in metadata?

Copy link
Contributor Author

@knassre-bodo knassre-bodo Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXACTLY (plus its a good idea in general)

True,
),
hybrid_expr.column.column_property.unprotected_data_type,
[ColumnReference(name, hybrid_expr.typ)],
Expand Down Expand Up @@ -1561,7 +1568,9 @@ def confirm_root(node: RelationalNode) -> RelationalRoot:
def optimize_relational_tree(
root: RelationalRoot,
session: PyDoughSession,
additional_shuttles: list[RelationalExpressionShuttle],
additional_shuttles: list[
RelationalExpressionShuttle | RelationalExpressionVisitor
],
) -> RelationalRoot:
"""
Runs optimize on the relational tree, including pushing down filters and
Expand All @@ -1570,8 +1579,8 @@ def optimize_relational_tree(
Args:
`root`: the relational root to optimize.
`configs`: PyDough session used during optimization.
`additional_shuttles`: additional relational expression shuttles to use
for expression simplification.
`additional_shuttles`: additional relational expression shuttles or
visitors to use for expression simplification.

Returns:
The optimized relational root.
Expand Down Expand Up @@ -1633,7 +1642,7 @@ def optimize_relational_tree(

# Run the following pipeline twice:
# A: projection pullup
# B: expression simplification
# B: expression simplification (followed by additional shuttles)
# C: filter pushdown
# D: join-aggregate transpose
# E: projection pullup again
Expand All @@ -1647,7 +1656,13 @@ def optimize_relational_tree(
# pullup and pushdown and so on.
for _ in range(2):
root = confirm_root(pullup_projections(root))
simplify_expressions(root, session, additional_shuttles)
simplify_expressions(root, session)
# Run all of the other shuttles/visitors over the entire tree.
for shuttle_or_visitor in additional_shuttles:
if isinstance(shuttle_or_visitor, RelationalExpressionShuttle):
root.accept(RelationalExpressionShuttleDispatcher(shuttle_or_visitor))
else:
root.accept(RelationalExpressionDispatcher(shuttle_or_visitor, True))
root = confirm_root(push_filters(root, session))
root = confirm_root(pull_aggregates_above_joins(root))
root = confirm_root(pullup_projections(root))
Expand Down Expand Up @@ -1716,10 +1731,19 @@ def convert_ast_to_relational(
raw_result: RelationalRoot = postprocess_root(node, columns, hybrid, output)

# Invoke the optimization procedures on the result to clean up the tree.
additional_shuttles: list[RelationalExpressionShuttle] = []
additional_shuttles: list[
RelationalExpressionShuttle | RelationalExpressionVisitor
] = []
# Add the mask literal comparison shuttle if the environment variable
# PYDOUGH_ENABLE_MASK_REWRITES is set to 1.
# PYDOUGH_ENABLE_MASK_REWRITES is set to 1. If a masking rewrite server has
# been attached to the session, include the shuttles for that as well.
if os.getenv("PYDOUGH_ENABLE_MASK_REWRITES") == "1":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reson why PYDOUGH_ENABLE_MASK_REWRITES is not in PyDoughConfigs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we wanted an environment variable as a "switch"

if session.mask_server is not None:
candidate_shuttle: MaskServerCandidateVisitor = MaskServerCandidateVisitor()
additional_shuttles.append(candidate_shuttle)
additional_shuttles.append(
MaskServerRewriteShuttle(session.mask_server, candidate_shuttle)
)
additional_shuttles.append(MaskLiteralComparisonShuttle())
optimized_result: RelationalRoot = optimize_relational_tree(
raw_result, session, additional_shuttles
Expand Down
30 changes: 2 additions & 28 deletions pydough/conversion/relational_simplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,22 +1479,13 @@ class SimplificationVisitor(RelationalVisitor):
the current node are placed on the stack.
"""

def __init__(
self,
session: PyDoughSession,
additional_shuttles: list[RelationalExpressionShuttle],
):
def __init__(self, session: PyDoughSession):
self.stack: list[dict[RelationalExpression, PredicateSet]] = []
self.shuttle: SimplificationShuttle = SimplificationShuttle(session)
self.additional_shuttles: list[RelationalExpressionShuttle] = (
additional_shuttles
)

def reset(self):
self.stack.clear()
self.shuttle.reset()
for shuttle in self.additional_shuttles:
shuttle.reset()

def get_input_predicates(
self, node: RelationalNode
Expand Down Expand Up @@ -1559,8 +1550,6 @@ def generic_visit(
ref_expr = ColumnReference(name, expr.data_type)
expr = expr.accept_shuttle(self.shuttle)
output_predicates[ref_expr] = self.shuttle.stack.pop()
for shuttle in self.additional_shuttles:
expr = expr.accept_shuttle(shuttle)
node.columns[name] = expr
return output_predicates

Expand Down Expand Up @@ -1645,8 +1634,6 @@ def visit_filter(self, node: Filter) -> None:
# Transform the filter condition in-place.
node._condition = node.condition.accept_shuttle(self.shuttle)
self.shuttle.stack.pop()
for shuttle in self.additional_shuttles:
node._condition = node.condition.accept_shuttle(shuttle)
self.infer_null_predicates_from_condition(
output_predicates,
node.condition,
Expand All @@ -1661,8 +1648,6 @@ def visit_join(self, node: Join) -> None:
# Transform the join condition in-place.
node._condition = node.condition.accept_shuttle(self.shuttle)
self.shuttle.stack.pop()
for shuttle in self.additional_shuttles:
node._condition = node.condition.accept_shuttle(shuttle)
# If the join is not an inner join, remove any not-null predicates
# from the RHS of the join.
if node.join_type != JoinType.INNER:
Expand All @@ -1689,8 +1674,6 @@ def visit_limit(self, node: Limit) -> None:
for ordering_expr in node.orderings:
ordering_expr.expr = ordering_expr.expr.accept_shuttle(self.shuttle)
self.shuttle.stack.pop()
for shuttle in self.additional_shuttles:
ordering_expr.expr = ordering_expr.expr.accept_shuttle(shuttle)
self.stack.append(output_predicates)

def visit_root(self, node: RelationalRoot) -> None:
Expand All @@ -1704,8 +1687,6 @@ def visit_root(self, node: RelationalRoot) -> None:
for ordering_expr in node.orderings:
ordering_expr.expr = ordering_expr.expr.accept_shuttle(self.shuttle)
self.shuttle.stack.pop()
for shuttle in self.additional_shuttles:
ordering_expr.expr = ordering_expr.expr.accept_shuttle(shuttle)
self.stack.append(output_predicates)

def visit_aggregate(self, node: Aggregate) -> None:
Expand All @@ -1725,7 +1706,6 @@ def visit_aggregate(self, node: Aggregate) -> None:
def simplify_expressions(
node: RelationalNode,
session: PyDoughSession,
additional_shuttles: list[RelationalExpressionShuttle],
) -> None:
"""
Transforms the current node and all of its descendants in-place to simplify
Expand All @@ -1734,12 +1714,6 @@ def simplify_expressions(
Args:
`node`: The relational node to perform simplification on.
`session`: The PyDough session used during the simplification.
`additional_shuttles`: A list of additional shuttles to apply to the
expressions of the node and its descendants. These shuttles are applied
after the simplification shuttle, and can be used to perform additional
transformations on the expressions.
"""
simplifier: SimplificationVisitor = SimplificationVisitor(
session, additional_shuttles
)
simplifier: SimplificationVisitor = SimplificationVisitor(session)
node.accept(simplifier)
3 changes: 2 additions & 1 deletion pydough/errors/error_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def __init__(self):
"sql_keyword": "must have a SQL name that is not a reserved word",
}

def _split_identifier(self, name: str) -> list[str]:
@staticmethod
def _split_identifier(name: str) -> list[str]:
"""
Split a potentially qualified SQL identifier into parts.

Expand Down
11 changes: 9 additions & 2 deletions pydough/evaluation/evaluate_unqualified.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pydough.errors import (
PyDoughSessionException,
)
from pydough.mask_server import MaskServerInfo
from pydough.metadata import GraphMetadata
from pydough.qdag import PyDoughCollectionQDAG, PyDoughQDAG
from pydough.relational import RelationalRoot
Expand All @@ -32,8 +33,8 @@ def _load_session_info(**kwargs) -> PyDoughSession:
Load the session information from the active session unless it is found
in the keyword arguments. The following variants are accepted:
- If `session` is found, it is used directly.
- If `metadata`, `config` and/or `database` are found, they are used to
construct a new session.
- If `metadata`, `config`, `mask_server`, and/or `database` are found, they
are used to construct a new session.
- If none of these are found, the active session is used.

Args:
Expand Down Expand Up @@ -88,13 +89,19 @@ def _load_session_info(**kwargs) -> PyDoughSession:
database = kwargs.pop("database")
else:
database = pydough.active_session.database
mask_server: MaskServerInfo | None
if "mask_server" in kwargs:
mask_server = kwargs.pop("mask_server")
else:
mask_server = pydough.active_session.mask_server
assert not kwargs, f"Unexpected keyword arguments: {kwargs}"

# Construct the new session
new_session: PyDoughSession = PyDoughSession()
new_session._metadata = metadata
new_session._config = config
new_session._database = database
new_session._mask_server = mask_server
return new_session


Expand Down
4 changes: 4 additions & 0 deletions pydough/mask_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

__all__ = [
"MaskServerCandidateVisitor",
"MaskServerInfo",
"MaskServerInput",
"MaskServerOutput",
"MaskServerResponse",
"MaskServerRewriteShuttle",
"RequestMethod",
"ServerConnection",
"ServerRequest",
Expand All @@ -18,6 +20,8 @@
MaskServerOutput,
MaskServerResponse,
)
from .mask_server_candidate_visitor import MaskServerCandidateVisitor
from .mask_server_rewrite_shuttle import MaskServerRewriteShuttle
from .server_connection import (
RequestMethod,
ServerConnection,
Expand Down
Loading