Skip to content

Commit

Permalink
Fixed sim.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Jan 22, 2025
1 parent a05dac9 commit f3cf546
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 78 deletions.
38 changes: 17 additions & 21 deletions src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,23 @@ def transition_to_next_period(
savings_current_period,
choice,
params,
compute_exog_transition_vec,
exog_state_mapping,
compute_beginning_of_period_wealth,
compute_next_period_states,
model_funcs_sim,
sim_specific_keys,
):
n_agents = savings_current_period.shape[0]

exog_states_next_period = vmap(
realize_exog_process, in_axes=(0, 0, 0, None, None, None)
)(
exog_states_next_period = vmap(realize_exog_process, in_axes=(0, 0, 0, None, None))(
discrete_states_beginning_of_period,
choice,
sim_specific_keys[2:, :],
params,
compute_exog_transition_vec,
exog_state_mapping,
model_funcs_sim["processed_exog_funcs"],
)

discrete_endog_states_next_period = vmap(
update_discrete_states_for_one_agent, in_axes=(None, 0, 0, None) # choice
)(
compute_next_period_states["next_period_endogenous_state"],
model_funcs_sim["next_period_endogenous_state"],
discrete_states_beginning_of_period,
choice,
params,
Expand All @@ -167,15 +161,14 @@ def transition_to_next_period(
key=sim_specific_keys[1, :], num_agents=n_agents, mean=0, std=params["sigma"]
)

next_period_wealth = model_funcs_sim["compute_beginning_of_period_wealth"]
if continuous_state_beginning_of_period is not None:

continuous_state_next_period = calculate_second_continuous_state_for_all_agents(
discrete_states_beginning_of_period=discrete_states_next_period,
continuous_state_beginning_of_period=continuous_state_beginning_of_period,
params=params,
compute_continuous_state=compute_next_period_states[
"next_period_continuous_state"
],
compute_continuous_state=model_funcs_sim["next_period_continuous_state"],
)

wealth_beginning_of_next_period = (
Expand All @@ -185,7 +178,7 @@ def transition_to_next_period(
savings_end_of_previous_period=savings_current_period,
income_shocks_of_period=income_shocks_next_period,
params=params,
compute_beginning_of_period_wealth=compute_beginning_of_period_wealth,
compute_beginning_of_period_wealth=next_period_wealth,
)
)
else:
Expand All @@ -196,7 +189,7 @@ def transition_to_next_period(
savings_end_of_previous_period=savings_current_period,
income_shocks_of_period=income_shocks_next_period,
params=params,
compute_beginning_of_period_wealth=compute_beginning_of_period_wealth,
compute_beginning_of_period_wealth=next_period_wealth,
)

return (
Expand Down Expand Up @@ -263,12 +256,15 @@ def vectorized_utility(consumption_period, state, choice, params, compute_utilit
return utility


def realize_exog_process(state, choice, key, params, exog_func, exog_state_mapping):
transition_vec = exog_func(params=params, **state, choice=choice)
exog_proc_next_period = jax.random.choice(
key=key, a=transition_vec.shape[0], p=transition_vec
)
exog_states_next_period = exog_state_mapping(exog_proc_next_period)
def realize_exog_process(state, choice, key, params, processed_exog_funcs):
exog_states_next_period = {}
for exog_state_name in processed_exog_funcs.keys():
exog_state_vec = processed_exog_funcs[exog_state_name](
params=params, **state, choice=choice
)
exog_states_next_period[exog_state_name] = jax.random.choice(
key=key, a=exog_state_vec.shape[0], p=exog_state_vec
).astype(state[exog_state_name].dtype)
return exog_states_next_period


Expand Down
65 changes: 21 additions & 44 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def simulate_all_periods(
else None
)

model_structure_solution = model["model_structure"]
discrete_state_space = model_structure_solution["state_space_dict"]
model_structure_sol = model["model_structure"]
discrete_state_space = model_structure_sol["state_space_dict"]

# Set initial states to internal dtype
states_initial_dtype = {
Expand All @@ -57,7 +57,7 @@ def simulate_all_periods(
if key in discrete_state_space
}

if "dummy_exog" in model_structure_solution["exog_states_names"]:
if "dummy_exog" in model_structure_sol["exog_states_names"]:
states_initial_dtype["dummy_exog"] = np.zeros_like(
states_initial_dtype["period"]
)
Expand All @@ -77,30 +77,14 @@ def simulate_all_periods(
)
model_funcs_sim = model_sim["model_funcs"]

compute_next_period_states = {
"next_period_endogenous_state": model_funcs_sim["next_period_endogenous_state"],
"next_period_continuous_state": model_funcs_sim["next_period_continuous_state"],
}

simulate_body = partial(
simulate_single_period,
params=params,
discrete_states_names=model_structure_solution["discrete_states_names"],
endog_grid_solved=endog_grid_solved,
value_solved=value_solved,
policy_solved=policy_solved,
map_state_choice_to_index=jnp.asarray(
model_structure_solution["map_state_choice_to_index_with_proxy"]
),
choice_range=model_structure_solution["choice_range"],
compute_exog_transition_vec=model_funcs_sim["compute_exog_transition_vec"],
compute_utility=model_funcs_sim["compute_utility"],
compute_beginning_of_period_wealth=model_funcs_sim[
"compute_beginning_of_period_wealth"
],
exog_state_mapping=model_funcs_sim["exog_state_mapping"],
compute_next_period_states=compute_next_period_states,
shock_functions=model_funcs_sim["shock_functions"],
model_structure_sol=model_structure_sol,
model_funcs_sim=model_funcs_sim,
second_continuous_state_dict=second_continuous_state_dict,
)

Expand All @@ -120,9 +104,9 @@ def simulate_all_periods(
states_and_wealth_beginning_of_final_period,
sim_specific_keys=sim_specific_keys[-1],
params=params,
discrete_states_names=model_structure_solution["discrete_states_names"],
choice_range=model_structure_solution["choice_range"],
map_state_choice_to_index=model_structure_solution[
discrete_states_names=model_structure_sol["discrete_states_names"],
choice_range=model_structure_sol["choice_range"],
map_state_choice_to_index=model_structure_sol[
"map_state_choice_to_index_with_proxy"
],
compute_utility_final_period=model_funcs_sim["compute_utility_final"],
Expand All @@ -133,7 +117,7 @@ def simulate_all_periods(
key: np.vstack([sim_dict[key], final_period_dict[key]])
for key in sim_dict.keys()
}
if "dummy_exog" in model_structure_solution["exog_states_names"]:
if "dummy_exog" in model_structure_sol["exog_states_names"]:
if "dummy_exog" not in model_sim["model_structure"]["exog_states_names"]:
result.pop("dummy_exog")

Expand All @@ -144,18 +128,11 @@ def simulate_single_period(
states_and_wealth_beginning_of_period,
sim_specific_keys,
params,
discrete_states_names,
endog_grid_solved,
value_solved,
policy_solved,
map_state_choice_to_index,
choice_range,
compute_exog_transition_vec,
compute_utility,
compute_beginning_of_period_wealth,
exog_state_mapping,
compute_next_period_states,
shock_functions,
model_structure_sol,
model_funcs_sim,
second_continuous_state_dict=None,
):
(
Expand All @@ -180,6 +157,7 @@ def simulate_single_period(
continuous_state_beginning_of_period = None
continuous_grid = None

choice_range = model_structure_sol["choice_range"]
# Interpolate policy and value function for all agents.
policy, values_pre_taste_shock = interpolate_policy_and_value_for_all_agents(
discrete_states_beginning_of_period=discrete_states_beginning_of_period,
Expand All @@ -188,11 +166,13 @@ def simulate_single_period(
value_solved=value_solved,
policy_solved=policy_solved,
endog_grid_solved=endog_grid_solved,
map_state_choice_to_index=map_state_choice_to_index,
choice_range=choice_range,
map_state_choice_to_index=jnp.asarray(
model_structure_sol["map_state_choice_to_index_with_proxy"]
),
choice_range=model_structure_sol["choice_range"],
params=params,
discrete_states_names=discrete_states_names,
compute_utility=compute_utility,
discrete_states_names=model_structure_sol["discrete_states_names"],
compute_utility=model_funcs_sim["compute_utility"],
continuous_grid=continuous_grid,
)

Expand All @@ -202,7 +182,7 @@ def simulate_single_period(
n_choices=len(choice_range),
states=states_beginning_of_period,
params=params,
shock_functions=shock_functions,
shock_functions=model_funcs_sim["shock_functions"],
key=sim_specific_keys[0, :],
)
values_across_choices = values_pre_taste_shock + taste_shocks
Expand All @@ -222,7 +202,7 @@ def simulate_single_period(
states_beginning_of_period,
choice,
params,
compute_utility,
model_funcs_sim["compute_utility"],
)
savings_current_period = wealth_beginning_of_period - consumption

Expand All @@ -237,10 +217,7 @@ def simulate_single_period(
savings_current_period=savings_current_period,
choice=choice,
params=params,
compute_exog_transition_vec=compute_exog_transition_vec,
exog_state_mapping=exog_state_mapping,
compute_beginning_of_period_wealth=compute_beginning_of_period_wealth,
compute_next_period_states=compute_next_period_states,
model_funcs_sim=model_funcs_sim,
sim_specific_keys=sim_specific_keys,
)

Expand Down
15 changes: 2 additions & 13 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,11 @@ def test_simulate_lax_scan(model_setup):
simulate_body = partial(
simulate_single_period,
params=params,
discrete_states_names=discrete_states_names,
endog_grid_solved=endog_grid,
value_solved=value,
policy_solved=policy,
map_state_choice_to_index=jnp.array(map_state_choice_to_index),
choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=np.uint8),
compute_exog_transition_vec=model_funcs["compute_exog_transition_vec"],
compute_utility=model_funcs["compute_utility"],
compute_beginning_of_period_wealth=model_funcs[
"compute_beginning_of_period_wealth"
],
exog_state_mapping=exog_state_mapping,
shock_functions=model_funcs["shock_functions"],
compute_next_period_states={
"next_period_endogenous_state": next_period_endogenous_state
},
model_funcs_sim=model_funcs,
model_structure_sol=model_structure,
)

# a) lax.scan
Expand Down

0 comments on commit f3cf546

Please sign in to comment.