Skip to content

Commit

Permalink
allow endogenous grids & transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Aug 27, 2024
1 parent 7d6fb42 commit 0104bc8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
24 changes: 20 additions & 4 deletions econpizza/parser/build_generic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,23 @@


def func_forw_generic(distributions, decisions_outputs, grids, transition, indices):
# prototype for one distribution
# should be a for-loop for more than one distribution
"""prototype for one distribution. should be a for-loop for more than one distribution
"""
(dist, ) = distributions
# use objects provided in decisions_output if available
if isinstance(transition, int):
transition = decisions_outputs[transition]
if isinstance(grids[0], int):
grids[0] = decisions_outputs[grids[0]]
# use as inputs to forward distribution
endog_inds0, endog_probs0 = interp.interpolate_coord_robust(
grids[0], decisions_outputs[indices[0]])
if len(indices) == 1:
grid = grids[0]
dist = transition.T @ dists.forward_policy_1d(
dist, endog_inds0, endog_probs0)
elif len(indices) == 2:
if isinstance(grids[0], int):
grids[1] = decisions_outputs[grids[1]]
endog_inds1, endog_probs1 = interp.interpolate_coord_robust(
grids[1], decisions_outputs[indices[1]])
forwarded_dist = dists.forward_policy_2d(
Expand All @@ -29,13 +36,22 @@ def func_forw_generic(distributions, decisions_outputs, grids, transition, indic


def func_forw_stst_generic(decisions_outputs, tol, maxit, grids, transition, indices):
# prototype for one distribution, as with _func_forw
"""prototype for one distribution, as with _func_forw
"""
# use objects provided in decisions_output if available
if isinstance(transition, int):
transition = decisions_outputs[transition]
if isinstance(grids[0], int):
grids[0] = decisions_outputs[grids[0]]
# use as inputs to find stationary distribution
endog_inds0, endog_probs0 = interp.interpolate_coord_robust(
grids[0], decisions_outputs[indices[0]])
if len(indices) == 1:
dist, dist_cnt = dists.stationary_distribution_forward_policy_1d(
endog_inds0, endog_probs0, transition, tol, maxit)
elif len(indices) == 2:
if isinstance(grids[0], int):
grids[1] = decisions_outputs[grids[1]]
endog_inds1, endog_probs1 = interp.interpolate_coord_robust(
grids[1], decisions_outputs[indices[1]])
dist, dist_cnt = dists.stationary_distribution_forward_policy_2d(
Expand Down
13 changes: 11 additions & 2 deletions econpizza/parser/compile_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,17 @@ def get_forw_funcs(model):
raise NotImplementedError(
f"Grid(s) of type {str(*other)} not implemented.")

transition = model['context'][dist[exog[0]]['transition_name']]
grids = [model['context'][dist[i]['grid_name']] for i in endo]
# for each object, check if it is provided in decisions_outputs
try:
transition = model['decisions']['outputs'].index(dist[exog[0]]['transition_name'])
except ValueError:
transition = model['context'][dist[exog[0]]['transition_name']]
grids = []
for i in endo:
try:
grids.append(model['decisions']['outputs'].index(dist[i]['grid_name']))
except ValueError:
grids.append(model['context'][dist[i]['grid_name']])
indices = [model['decisions']['outputs'].index(i) for i in endo]

model['context']['func_forw'] = jax.tree_util.Partial(
Expand Down

0 comments on commit 0104bc8

Please sign in to comment.