From 4e2da3acdf81940bba9f11bf6ee7195a345c4e46 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 9 Aug 2025 18:36:25 +0200 Subject: [PATCH] Extend 3-arg `dot` to generic `HermOrSym` sparse matrices --- src/linalg.jl | 18 ++++++++++++------ test/linalg.jl | 5 +++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 9f813246..e3207cd4 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,7 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular, - RealHermSymComplexHerm, checksquare, sym_uplo, wrap + RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap using Random: rand! const tilebufsize = 10800 # Approximately 32k/3 @@ -1210,6 +1210,9 @@ function nzrangelo(A, i, excl=false) @inbounds r2 < r1 || rv[r1] >= i + excl ? r : (searchsortedfirst(view(rv, r1:r2), i + excl) + r1-1):r2 end +dot(x::AbstractVector, A::HermOrSym{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractVector) = + _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint) +# disambiguation dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Real,<:AbstractSparseMatrixCSC}, y::AbstractVector) = _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint) function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, rangefun::Function, diagop::Function, odiagop::Function) @@ -1242,9 +1245,12 @@ function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, end return r end -dot(x::SparseVector, A::RealHermSymComplexHerm{<:Real,<:AbstractSparseMatrixCSC}, y::SparseVector) = - _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real) -function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rangefun::Function, diagop::Function) +dot(x::AbstractSparseVector, A::HermOrSym{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractSparseVector) = + _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint) +# disambiguation +dot(x::AbstractSparseVector, A::RealHermSymComplexHerm{<:Real,<:AbstractSparseMatrixCSC}, y::AbstractSparseVector) = + _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint) +function _dot(x::AbstractSparseVector, A::AbstractSparseMatrixCSC, y::AbstractSparseVector, rangefun::Function, diagop::Function, odiagop::Function) m, n = size(A) length(x) == m && n == length(y) || throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))")) @@ -1275,7 +1281,7 @@ function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rang A_ptr_lo = first(rangefun(A, xi, true)) A_ptr_hi = last(rangefun(A, xi, true)) if A_ptr_lo <= A_ptr_hi - r += dot(xv, _spdot((a, y) -> a'y, A_ptr_lo, A_ptr_hi, Arowval, Anzval, + r += dot(xv, _spdot((a, y) -> odiagop(a)*y, A_ptr_lo, A_ptr_hi, Arowval, Anzval, 1, length(ynzind), ynzind, ynzval)) end end @@ -2241,7 +2247,7 @@ end # return F # end # end -function factorize(A::LinearAlgebra.RealHermSymComplexHerm{Float64,<:AbstractSparseMatrixCSC}) +function factorize(A::RealHermSymComplexHerm{Float64,<:AbstractSparseMatrixCSC}) F = cholesky(A; check = false) if LinearAlgebra.issuccess(F) return F diff --git a/test/linalg.jl b/test/linalg.jl index 68ea6809..a48d17b6 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -954,12 +954,13 @@ end @test dot(x, A, y) ≈ dot(x, Av, y) end - for (T, trans) in ((Float64, Symmetric), (ComplexF64, Symmetric), (ComplexF64, Hermitian)), uplo in (:U, :L) + for T in (Float64, ComplexF64, Quaternion{Float64}), trans in (Symmetric, Hermitian), uplo in (:U, :L) B = sprandn(T, 10, 10, 0.2) x = sprandn(T, 10, 0.4) + xd = Vector(x) S = trans(B'B, uplo) Sd = trans(Matrix(B'B), uplo) - @test dot(x, S, x) ≈ dot(x, Sd, x) ≈ dot(Vector(x), S, Vector(x)) ≈ dot(Vector(x), Sd, Vector(x)) + @test dot(x, S, x) ≈ dot(x, Sd, x) ≈ dot(xd, S, xd) ≈ dot(xd, Sd, xd) end end