Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 74 additions & 28 deletions quantecon/_matrix_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy.linalg import solve
from scipy.linalg import solve_discrete_lyapunov as sp_solve_discrete_lyapunov
from scipy.linalg import solve_discrete_are as sp_solve_discrete_are
from numba import njit


EPS = np.finfo(float).eps
Expand Down Expand Up @@ -273,41 +274,86 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
"""
m = Qs.shape[0]
k, n = Qs.shape[1], Rs.shape[1]
# Numba expects C-contiguous np.arrays, ensure type/copy as needed
Π = np.ascontiguousarray(Π, dtype=np.float64)
As = np.ascontiguousarray(As, dtype=np.float64)
Bs = np.ascontiguousarray(Bs, dtype=np.float64)
Qs = np.ascontiguousarray(Qs, dtype=np.float64)
Rs = np.ascontiguousarray(Rs, dtype=np.float64)
Ns = np.ascontiguousarray(Ns, dtype=np.float64)

Ps = _riccati_core(Π, As, Bs, Qs, Rs, Ns, beta, tolerance, max_iter, m, k, n)
if Ps is None:
fail_msg = "Convergence failed after {} iterations."
raise ValueError(fail_msg.format(max_iter))

return Ps


@njit(cache=True)
def _riccati_core(Π, As, Bs, Qs, Rs, Ns, beta, tolerance, max_iter, m, k, n):
# Create the Ps matrices, initialize as identity matrix
Ps = np.array([np.eye(n) for i in range(m)])
Ps = np.empty((m, n, n))
for i in range(m):
Ps[i] = np.eye(n)
Ps1 = np.copy(Ps)

# == Set up for iteration on Riccati equations system == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."

# == Prepare array for iteration == #
sum1, sum2 = np.empty((n, n)), np.empty((n, n))
sum1 = np.empty((n, n))
sum2 = np.empty((n, n))

# == Main loop == #
iteration = 0
while error > tolerance:
error = tolerance + 1

while error > tolerance:
if iteration > max_iter:
raise ValueError(fail_msg.format(max_iter))

else:
error = 0
for i in range(m):
# Initialize arrays
sum1[:, :] = 0.
sum2[:, :] = 0.
for j in range(m):
sum1 += beta * Π[i, j] * As[i].T @ Ps[j] @ As[i]
sum2 += Π[i, j] * \
(beta * As[i].T @ Ps[j] @ Bs[i] + Ns[i].T) @ \
solve(Qs[i] + beta * Bs[i].T @ Ps[j] @ Bs[i],
beta * Bs[i].T @ Ps[j] @ As[i] + Ns[i])

Ps1[i][:, :] = Rs[i] + sum1 - sum2
error += np.max(np.abs(Ps1[i] - Ps[i]))

Ps[:, :, :] = Ps1[:, :, :]
iteration += 1
# Cannot raise exception directly in njit,
# signal error by returning NaNs in Ps, will check outside
return None
error = 0.0
for i in range(m):
# Initialize arrays
for ii in range(n):
for jj in range(n):
sum1[ii, jj] = 0.0
sum2[ii, jj] = 0.0

for j in range(m):
# sum1 += beta * Π[i, j] * As[i].T @ Ps[j] @ As[i]
term1 = beta * Π[i, j]
tmp1 = As[i].T @ Ps[j] @ As[i]
for ii in range(n):
for jj in range(n):
sum1[ii, jj] += term1 * tmp1[ii, jj]
# sum2 += Π[i, j] * \
# (beta * As[i].T @ Ps[j] @ Bs[i] + Ns[i].T) @ \
# solve(Qs[i] + beta * Bs[i].T @ Ps[j] @ Bs[i],
# beta * Bs[i].T @ Ps[j] @ As[i] + Ns[i])
bgain = beta * As[i].T @ Ps[j] @ Bs[i] + Ns[i].T
Mgain = Qs[i] + beta * Bs[i].T @ Ps[j] @ Bs[i]
rgain = beta * Bs[i].T @ Ps[j] @ As[i] + Ns[i]
sol = solve(Mgain, rgain)
tmp2 = bgain @ sol
for ii in range(n):
for jj in range(n):
sum2[ii, jj] += Π[i, j] * tmp2[ii, jj]

for ii in range(n):
for jj in range(n):
Ps1[i, ii, jj] = Rs[i, ii, jj] + sum1[ii, jj] - sum2[ii, jj]

local_error = 0.0
for ii in range(n):
for jj in range(n):
diff = abs(Ps1[i, ii, jj] - Ps[i, ii, jj])
if diff > local_error:
local_error = diff
error += local_error

for i in range(m):
for ii in range(n):
for jj in range(n):
Ps[i, ii, jj] = Ps1[i, ii, jj]
iteration += 1

return Ps