Skip to content

Commit 54e7c57

Browse files
committed
fypp best practice + minor optimizations + stack matrices
1 parent 3192a16 commit 54e7c57

File tree

1 file changed

+34
-47
lines changed

1 file changed

+34
-47
lines changed
Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,39 @@
11
#:include "common.fypp"
2-
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
2+
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX, REAL_INIT))
3+
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX, CMPLX_INIT))
4+
#:set RC_KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES
35
submodule (stdlib_linalg) stdlib_linalg_matrix_functions
46
use stdlib_constants
57
use stdlib_linalg_constants
68
use stdlib_linalg_blas, only: gemm
7-
use stdlib_linalg_lapack, only: gesv
9+
use stdlib_linalg_lapack, only: gesv, lacpy
810
use stdlib_linalg_lapack_aux, only: handle_gesv_info
911
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
1012
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
11-
implicit none
13+
implicit none(type, external)
1214

1315
character(len=*), parameter :: this = "matrix_exponential"
1416

1517
contains
1618

17-
#:for rk,rt,ri in RC_KINDS_TYPES
18-
module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
19+
#:for k,t,s, i in RC_KINDS_TYPES
20+
module function stdlib_linalg_${i}$_expm_fun(A, order) result(E)
1921
!> Input matrix A(n, n).
20-
${rt}$, intent(in) :: A(:, :)
22+
${t}$, intent(in) :: A(:, :)
2123
!> [optional] Order of the Pade approximation.
2224
integer(ilp), optional, intent(in) :: order
2325
!> Exponential of the input matrix E = exp(A).
24-
${rt}$, allocatable :: E(:, :)
26+
${t}$, allocatable :: E(:, :)
2527

2628
E = A
27-
call stdlib_linalg_${ri}$_expm_inplace(E, order)
28-
end function
29+
call stdlib_linalg_${i}$_expm_inplace(E, order)
30+
end function stdlib_linalg_${i}$_expm_fun
2931

30-
module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
32+
module subroutine stdlib_linalg_${i}$_expm(A, E, order, err)
3133
!> Input matrix A(n, n).
32-
${rt}$, intent(in) :: A(:, :)
34+
${t}$, intent(in) :: A(:, :)
3335
!> Exponential of the input matrix E = exp(A).
34-
${rt}$, intent(out) :: E(:, :)
36+
${t}$, intent(out) :: E(:, :)
3537
!> [optional] Order of the Pade approximation.
3638
integer(ilp), optional, intent(in) :: order
3739
!> [optional] State return flag.
@@ -49,27 +51,28 @@ contains
4951
'invalid matrix sizes: A must be square (lda=', lda, ', n=', n, ')', &
5052
' E must be square (lde=', lde, ', ne=', ne, ')')
5153
else
52-
E(:n, :n) = A(:n, :n)
53-
call stdlib_linalg_${ri}$_expm_inplace(E, order, err0)
54+
call lacpy("n", n, n, A, n, E, n) ! E = A
55+
call stdlib_linalg_${i}$_expm_inplace(E, order, err0)
5456
endif
5557

5658
! Process output and return
5759
call linalg_error_handling(err0,err)
5860

5961
return
60-
end subroutine stdlib_linalg_${ri}$_expm
62+
end subroutine stdlib_linalg_${i}$_expm
6163

62-
module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
64+
module subroutine stdlib_linalg_${i}$_expm_inplace(A, order, err)
6365
!> Input matrix A(n, n) / Output matrix exponential.
64-
${rt}$, intent(inout) :: A(:, :)
66+
${t}$, intent(inout) :: A(:, :)
6567
!> [optional] Order of the Pade approximation.
6668
integer(ilp), optional, intent(in) :: order
6769
!> [optional] State return flag.
6870
type(linalg_state_type), optional, intent(out) :: err
6971

7072
! Internal variables.
71-
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :), X_tmp(:, :)
72-
real(${rk}$) :: a_norm, c
73+
${t}$ :: A2(size(A, 1), size(A, 2)), Q(size(A, 1), size(A, 2))
74+
${t}$ :: X(size(A, 1), size(A, 2)), X_tmp(size(A, 1), size(A, 2))
75+
real(${k}$) :: a_norm, c
7376
integer(ilp) :: m, n, ee, k, s, order_, i, j
7477
logical(lk) :: p
7578
type(linalg_state_type) :: err0
@@ -90,42 +93,30 @@ contains
9093
a_norm = mnorm(A, "inf")
9194

9295
! Determine scaling factor for the matrix.
93-
ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
96+
ee = int(log(a_norm) / log2_${k}$, kind=ilp) + 1
9497
s = max(0, ee+1)
9598

9699
! Scale the input matrix & initialize polynomial.
97-
A2 = A/2.0_${rk}$**s ; X = A2
100+
A2 = A/2.0_${k}$**s
101+
call lacpy("n", n, n, A2, n, X, n) ! X = A2
98102

99103
! First step of the Pade approximation.
100-
c = 0.5_${rk}$
101-
allocate (Q, source=A2) ; A = A2
104+
c = 0.5_${k}$
102105
do concurrent(i=1:n, j=1:n)
103-
A(i, j) = merge(1.0_${rk}$ + c*A(i, j), c*A(i, j), i == j)
104-
Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
106+
A(i, j) = merge(1.0_${k}$ + c*A2(i, j), c*A2(i, j), i == j)
107+
Q(i, j) = merge(1.0_${k}$ - c*A2(i, j), -c*A2(i, j), i == j)
105108
enddo
106109

107110
! Iteratively compute the Pade approximation.
108111
p = .true.
109112
do k = 2, order_
110113
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111-
X_tmp = X
112-
#:if rt.startswith('complex')
113-
call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114-
#:else
115-
call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116-
#:endif
114+
call lacpy("n", n, n, X, n, X_tmp, n) ! X_tmp = X
115+
call gemm("N", "N", n, n, n, one_${s}$, A2, n, X_tmp, n, zero_${s}$, X, n)
117116
do concurrent(i=1:n, j=1:n)
118117
A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
118+
Q(i, j) = merge(Q(i, j) + c*X(i, j), Q(i, j) - c*X(i, j), p)
119119
enddo
120-
if (p) then
121-
do concurrent(i=1:n, j=1:n)
122-
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
123-
enddo
124-
else
125-
do concurrent(i=1:n, j=1:n)
126-
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127-
enddo
128-
endif
129120
p = .not. p
130121
enddo
131122

@@ -137,19 +128,15 @@ contains
137128

138129
! Matrix squaring.
139130
do k = 1, s
140-
X = A ! Re-use X to minimize allocations.
141-
#:if rt.startswith('complex')
142-
call gemm("N", "N", n, n, n, one_c${rk}$, X, n, X, n, zero_c${rk}$, A, n)
143-
#:else
144-
call gemm("N", "N", n, n, n, one_${rk}$, X, n, X, n, zero_${rk}$, A, n)
145-
#:endif
131+
call lacpy("n", n, n, A, n, X, n) ! X = A
132+
call gemm("N", "N", n, n, n, one_${s}$, X, n, X, n, zero_${s}$, A, n)
146133
enddo
147134
endif
148135

149136
call linalg_error_handling(err0, err)
150137

151138
return
152-
end subroutine stdlib_linalg_${ri}$_expm_inplace
139+
end subroutine stdlib_linalg_${i}$_expm_inplace
153140
#:endfor
154141

155142
end submodule stdlib_linalg_matrix_functions

0 commit comments

Comments
 (0)