diff --git a/econpizza/parser/build_generic_functions.py b/econpizza/parser/build_generic_functions.py index e0512c4..020fbe2 100644 --- a/econpizza/parser/build_generic_functions.py +++ b/econpizza/parser/build_generic_functions.py @@ -77,14 +77,15 @@ def func_stst_rep_agent(y, func_pre_stst, func_eqns): def func_stst_het_agent(y, func_pre_stst, find_stat_wf, func_forw_stst, func_eqns): x, par = func_pre_stst(y) - x = x[..., None] wf, decisions_output, cnt_backw = find_stat_wf( x, par) dist, cnt_forw = func_forw_stst(decisions_output) - # TODO: for more than one dist this should be a loop... + # add time dimension for every object + x = x[..., None] decisions_output_array = (do[..., None] for do in decisions_output) + # TODO: for more than one dist this should be a loop... dist_array = dist[..., None] aux = (wf, decisions_output, cnt_backw), (dist, cnt_forw) diff --git a/econpizza/parser/checks.py b/econpizza/parser/checks.py index 9bb7ae5..9e1b5d1 100644 --- a/econpizza/parser/checks.py +++ b/econpizza/parser/checks.py @@ -64,7 +64,6 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre # run func_pre_stst to translate init values into vars & pars init_vals, par = func_pre_stst( d2jnp(init_guesses), d2jnp(fixed_values), pre_stst_mapping) - init_vals = init_vals[..., None] # collect some information needed later model['context']['init_run'] = {} @@ -91,6 +90,7 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre model['context']['init_run']['dists'] = dists_init # final test of main function + init_vals = init_vals[..., None] test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, jnp.zeros( len(shocks)), par, jnp.array(dists_init)[..., None], (doi[...,None] for doi in decisions_output_init))