@@ -37,26 +37,259 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
3737 return expr .op .as_expr (larg , rarg )
3838
3939
40+ class LowerAddRule (op_lowering .OpLoweringRule ):
41+ @property
42+ def op (self ) -> type [ops .ScalarOp ]:
43+ return numeric_ops .AddOp
44+
45+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
46+ assert isinstance (expr .op , numeric_ops .AddOp )
47+ larg , rarg = expr .children [0 ], expr .children [1 ]
48+
49+ if (
50+ larg .output_type == dtypes .BOOL_DTYPE
51+ and rarg .output_type == dtypes .BOOL_DTYPE
52+ ):
53+ int_result = expr .op .as_expr (
54+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
55+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
56+ )
57+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
58+
59+ if dtypes .is_string_like (larg .output_type ) and dtypes .is_string_like (
60+ rarg .output_type
61+ ):
62+ return ops .strconcat_op .as_expr (larg , rarg )
63+
64+ if larg .output_type == dtypes .BOOL_DTYPE :
65+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
66+ if rarg .output_type == dtypes .BOOL_DTYPE :
67+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
68+
69+ if (
70+ larg .output_type == dtypes .DATE_DTYPE
71+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
72+ ):
73+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
74+
75+ if (
76+ larg .output_type == dtypes .TIMEDELTA_DTYPE
77+ and rarg .output_type == dtypes .DATE_DTYPE
78+ ):
79+ rarg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (rarg )
80+
81+ return expr .op .as_expr (larg , rarg )
82+
83+
84+ class LowerSubRule (op_lowering .OpLoweringRule ):
85+ @property
86+ def op (self ) -> type [ops .ScalarOp ]:
87+ return numeric_ops .SubOp
88+
89+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
90+ assert isinstance (expr .op , numeric_ops .SubOp )
91+ larg , rarg = expr .children [0 ], expr .children [1 ]
92+
93+ if (
94+ larg .output_type == dtypes .BOOL_DTYPE
95+ and rarg .output_type == dtypes .BOOL_DTYPE
96+ ):
97+ int_result = expr .op .as_expr (
98+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
99+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
100+ )
101+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
102+
103+ if larg .output_type == dtypes .BOOL_DTYPE :
104+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
105+ if rarg .output_type == dtypes .BOOL_DTYPE :
106+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
107+
108+ if (
109+ larg .output_type == dtypes .DATE_DTYPE
110+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
111+ ):
112+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
113+
114+ return expr .op .as_expr (larg , rarg )
115+
116+
117+ @dataclasses .dataclass
118+ class LowerMulRule (op_lowering .OpLoweringRule ):
119+ @property
120+ def op (self ) -> type [ops .ScalarOp ]:
121+ return numeric_ops .MulOp
122+
123+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
124+ assert isinstance (expr .op , numeric_ops .MulOp )
125+ larg , rarg = expr .children [0 ], expr .children [1 ]
126+
127+ if (
128+ larg .output_type == dtypes .BOOL_DTYPE
129+ and rarg .output_type == dtypes .BOOL_DTYPE
130+ ):
131+ int_result = expr .op .as_expr (
132+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
133+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
134+ )
135+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
136+
137+ if (
138+ larg .output_type == dtypes .BOOL_DTYPE
139+ and rarg .output_type != dtypes .BOOL_DTYPE
140+ ):
141+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
142+ if (
143+ rarg .output_type == dtypes .BOOL_DTYPE
144+ and larg .output_type != dtypes .BOOL_DTYPE
145+ ):
146+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
147+
148+ return expr .op .as_expr (larg , rarg )
149+
150+
151+ class LowerDivRule (op_lowering .OpLoweringRule ):
152+ @property
153+ def op (self ) -> type [ops .ScalarOp ]:
154+ return numeric_ops .DivOp
155+
156+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
157+ assert isinstance (expr .op , numeric_ops .DivOp )
158+
159+ dividend = expr .children [0 ]
160+ divisor = expr .children [1 ]
161+
162+ if dividend .output_type == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (
163+ divisor .output_type
164+ ):
165+ # exact same as floordiv impl for timedelta
166+ numeric_result = ops .floordiv_op .as_expr (
167+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
168+ )
169+ int_result = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (numeric_result )
170+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
171+
172+ if (
173+ dividend .output_type == dtypes .BOOL_DTYPE
174+ and divisor .output_type == dtypes .BOOL_DTYPE
175+ ):
176+ int_result = expr .op .as_expr (
177+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
178+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
179+ )
180+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
181+
182+ # polars divide doesn't like bools, convert to int always
183+ # convert numerics to float always
184+ if dividend .output_type == dtypes .BOOL_DTYPE :
185+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
186+ elif dividend .output_type in (dtypes .BIGNUMERIC_DTYPE , dtypes .NUMERIC_DTYPE ):
187+ dividend = ops .AsTypeOp (to_type = dtypes .FLOAT_DTYPE ).as_expr (dividend )
188+ if divisor .output_type == dtypes .BOOL_DTYPE :
189+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
190+
191+ return numeric_ops .div_op .as_expr (dividend , divisor )
192+
193+
40194class LowerFloorDivRule (op_lowering .OpLoweringRule ):
41195 @property
42196 def op (self ) -> type [ops .ScalarOp ]:
43197 return numeric_ops .FloorDivOp
44198
45199 def lower (self , expr : expression .OpExpression ) -> expression .Expression :
200+ assert isinstance (expr .op , numeric_ops .FloorDivOp )
201+
46202 dividend = expr .children [0 ]
47203 divisor = expr .children [1 ]
48- using_floats = (dividend .output_type == dtypes .FLOAT_DTYPE ) or (
49- divisor .output_type == dtypes .FLOAT_DTYPE
50- )
51- inf_or_zero = (
52- expression .const (float ("INF" )) if using_floats else expression .const (0 )
53- )
54- zero_result = ops .mul_op .as_expr (inf_or_zero , dividend )
55- divisor_is_zero = ops .eq_op .as_expr (divisor , expression .const (0 ))
56- return ops .where_op .as_expr (zero_result , divisor_is_zero , expr )
204+
205+ if (
206+ dividend .output_type == dtypes .TIMEDELTA_DTYPE
207+ and divisor .output_type == dtypes .TIMEDELTA_DTYPE
208+ ):
209+ int_result = expr .op .as_expr (
210+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
211+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
212+ )
213+ return int_result
214+ if dividend .output_type == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (
215+ divisor .output_type
216+ ):
217+ # this is pretty fragile as zero will break it, and must fit back into int
218+ numeric_result = expr .op .as_expr (
219+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
220+ )
221+ int_result = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (numeric_result )
222+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
223+
224+ if dividend .output_type == dtypes .BOOL_DTYPE :
225+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
226+ if divisor .output_type == dtypes .BOOL_DTYPE :
227+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
228+
229+ if expr .output_type != dtypes .FLOAT_DTYPE :
230+ # need to guard against zero divisor
231+ # multiply dividend in this case to propagate nulls
232+ return ops .where_op .as_expr (
233+ ops .mul_op .as_expr (dividend , expression .const (0 )),
234+ ops .eq_op .as_expr (divisor , expression .const (0 )),
235+ numeric_ops .floordiv_op .as_expr (dividend , divisor ),
236+ )
237+ else :
238+ return expr .op .as_expr (dividend , divisor )
239+
240+
241+ class LowerModRule (op_lowering .OpLoweringRule ):
242+ @property
243+ def op (self ) -> type [ops .ScalarOp ]:
244+ return numeric_ops .ModOp
245+
246+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
247+ og_expr = expr
248+ assert isinstance (expr .op , numeric_ops .ModOp )
249+ larg , rarg = expr .children [0 ], expr .children [1 ]
250+
251+ if (
252+ larg .output_type == dtypes .TIMEDELTA_DTYPE
253+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
254+ ):
255+ larg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
256+ rarg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
257+ int_result = expr .op .as_expr (larg_int , rarg_int )
258+ w_zero_handling = ops .where_op .as_expr (
259+ int_result ,
260+ ops .ne_op .as_expr (rarg_int , expression .const (0 )),
261+ ops .mul_op .as_expr (rarg_int , expression .const (0 )),
262+ )
263+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (w_zero_handling )
264+
265+ if larg .output_type == dtypes .BOOL_DTYPE :
266+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
267+ if rarg .output_type == dtypes .BOOL_DTYPE :
268+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
269+
270+ wo_bools = expr .op .as_expr (larg , rarg )
271+
272+ if og_expr .output_type == dtypes .INT_DTYPE :
273+ return ops .where_op .as_expr (
274+ wo_bools ,
275+ ops .ne_op .as_expr (rarg , expression .const (0 )),
276+ ops .mul_op .as_expr (rarg , expression .const (0 )),
277+ )
278+ return wo_bools
57279
58280
59- def _coerce_comparables (expr1 : expression .Expression , expr2 : expression .Expression ):
281+ def _coerce_comparables (
282+ expr1 : expression .Expression ,
283+ expr2 : expression .Expression ,
284+ * ,
285+ bools_only : bool = False
286+ ):
287+ if bools_only :
288+ if (
289+ expr1 .output_type != dtypes .BOOL_DTYPE
290+ and expr2 .output_type != dtypes .BOOL_DTYPE
291+ ):
292+ return expr1 , expr2
60293
61294 target_type = dtypes .coerce_to_common (expr1 .output_type , expr2 .output_type )
62295 if expr1 .output_type != target_type :
@@ -90,7 +323,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90323
91324POLARS_LOWERING_RULES = (
92325 * LOWER_COMPARISONS ,
326+ LowerAddRule (),
327+ LowerSubRule (),
328+ LowerMulRule (),
329+ LowerDivRule (),
93330 LowerFloorDivRule (),
331+ LowerModRule (),
94332)
95333
96334
0 commit comments