Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 217 additions & 104 deletions lectures/aiyagari.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ Below we provide code to solve the household problem, taking $r$ and $w$ as fixe

### Primitives and operators

We will solve the household problem using Howard policy iteration (see Ch 5 of [Dynamic Programming](https://dp.quantecon.org/)).
We will solve the household problem using value function iteration.

First we set up a `NamedTuple` to store the parameters that define a household asset accumulation problem, as well as the grids used to solve it

Expand All @@ -245,8 +245,8 @@ class Household(NamedTuple):
def create_household(β=0.96, # Discount factor
Π=[[0.9, 0.1], [0.1, 0.9]], # Markov chain
z_grid=[0.1, 1.0], # Exogenous states
a_min=1e-10, a_max=20, # Asset grid
a_size=200):
a_min=1e-10, a_max=12.5, # Asset grid
a_size=100):
"""
Create a Household namedtuple with custom grids.
"""
Expand Down Expand Up @@ -278,7 +278,6 @@ $$
for all $(a, z, a')$.

```{code-cell} ipython3
@jax.jit
def B(v, household, prices):
# Unpack
β, a_grid, z_grid, Π = household
Expand All @@ -303,125 +302,54 @@ def B(v, household, prices):
The next function computes greedy policies

```{code-cell} ipython3
@jax.jit
def get_greedy(v, household, prices):
"""
Computes a v-greedy policy σ, returned as a set of indices. If
Computes a v-greedy policy σ, returned as a set of indices. If
σ[i, j] equals ip, then a_grid[ip] is the maximizer at i, j.
"""
# argmax over ap
return jnp.argmax(B(v, household, prices), axis=-1)
```

The following function computes the array $r_{\sigma}$ which gives current rewards given policy $\sigma$
We define the Bellman operator $T$, which takes a value function $v$ and returns $Tv$ as given in the Bellman equation

```{code-cell} ipython3
@jax.jit
def compute_r_σ(σ, household, prices):
def T(v, household, prices):
"""
Compute current rewards at each i, j under policy σ. In particular,

r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip])

when ip = σ[i, j].
The Bellman operator. Takes a value function v and returns Tv.
"""
# Unpack
β, a_grid, z_grid, Π = household
a_size, z_size = len(a_grid), len(z_grid)
r, w = prices

# Compute r_σ[i, j]
a = jnp.reshape(a_grid, (a_size, 1))
z = jnp.reshape(z_grid, (1, z_size))
ap = a_grid[σ]
c = (1 + r) * a + w * z - ap
r_σ = u(c)

return r_σ
return jnp.max(B(v, household, prices), axis=-1)
```

The value $v_{\sigma}$ of a policy $\sigma$ is defined as

$$
v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma}
$$

(See Ch 5 of [Dynamic Programming](https://dp.quantecon.org/) for notation and background on Howard policy iteration.)

To compute this vector, we set up the linear map $v \rightarrow R_{\sigma} v$, where $R_{\sigma} := I - \beta P_{\sigma}$.

This map can be expressed as

$$
(R_{\sigma} v)(a, z) = v(a, z) - \beta \sum_{z'} v(\sigma(a, z), z') \Pi(z, z')
$$

(Notice that $R_\sigma$ is expressed as a linear operator rather than a matrix—this is much easier and cleaner to code, and also exploits sparsity.)
Here's value function iteration, which repeatedly applies the Bellman operator until convergence

```{code-cell} ipython3
@jax.jit
def R_σ(v, σ, household):
# Unpack
def value_function_iteration(household, prices, tol=1e-4, max_iter=10_000):
"""
Implements value function iteration using a compiled JAX loop.
"""
β, a_grid, z_grid, Π = household
a_size, z_size = len(a_grid), len(z_grid)

# Set up the array v[σ[i, j], jp]
zp_idx = jnp.arange(z_size)
zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
σ = jnp.reshape(σ, (a_size, z_size, 1))
V = v[σ, zp_idx]

# Expand Π[j, jp] to Π[i, j, jp]
Π = jnp.reshape(Π, (1, z_size, z_size))

# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp]
return v - β * jnp.sum(V * Π, axis=-1)
```
def condition_function(loop_state):
i, v, error = loop_state
return jnp.logical_and(error > tol, i < max_iter)

The next function computes the lifetime value of a given policy
def update(loop_state):
i, v, error = loop_state
v_new = T(v, household, prices)
error = jnp.max(jnp.abs(v_new - v))
return i + 1, v_new, error

```{code-cell} ipython3
@jax.jit
def get_value(σ, household, prices):
"""
Get the lifetime value of policy σ by computing
# Initial loop state
v_init = jnp.zeros((a_size, z_size))
loop_state_init = (0, v_init, tol + 1)

v_σ = R_σ^{-1} r_σ
"""
r_σ = compute_r_σ(σ, household, prices)

# Reduce R_σ to a function in v
_R_σ = lambda v: R_σ(v, σ, household)
# Run the fixed point iteration
i, v, error = jax.lax.while_loop(condition_function, update, loop_state_init)

# Compute v_σ = R_σ^{-1} r_σ using an iterative routine.
return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0]
```

Here's the Howard policy iteration

```{code-cell} ipython3
def howard_policy_iteration(household, prices,
tol=1e-4, max_iter=10_000, verbose=False):
"""
Howard policy iteration routine.
"""
β, a_grid, z_grid, Π = household
a_size, z_size = len(a_grid), len(z_grid)
σ = jnp.zeros((a_size, z_size), dtype=int)

v_σ = get_value(σ, household, prices)
i = 0
error = tol + 1
while error > tol and i < max_iter:
σ_new = get_greedy(v_σ, household, prices)
v_σ_new = get_value(σ_new, household, prices)
error = jnp.max(jnp.abs(v_σ_new - v_σ))
σ = σ_new
v_σ = v_σ_new
i = i + 1
if verbose:
print(f"iteration {i} with error {error}.")
return σ
return get_greedy(v, household, prices)
```

As a first example of what we can do, let's compute and plot an optimal accumulation policy at fixed prices
Expand All @@ -437,8 +365,7 @@ print(f"Interest rate: {r}, Wage: {w}")

```{code-cell} ipython3
with qe.Timer():
σ_star = howard_policy_iteration(
household, prices, verbose=True).block_until_ready()
σ_star = value_function_iteration(household, prices).block_until_ready()
```

The next plot shows asset accumulation policies at different values of the exogenous state
Expand Down Expand Up @@ -560,7 +487,7 @@ def G(K, firm, household):
# Generate a household object with these prices, compute
# aggregate capital.
prices = Prices(r=r, w=w)
σ_star = howard_policy_iteration(household, prices)
σ_star = value_function_iteration(household, prices)
return capital_supply(σ_star, household)
```

Expand Down Expand Up @@ -640,8 +567,8 @@ def prices_to_capital_stock(household, r, firm):
prices = Prices(r=r, w=w)

# Compute the optimal policy
σ_star = howard_policy_iteration(household, prices)
σ_star = value_function_iteration(household, prices)

# Compute capital supply
return capital_supply(σ_star, household)

Expand Down Expand Up @@ -752,3 +679,189 @@ plt.show()

```{solution-end}
```

```{exercise-start}
:label: aiyagari_ex3
```

In this lecture, we used value function iteration to solve the household problem.

An alternative is Howard policy iteration (HPI), which is discussed in detail in [Dynamic Programming](https://dp.quantecon.org/).

HPI can be faster than VFI for some problems because it uses fewer but more computationally intensive iterations.

Your task is to implement Howard policy iteration and compare the results with value function iteration.

**Key concepts you'll need:**

Howard policy iteration requires computing the value $v_{\sigma}$ of a policy $\sigma$, defined as:

$$
v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma}
$$

where $r_{\sigma}$ is the reward vector under policy $\sigma$, and $P_{\sigma}$ is the transition matrix induced by $\sigma$.

To solve this, you'll need to:
1. Compute current rewards $r_{\sigma}(a, z) = u((1 + r)a + wz - \sigma(a, z))$
2. Set up the linear operator $R_{\sigma}$ where $(R_{\sigma} v)(a, z) = v(a, z) - \beta \sum_{z'} v(\sigma(a, z), z') \Pi(z, z')$
3. Solve $v_{\sigma} = R_{\sigma}^{-1} r_{\sigma}$ using `jax.scipy.sparse.linalg.bicgstab`

You can use the `get_greedy` function that's already defined in this lecture.

Implement the following Howard policy iteration routine:

```python
def howard_policy_iteration(household, prices,
tol=1e-4, max_iter=10_000, verbose=False):
"""
Howard policy iteration routine.
"""
# Your code here
pass
```

Once implemented, compute the equilibrium capital stock using HPI and verify that it produces approximately the same result as VFI at the default parameter values.

```{exercise-end}
```

```{solution-start} aiyagari_ex3
:class: dropdown
```

First, we need to implement the helper functions for Howard policy iteration.

The following function computes the array $r_{\sigma}$ which gives current rewards given policy $\sigma$:

```{code-cell} ipython3
def compute_r_σ(σ, household, prices):
"""
Compute current rewards at each i, j under policy σ. In particular,

r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip])

when ip = σ[i, j].
"""
# Unpack
β, a_grid, z_grid, Π = household
a_size, z_size = len(a_grid), len(z_grid)
r, w = prices

# Compute r_σ[i, j]
a = jnp.reshape(a_grid, (a_size, 1))
z = jnp.reshape(z_grid, (1, z_size))
ap = a_grid[σ]
c = (1 + r) * a + w * z - ap
r_σ = u(c)

return r_σ
```

The linear operator $R_{\sigma}$ is defined as:

```{code-cell} ipython3
def R_σ(v, σ, household):
# Unpack
β, a_grid, z_grid, Π = household
a_size, z_size = len(a_grid), len(z_grid)

# Set up the array v[σ[i, j], jp]
zp_idx = jnp.arange(z_size)
zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
σ = jnp.reshape(σ, (a_size, z_size, 1))
V = v[σ, zp_idx]

# Expand Π[j, jp] to Π[i, j, jp]
Π = jnp.reshape(Π, (1, z_size, z_size))

# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp]
return v - β * jnp.sum(V * Π, axis=-1)
```

The next function computes the lifetime value of a given policy:

```{code-cell} ipython3
def get_value(σ, household, prices):
"""
Get the lifetime value of policy σ by computing

v_σ = R_σ^{-1} r_σ
"""
r_σ = compute_r_σ(σ, household, prices)

# Reduce R_σ to a function in v
_R_σ = lambda v: R_σ(v, σ, household)

# Compute v_σ = R_σ^{-1} r_σ using an iterative routine.
return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0]
```

Now we can implement Howard policy iteration:

```{code-cell} ipython3
@jax.jit
def howard_policy_iteration(household, prices, tol=1e-4, max_iter=10_000):
"""
Howard policy iteration routine using a compiled JAX loop.
"""
β, a_grid, z_grid, Π = household
a_size, z_size = len(a_grid), len(z_grid)

def condition_function(loop_state):
i, σ, v_σ, error = loop_state
return jnp.logical_and(error > tol, i < max_iter)

def update(loop_state):
i, σ, v_σ, error = loop_state
σ_new = get_greedy(v_σ, household, prices)
v_σ_new = get_value(σ_new, household, prices)
error = jnp.max(jnp.abs(v_σ_new - v_σ))
return i + 1, σ_new, v_σ_new, error

# Initial loop state
σ_init = jnp.zeros((a_size, z_size), dtype=int)
v_σ_init = get_value(σ_init, household, prices)
loop_state_init = (0, σ_init, v_σ_init, tol + 1)

# Run the fixed point iteration
i, σ, v_σ, error = jax.lax.while_loop(condition_function, update, loop_state_init)

return σ
```

Now let's create a modified version of the G function that uses HPI:

```{code-cell} ipython3
def G_hpi(K, firm, household):
# Get prices r, w associated with K
r = r_given_k(K, firm)
w = r_to_w(r, firm)

# Generate prices and compute aggregate capital using HPI.
prices = Prices(r=r, w=w)
σ_star = howard_policy_iteration(household, prices)
return capital_supply(σ_star, household)
```

And compute the equilibrium using HPI:

```{code-cell} ipython3
def compute_equilibrium_bisect_hpi(firm, household, a=1.0, b=20.0):
K = bisect(lambda k: k - G_hpi(k, firm, household), a, b, xtol=1e-4)
return K

firm = Firm()
household = create_household()
print("\nComputing equilibrium capital stock using HPI")
with qe.Timer():
K_star_hpi = compute_equilibrium_bisect_hpi(firm, household)
print(f"Computed equilibrium capital stock with HPI: {K_star_hpi:.5}")
print(f"Previous equilibrium capital stock with VFI: {K_star:.5}")
print(f"Difference: {abs(K_star_hpi - K_star):.6}")
```

The results show that both methods produce approximately the same equilibrium, confirming that HPI is a valid alternative to VFI.

```{solution-end}
```
Loading