-
Notifications
You must be signed in to change notification settings - Fork 149
Handle mixed input dtypes and empty size in numba lapack functions #1764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
025eaaf to
6ad3a19
Compare
|
We should redo |
6ad3a19 to
70b03b5
Compare
|
@jessegrabowski I didn't touch the QR stuff, could you give me a hand? I don't recall if you were supporting complex inputs, and if in that case the view thing is needed. Also I was too lazy to think about the empty input case for it. |
70b03b5 to
3251c8d
Compare
|
I can attack it over the weekend yeah |
|
I'm not testing the empty case for factor/solve tridiagonal, because I don't know how to define a valid empty case for those Ops. |
* It did not handle complex values correctly * It increased compile time with the nested function
3251c8d to
c54d739
Compare
|
Orthogonal to your help, the PR is ready for review @jessegrabowski. I understand if you want to take your time here as it is kind of your baby |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR systematically handles mixed input dtypes and empty arrays in numba LAPACK functions. The changes improve robustness by casting discrete inputs to floats, upcasting mixed inputs, handling empty arrays explicitly, and simplifying LAPACK wrapper code by removing unnecessary .view() calls.
Key Changes
- Added systematic dtype handling for discrete/complex inputs with fallback to obj mode for complex types
- Added explicit empty array handling to prevent LAPACK warnings/errors
- Simplified LAPACK wrapper code by removing
.view(dtype).ctypespattern in favor of direct.ctypesaccess - Updated Op
make_nodemethods to properly infer output dtypes based on input dtypes - Reorganized test classes for better structure
Reviewed changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/tensor/test_nlinalg.py | Fixed duplicate function definition and improved lstsq test assertions |
| tests/tensor/test_blockwise.py | Added test for eig blockwise operation with dtype verification |
| tests/link/numba/test_slinalg.py | Reorganized tests into classes and added empty array tests for solve/decomposition ops |
| tests/link/numba/test_nlinalg.py | Enhanced Eig test to handle multiple dtypes and verify correctness |
| pytensor/tensor/slinalg.py | Updated LU Op to infer correct output dtypes for discrete inputs |
| pytensor/tensor/nlinalg.py | Updated multiple Ops (MatrixPinv, MatrixInverse, Det, SLogDet, Lstsq) to infer output dtypes and remove unnecessary casts |
| pytensor/link/numba/dispatch/slinalg.py | Added dtype handling, empty array checks, and casting logic for Cholesky, LU, LUFactor, Solve, SolveTriangular, CholeskySolve |
| pytensor/link/numba/dispatch/nlinalg.py | Simplified dtype handling by removing int_to_float_fn and adding direct casting logic |
| pytensor/link/numba/dispatch/basic.py | Removed now-unused int_to_float_fn helper function |
| pytensor/link/numba/dispatch/linalg/utils.py | Refactored _check_scipy_linalg_matrix into more flexible _check_linalg_matrix with dtype matching |
| pytensor/link/numba/dispatch/linalg/solve/*.py | Removed .view() calls and updated to use _check_linalg_matrix |
| pytensor/link/numba/dispatch/linalg/decomposition/*.py | Removed .view() calls, updated checks, and enhanced cholesky to handle C-contiguous inputs |
c54d739 to
93986eb
Compare
Spinoff from #811
Major changes:
view(inp).ctypes -> inp.ctypesRe: Change
view(inp).ctypes -> inp.ctypesThis may have be needed when working with complex inputs? But we are not supporting them in most implementations, so it makes code more complex and is a potential source of bugs when we fail to systematically upcast input (point 3. from above)
Say we have a float32 and a float64 inputs, and forget to upcast the first one. Calling
viewwill raise for non f-contiguous inputs (which we always need for these routines):Even if it didn't raise the meaning of the array would be nonsensical: