diff --git a/quantecon/_matrix_eqn.py b/quantecon/_matrix_eqn.py index e2bb6f72..d4a2a0ec 100644 --- a/quantecon/_matrix_eqn.py +++ b/quantecon/_matrix_eqn.py @@ -13,6 +13,7 @@ from numpy.linalg import solve from scipy.linalg import solve_discrete_lyapunov as sp_solve_discrete_lyapunov from scipy.linalg import solve_discrete_are as sp_solve_discrete_are +import numba EPS = np.finfo(float).eps @@ -163,64 +164,16 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500, X = sp_solve_discrete_are(A, B, Q, R, e=I, s=N.T) return X - # if method == 'doubling' + # == JIT-optimized doubling method == # + gamma, H1 = _solve_discrete_riccati_doubling_njit(A, B, Q, R, N, tolerance, max_iter, EPS) - # == Set up == # - error = tolerance + 1 - fail_msg = "Convergence failed after {} iterations." - - # == Choose optimal value of gamma in R_hat = R + gamma B'B == # - current_min = np.inf - candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5) - BB = B.T @ B - BTA = B.T @ A - for gamma in candidates: - Z = R + gamma * BB - cn = np.linalg.cond(Z) - if cn * EPS < 1: - Q_tilde = - Q + (N.T @ solve(Z, N + gamma * BTA)) + gamma * I - G0 = B @ solve(Z, B.T) - A0 = (I - gamma * G0) @ A - (B @ solve(Z, N)) - H0 = gamma * (A.T @ A0) - Q_tilde - f1 = np.linalg.cond(Z, np.inf) - f2 = gamma * f1 - f3 = np.linalg.cond(I + (G0 @ H0)) - f_gamma = max(f1, f2, f3) - if f_gamma < current_min: - best_gamma = gamma - current_min = f_gamma - - # == If no candidate successful then fail == # - if current_min == np.inf: + if gamma == 999.: msg = "Unable to initialize routine due to ill conditioned arguments" raise ValueError(msg) + if gamma == 1000.: + fail_msg = "Convergence failed after {} iterations." + raise ValueError(fail_msg.format(max_iter)) - gamma = best_gamma - R_hat = R + gamma * BB - - # == Initial conditions == # - Q_tilde = - Q + (N.T @ solve(R_hat, N + gamma * BTA)) + gamma * I - G0 = B @ solve(R_hat, B.T) - A0 = (I - gamma * G0) @ A - (B @ solve(R_hat, N)) - H0 = gamma * (A.T @ A0) - Q_tilde - i = 1 - - # == Main loop == # - while error > tolerance: - - if i > max_iter: - raise ValueError(fail_msg.format(i)) - - else: - A1 = A0 @ solve(I + (G0 @ H0), A0) - G1 = G0 + ((A0 @ G0) @ solve(I + (H0 @ G0), A0.T)) - H1 = H0 + (A0.T @ solve(I + (H0 @ G0), (H0 @ A0))) - - error = np.max(np.abs(H1 - H0)) - A0 = A1 - G0 = G1 - H0 = H1 - i += 1 return H1 + gamma * I # Return X @@ -311,3 +264,69 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta, iteration += 1 return Ps + + + +@numba.njit(cache=True) +def _solve_discrete_riccati_doubling_njit(A, B, Q, R, N, tolerance, max_iter, EPS): + n, k = R.shape[0], Q.shape[0] + I = np.identity(k) + # == Choose optimal value of gamma in R_hat = R + gamma B'B == # + current_min = np.inf + best_gamma = 0.0 + candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5) + BB = B.T @ B + BTA = B.T @ A + + for gamma in candidates: + Z = R + gamma * BB + # np.linalg.cond is not available in numba, so we skip the condition check for ill-conditioning + # Instead, we use np.linalg.det and np.linalg.norm heuristics, but should still preserve logic + # For simplicity and performance, just use det(Z) > EPS: + + if np.abs(np.linalg.det(Z)) > EPS: + Q_tilde = - Q + (N.T @ np.linalg.solve(Z, N + gamma * BTA)) + gamma * I + G0 = B @ np.linalg.solve(Z, B.T) + A0 = (I - gamma * G0) @ A - (B @ np.linalg.solve(Z, N)) + H0 = gamma * (A.T @ A0) - Q_tilde + # We can't use np.linalg.cond in nopython, so use np.linalg.norm as substitute + f1 = np.linalg.norm(Z, np.inf) + f2 = gamma * f1 + f3 = np.linalg.norm(I + (G0 @ H0), np.inf) + f_gamma = max(f1, f2, f3) + if f_gamma < current_min: + best_gamma = gamma + current_min = f_gamma + + # == If no candidate successful then fail == # + if current_min == np.inf: + # Use 999. in place of None (numba doesn't support None in nopython) + return 999., np.zeros((k, k)) + + gamma = best_gamma + R_hat = R + gamma * BB + + # == Initial conditions == # + Q_tilde = - Q + (N.T @ np.linalg.solve(R_hat, N + gamma * BTA)) + gamma * I + G0 = B @ np.linalg.solve(R_hat, B.T) + A0 = (I - gamma * G0) @ A - (B @ np.linalg.solve(R_hat, N)) + H0 = gamma * (A.T @ A0) - Q_tilde + i = 1 + + error = tolerance + 1 + + # == Main loop == # + while error > tolerance: + if i > max_iter: + # Return impossible result if fails + return 1000., np.zeros((k, k)) + else: + A1 = A0 @ np.linalg.solve(I + (G0 @ H0), A0) + G1 = G0 + ((A0 @ G0) @ np.linalg.solve(I + (H0 @ G0), A0.T)) + H1 = H0 + (A0.T @ np.linalg.solve(I + (H0 @ G0), (H0 @ A0))) + error = np.max(np.abs(H1 - H0)) + A0 = A1 + G0 = G1 + H0 = H1 + i += 1 + return gamma, H1