Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 93 additions & 38 deletions quantecon/markov/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
_fill_dense_Q, _s_wise_max_argmax, _s_wise_max, _find_indices,
_has_sorted_sa_indices, _generate_a_indptr
)
from numba import njit


class DiscreteDP:
Expand Down Expand Up @@ -367,25 +368,8 @@ def __init__(self, R, Q, beta, s_indices=None, a_indices=None):
self.s_indices = _s_indices

# Define state-wise maximization
def s_wise_max(vals, out=None, out_argmax=None):
"""
Return the vector max_a vals(s, a), where vals is represented
by a 1-dimensional ndarray of shape (self.num_sa_pairs,).
out and out_argmax must be of length self.num_states; dtype of
out_argmax must be int.

"""
if out is None:
out = np.empty(self.num_states)
if out_argmax is None:
_s_wise_max(self.a_indices, self.a_indptr, vals,
out_max=out)
else:
_s_wise_max_argmax(self.a_indices, self.a_indptr, vals,
out_max=out, out_argmax=out_argmax)
return out

self.s_wise_max = s_wise_max
self.s_wise_max = _create_sa_s_wise_max(self.a_indices, self.a_indptr, self.num_states)


else: # Not self._sa_pair
if self.R.ndim != 2:
Expand All @@ -400,25 +384,9 @@ def s_wise_max(vals, out=None, out_argmax=None):
self.num_sa_pairs = (self.R > -np.inf).sum()

# Define state-wise maximization
def s_wise_max(vals, out=None, out_argmax=None):
"""
Return the vector max_a vals(s, a), where vals is represented
by a 2-dimensional ndarray of shape (n, m). Stored in out,
which must be of length self.num_states.
out and out_argmax must be of length self.num_states; dtype of
out_argmax must be int.

"""
if out is None:
out = np.empty(self.num_states)
if out_argmax is None:
vals.max(axis=1, out=out)
else:
vals.argmax(axis=1, out=out_argmax)
out[:] = vals[np.arange(self.num_states), out_argmax]
return out

self.s_wise_max = s_wise_max
self.s_wise_max = _create_dense_s_wise_max(self.num_states)

# Check that for every state, at least one action is feasible

# Check that for every state, at least one action is feasible
self._check_action_feasibility()
Expand Down Expand Up @@ -964,6 +932,30 @@ def controlled_mc(self, sigma):
return MarkovChain(Q_sigma)


def _check_action_feasibility(self):
"""
For each state, at least one action must be feasible.
"""
# Original logic left unmodified for exact logic preservation.
# If needed, further optimization (such as with numba)
# can be selectively applied only if logic changes.

if self._sa_pair:
# self.a_indptr gives (num_states+1,)
# Each state `i` has feasible actions from self.a_indptr[i] to self.a_indptr[i+1] (not inclusive at upper bound)
if np.any(self.a_indptr[1:] == self.a_indptr[:-1]):
raise ValueError(
"at least one state has no feasible actions"
)
else:
# At least one action must have a finite reward for each state
# (self.R > -np.inf).any(axis=1) shape: (n,)
if not np.all((self.R > -np.inf).any(axis=1)):
raise ValueError(
"at least one state has no feasible actions"
)


class DPSolveResult(dict):
"""
Contain the information about the dynamic programming result.
Expand Down Expand Up @@ -1078,3 +1070,66 @@ def backward_induction(ddp, T, v_term=None):
ddp.bellman_operator(vs[t, :], Tv=vs[t-1, :], sigma=sigmas[t-1, :])

return vs, sigmas

@njit(cache=True, fastmath=True)
def _dense_s_wise_max_impl(vals: np.ndarray, out: np.ndarray, out_argmax: np.ndarray):
"""
Numba-optimized: Find max (and optionally argmax) of each row of 2d array vals.
"""
n, m = vals.shape
for s in range(n):
best_val = vals[s, 0]
best_idx = 0
for a in range(1, m):
v = vals[s, a]
if v > best_val:
best_val = v
best_idx = a
out[s] = best_val
if out_argmax is not None:
out_argmax[s] = best_idx

def _create_dense_s_wise_max(num_states):
"""
Returns a function for dense state-wise maximization for use in DiscreteDP.
"""
def s_wise_max(vals, out=None, out_argmax=None):
"""
Return the vector max_a vals(s, a), where vals is represented
by a 2-dimensional ndarray of shape (n, m). Stored in out,
which must be of length self.num_states.
out and out_argmax must be of length self.num_states; dtype of
out_argmax must be int.
"""
if out is None:
out = np.empty(num_states)
if out_argmax is None:
# Fully vectorized in numpy as before for performance
vals.max(axis=1, out=out)
else:
if out_argmax.dtype != np.int_:
raise ValueError("out_argmax must be int dtype")
_dense_s_wise_max_impl(vals, out, out_argmax)
return out
return s_wise_max

def _create_sa_s_wise_max(a_indices, a_indptr, num_states):
"""
Returns a function for (s, a)-pair state-wise maximization for use in DiscreteDP.
"""
def s_wise_max(vals, out=None, out_argmax=None):
"""
Return the vector max_a vals(s, a), where vals is represented
by a 1-dimensional ndarray of shape (self.num_sa_pairs,).
out and out_argmax must be of length self.num_states; dtype of
out_argmax must be int.
"""
if out is None:
out = np.empty(num_states)
if out_argmax is None:
_s_wise_max(a_indices, a_indptr, vals, out_max=out)
else:
_s_wise_max_argmax(a_indices, a_indptr, vals,
out_max=out, out_argmax=out_argmax)
return out
return s_wise_max