diff --git a/quantecon/markov/ddp.py b/quantecon/markov/ddp.py index 25a99850..b49885fc 100644 --- a/quantecon/markov/ddp.py +++ b/quantecon/markov/ddp.py @@ -119,6 +119,7 @@ _fill_dense_Q, _s_wise_max_argmax, _s_wise_max, _find_indices, _has_sorted_sa_indices, _generate_a_indptr ) +from numba import njit class DiscreteDP: @@ -367,25 +368,8 @@ def __init__(self, R, Q, beta, s_indices=None, a_indices=None): self.s_indices = _s_indices # Define state-wise maximization - def s_wise_max(vals, out=None, out_argmax=None): - """ - Return the vector max_a vals(s, a), where vals is represented - by a 1-dimensional ndarray of shape (self.num_sa_pairs,). - out and out_argmax must be of length self.num_states; dtype of - out_argmax must be int. - - """ - if out is None: - out = np.empty(self.num_states) - if out_argmax is None: - _s_wise_max(self.a_indices, self.a_indptr, vals, - out_max=out) - else: - _s_wise_max_argmax(self.a_indices, self.a_indptr, vals, - out_max=out, out_argmax=out_argmax) - return out - - self.s_wise_max = s_wise_max + self.s_wise_max = _create_sa_s_wise_max(self.a_indices, self.a_indptr, self.num_states) + else: # Not self._sa_pair if self.R.ndim != 2: @@ -400,25 +384,9 @@ def s_wise_max(vals, out=None, out_argmax=None): self.num_sa_pairs = (self.R > -np.inf).sum() # Define state-wise maximization - def s_wise_max(vals, out=None, out_argmax=None): - """ - Return the vector max_a vals(s, a), where vals is represented - by a 2-dimensional ndarray of shape (n, m). Stored in out, - which must be of length self.num_states. - out and out_argmax must be of length self.num_states; dtype of - out_argmax must be int. - - """ - if out is None: - out = np.empty(self.num_states) - if out_argmax is None: - vals.max(axis=1, out=out) - else: - vals.argmax(axis=1, out=out_argmax) - out[:] = vals[np.arange(self.num_states), out_argmax] - return out - - self.s_wise_max = s_wise_max + self.s_wise_max = _create_dense_s_wise_max(self.num_states) + + # Check that for every state, at least one action is feasible # Check that for every state, at least one action is feasible self._check_action_feasibility() @@ -964,6 +932,30 @@ def controlled_mc(self, sigma): return MarkovChain(Q_sigma) + def _check_action_feasibility(self): + """ + For each state, at least one action must be feasible. + """ + # Original logic left unmodified for exact logic preservation. + # If needed, further optimization (such as with numba) + # can be selectively applied only if logic changes. + + if self._sa_pair: + # self.a_indptr gives (num_states+1,) + # Each state `i` has feasible actions from self.a_indptr[i] to self.a_indptr[i+1] (not inclusive at upper bound) + if np.any(self.a_indptr[1:] == self.a_indptr[:-1]): + raise ValueError( + "at least one state has no feasible actions" + ) + else: + # At least one action must have a finite reward for each state + # (self.R > -np.inf).any(axis=1) shape: (n,) + if not np.all((self.R > -np.inf).any(axis=1)): + raise ValueError( + "at least one state has no feasible actions" + ) + + class DPSolveResult(dict): """ Contain the information about the dynamic programming result. @@ -1078,3 +1070,66 @@ def backward_induction(ddp, T, v_term=None): ddp.bellman_operator(vs[t, :], Tv=vs[t-1, :], sigma=sigmas[t-1, :]) return vs, sigmas + +@njit(cache=True, fastmath=True) +def _dense_s_wise_max_impl(vals: np.ndarray, out: np.ndarray, out_argmax: np.ndarray): + """ + Numba-optimized: Find max (and optionally argmax) of each row of 2d array vals. + """ + n, m = vals.shape + for s in range(n): + best_val = vals[s, 0] + best_idx = 0 + for a in range(1, m): + v = vals[s, a] + if v > best_val: + best_val = v + best_idx = a + out[s] = best_val + if out_argmax is not None: + out_argmax[s] = best_idx + +def _create_dense_s_wise_max(num_states): + """ + Returns a function for dense state-wise maximization for use in DiscreteDP. + """ + def s_wise_max(vals, out=None, out_argmax=None): + """ + Return the vector max_a vals(s, a), where vals is represented + by a 2-dimensional ndarray of shape (n, m). Stored in out, + which must be of length self.num_states. + out and out_argmax must be of length self.num_states; dtype of + out_argmax must be int. + """ + if out is None: + out = np.empty(num_states) + if out_argmax is None: + # Fully vectorized in numpy as before for performance + vals.max(axis=1, out=out) + else: + if out_argmax.dtype != np.int_: + raise ValueError("out_argmax must be int dtype") + _dense_s_wise_max_impl(vals, out, out_argmax) + return out + return s_wise_max + +def _create_sa_s_wise_max(a_indices, a_indptr, num_states): + """ + Returns a function for (s, a)-pair state-wise maximization for use in DiscreteDP. + """ + def s_wise_max(vals, out=None, out_argmax=None): + """ + Return the vector max_a vals(s, a), where vals is represented + by a 1-dimensional ndarray of shape (self.num_sa_pairs,). + out and out_argmax must be of length self.num_states; dtype of + out_argmax must be int. + """ + if out is None: + out = np.empty(num_states) + if out_argmax is None: + _s_wise_max(a_indices, a_indptr, vals, out_max=out) + else: + _s_wise_max_argmax(a_indices, a_indptr, vals, + out_max=out, out_argmax=out_argmax) + return out + return s_wise_max