Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 73 additions & 54 deletions quantecon/_matrix_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy.linalg import solve
from scipy.linalg import solve_discrete_lyapunov as sp_solve_discrete_lyapunov
from scipy.linalg import solve_discrete_are as sp_solve_discrete_are
import numba


EPS = np.finfo(float).eps
Expand Down Expand Up @@ -163,64 +164,16 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
X = sp_solve_discrete_are(A, B, Q, R, e=I, s=N.T)
return X

# if method == 'doubling'
# == JIT-optimized doubling method == #
gamma, H1 = _solve_discrete_riccati_doubling_njit(A, B, Q, R, N, tolerance, max_iter, EPS)

# == Set up == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."

# == Choose optimal value of gamma in R_hat = R + gamma B'B == #
current_min = np.inf
candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5)
BB = B.T @ B
BTA = B.T @ A
for gamma in candidates:
Z = R + gamma * BB
cn = np.linalg.cond(Z)
if cn * EPS < 1:
Q_tilde = - Q + (N.T @ solve(Z, N + gamma * BTA)) + gamma * I
G0 = B @ solve(Z, B.T)
A0 = (I - gamma * G0) @ A - (B @ solve(Z, N))
H0 = gamma * (A.T @ A0) - Q_tilde
f1 = np.linalg.cond(Z, np.inf)
f2 = gamma * f1
f3 = np.linalg.cond(I + (G0 @ H0))
f_gamma = max(f1, f2, f3)
if f_gamma < current_min:
best_gamma = gamma
current_min = f_gamma

# == If no candidate successful then fail == #
if current_min == np.inf:
if gamma == 999.:
msg = "Unable to initialize routine due to ill conditioned arguments"
raise ValueError(msg)
if gamma == 1000.:
fail_msg = "Convergence failed after {} iterations."
raise ValueError(fail_msg.format(max_iter))

gamma = best_gamma
R_hat = R + gamma * BB

# == Initial conditions == #
Q_tilde = - Q + (N.T @ solve(R_hat, N + gamma * BTA)) + gamma * I
G0 = B @ solve(R_hat, B.T)
A0 = (I - gamma * G0) @ A - (B @ solve(R_hat, N))
H0 = gamma * (A.T @ A0) - Q_tilde
i = 1

# == Main loop == #
while error > tolerance:

if i > max_iter:
raise ValueError(fail_msg.format(i))

else:
A1 = A0 @ solve(I + (G0 @ H0), A0)
G1 = G0 + ((A0 @ G0) @ solve(I + (H0 @ G0), A0.T))
H1 = H0 + (A0.T @ solve(I + (H0 @ G0), (H0 @ A0)))

error = np.max(np.abs(H1 - H0))
A0 = A1
G0 = G1
H0 = H1
i += 1

return H1 + gamma * I # Return X

Expand Down Expand Up @@ -311,3 +264,69 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
iteration += 1

return Ps



@numba.njit(cache=True)
def _solve_discrete_riccati_doubling_njit(A, B, Q, R, N, tolerance, max_iter, EPS):
n, k = R.shape[0], Q.shape[0]
I = np.identity(k)
# == Choose optimal value of gamma in R_hat = R + gamma B'B == #
current_min = np.inf
best_gamma = 0.0
candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5)
BB = B.T @ B
BTA = B.T @ A

for gamma in candidates:
Z = R + gamma * BB
# np.linalg.cond is not available in numba, so we skip the condition check for ill-conditioning
# Instead, we use np.linalg.det and np.linalg.norm heuristics, but should still preserve logic
# For simplicity and performance, just use det(Z) > EPS:

if np.abs(np.linalg.det(Z)) > EPS:
Q_tilde = - Q + (N.T @ np.linalg.solve(Z, N + gamma * BTA)) + gamma * I
G0 = B @ np.linalg.solve(Z, B.T)
A0 = (I - gamma * G0) @ A - (B @ np.linalg.solve(Z, N))
H0 = gamma * (A.T @ A0) - Q_tilde
# We can't use np.linalg.cond in nopython, so use np.linalg.norm as substitute
f1 = np.linalg.norm(Z, np.inf)
f2 = gamma * f1
f3 = np.linalg.norm(I + (G0 @ H0), np.inf)
f_gamma = max(f1, f2, f3)
if f_gamma < current_min:
best_gamma = gamma
current_min = f_gamma

# == If no candidate successful then fail == #
if current_min == np.inf:
# Use 999. in place of None (numba doesn't support None in nopython)
return 999., np.zeros((k, k))

gamma = best_gamma
R_hat = R + gamma * BB

# == Initial conditions == #
Q_tilde = - Q + (N.T @ np.linalg.solve(R_hat, N + gamma * BTA)) + gamma * I
G0 = B @ np.linalg.solve(R_hat, B.T)
A0 = (I - gamma * G0) @ A - (B @ np.linalg.solve(R_hat, N))
H0 = gamma * (A.T @ A0) - Q_tilde
i = 1

error = tolerance + 1

# == Main loop == #
while error > tolerance:
if i > max_iter:
# Return impossible result if fails
return 1000., np.zeros((k, k))
else:
A1 = A0 @ np.linalg.solve(I + (G0 @ H0), A0)
G1 = G0 + ((A0 @ G0) @ np.linalg.solve(I + (H0 @ G0), A0.T))
H1 = H0 + (A0.T @ np.linalg.solve(I + (H0 @ G0), (H0 @ A0)))
error = np.max(np.abs(H1 - H0))
A0 = A1
G0 = G1
H0 = H1
i += 1
return gamma, H1