@@ -30,13 +30,13 @@ contains
3030 module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
3131 !> Input matrix A(n, n).
3232 ${rt}$, intent(in) :: A(:, :)
33- !> [optional] Order of the Pade approximation.
33+ !> Exponential of the input matrix E = exp(A).
34+ ${rt}$, intent(out) :: E(:, :)
35+ !> [optional] Order of the Pade approximation.
3436 integer(ilp), optional, intent(in) :: order
3537 !> [optional] State return flag.
3638 type(linalg_state_type), optional, intent(out) :: err
37- !> Exponential of the input matrix E = exp(A).
38- ${rt}$, intent(out) :: E(:, :)
39-
39+
4040 type(linalg_state_type) :: err0
4141 integer(ilp) :: lda, n, lde, ne
4242
@@ -68,7 +68,7 @@ contains
6868 type(linalg_state_type), optional, intent(out) :: err
6969
7070 ! Internal variables.
71- ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
71+ ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :), X_tmp(:, :)
7272 real(${rk}$) :: a_norm, c
7373 integer(ilp) :: m, n, ee, k, s, order_, i, j
7474 logical(lk) :: p
@@ -105,32 +105,29 @@ contains
105105 enddo
106106
107107 ! Iteratively compute the Pade approximation.
108- block
109- ${rt}$, allocatable :: X_tmp(:, :)
110- p = .true.
111- do k = 2, order_
112- c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
113- X_tmp = X
114- #:if rt.startswith('complex')
115- call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
116- #:else
117- call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
118- #:endif
108+ p = .true.
109+ do k = 2, order_
110+ c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+ X_tmp = X
112+ #:if rt.startswith('complex')
113+ call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114+ #:else
115+ call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116+ #:endif
117+ do concurrent(i=1:n, j=1:n)
118+ A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
119+ enddo
120+ if (p) then
119121 do concurrent(i=1:n, j=1:n)
120- A (i, j) = A (i, j) + c*X(i, j) ! E = E + c*X
122+ Q (i, j) = Q (i, j) + c*X(i, j) ! Q = Q + c*X
121123 enddo
122- if (p) then
123- do concurrent(i=1:n, j=1:n)
124- Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
125- enddo
126- else
127- do concurrent(i=1:n, j=1:n)
128- Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
129- enddo
130- endif
131- p = .not. p
132- enddo
133- end block
124+ else
125+ do concurrent(i=1:n, j=1:n)
126+ Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127+ enddo
128+ endif
129+ p = .not. p
130+ enddo
134131
135132 block
136133 integer(ilp) :: ipiv(n), info
@@ -139,17 +136,14 @@ contains
139136 end block
140137
141138 ! Matrix squaring.
142- block
143- ${rt}$, allocatable :: E_tmp(:, :)
144- do k = 1, s
145- E_tmp = A
146- #:if rt.startswith('complex')
147- call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
148- #:else
149- call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
150- #:endif
151- enddo
152- end block
139+ do k = 1, s
140+ X = A ! Re-use X to minimize allocations.
141+ #:if rt.startswith('complex')
142+ call gemm("N", "N", n, n, n, one_c${rk}$, X, n, X, n, zero_c${rk}$, A, n)
143+ #:else
144+ call gemm("N", "N", n, n, n, one_${rk}$, X, n, X, n, zero_${rk}$, A, n)
145+ #:endif
146+ enddo
153147 endif
154148
155149 call linalg_error_handling(err0, err)
0 commit comments