From f3cf5464cf1e5a4b865e720bc03d40f069f4ecac Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Wed, 22 Jan 2025 12:37:38 +0100 Subject: [PATCH] Fixed sim. --- src/dcegm/simulation/sim_utils.py | 38 ++++++++---------- src/dcegm/simulation/simulate.py | 65 ++++++++++--------------------- tests/test_simulate.py | 15 +------ 3 files changed, 40 insertions(+), 78 deletions(-) diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index 86447d55..5b27a6c7 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -127,29 +127,23 @@ def transition_to_next_period( savings_current_period, choice, params, - compute_exog_transition_vec, - exog_state_mapping, - compute_beginning_of_period_wealth, - compute_next_period_states, + model_funcs_sim, sim_specific_keys, ): n_agents = savings_current_period.shape[0] - exog_states_next_period = vmap( - realize_exog_process, in_axes=(0, 0, 0, None, None, None) - )( + exog_states_next_period = vmap(realize_exog_process, in_axes=(0, 0, 0, None, None))( discrete_states_beginning_of_period, choice, sim_specific_keys[2:, :], params, - compute_exog_transition_vec, - exog_state_mapping, + model_funcs_sim["processed_exog_funcs"], ) discrete_endog_states_next_period = vmap( update_discrete_states_for_one_agent, in_axes=(None, 0, 0, None) # choice )( - compute_next_period_states["next_period_endogenous_state"], + model_funcs_sim["next_period_endogenous_state"], discrete_states_beginning_of_period, choice, params, @@ -167,15 +161,14 @@ def transition_to_next_period( key=sim_specific_keys[1, :], num_agents=n_agents, mean=0, std=params["sigma"] ) + next_period_wealth = model_funcs_sim["compute_beginning_of_period_wealth"] if continuous_state_beginning_of_period is not None: continuous_state_next_period = calculate_second_continuous_state_for_all_agents( discrete_states_beginning_of_period=discrete_states_next_period, continuous_state_beginning_of_period=continuous_state_beginning_of_period, params=params, - compute_continuous_state=compute_next_period_states[ - "next_period_continuous_state" - ], + compute_continuous_state=model_funcs_sim["next_period_continuous_state"], ) wealth_beginning_of_next_period = ( @@ -185,7 +178,7 @@ def transition_to_next_period( savings_end_of_previous_period=savings_current_period, income_shocks_of_period=income_shocks_next_period, params=params, - compute_beginning_of_period_wealth=compute_beginning_of_period_wealth, + compute_beginning_of_period_wealth=next_period_wealth, ) ) else: @@ -196,7 +189,7 @@ def transition_to_next_period( savings_end_of_previous_period=savings_current_period, income_shocks_of_period=income_shocks_next_period, params=params, - compute_beginning_of_period_wealth=compute_beginning_of_period_wealth, + compute_beginning_of_period_wealth=next_period_wealth, ) return ( @@ -263,12 +256,15 @@ def vectorized_utility(consumption_period, state, choice, params, compute_utilit return utility -def realize_exog_process(state, choice, key, params, exog_func, exog_state_mapping): - transition_vec = exog_func(params=params, **state, choice=choice) - exog_proc_next_period = jax.random.choice( - key=key, a=transition_vec.shape[0], p=transition_vec - ) - exog_states_next_period = exog_state_mapping(exog_proc_next_period) +def realize_exog_process(state, choice, key, params, processed_exog_funcs): + exog_states_next_period = {} + for exog_state_name in processed_exog_funcs.keys(): + exog_state_vec = processed_exog_funcs[exog_state_name]( + params=params, **state, choice=choice + ) + exog_states_next_period[exog_state_name] = jax.random.choice( + key=key, a=exog_state_vec.shape[0], p=exog_state_vec + ).astype(state[exog_state_name].dtype) return exog_states_next_period diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index d8f19b48..13ea6763 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -47,8 +47,8 @@ def simulate_all_periods( else None ) - model_structure_solution = model["model_structure"] - discrete_state_space = model_structure_solution["state_space_dict"] + model_structure_sol = model["model_structure"] + discrete_state_space = model_structure_sol["state_space_dict"] # Set initial states to internal dtype states_initial_dtype = { @@ -57,7 +57,7 @@ def simulate_all_periods( if key in discrete_state_space } - if "dummy_exog" in model_structure_solution["exog_states_names"]: + if "dummy_exog" in model_structure_sol["exog_states_names"]: states_initial_dtype["dummy_exog"] = np.zeros_like( states_initial_dtype["period"] ) @@ -77,30 +77,14 @@ def simulate_all_periods( ) model_funcs_sim = model_sim["model_funcs"] - compute_next_period_states = { - "next_period_endogenous_state": model_funcs_sim["next_period_endogenous_state"], - "next_period_continuous_state": model_funcs_sim["next_period_continuous_state"], - } - simulate_body = partial( simulate_single_period, params=params, - discrete_states_names=model_structure_solution["discrete_states_names"], endog_grid_solved=endog_grid_solved, value_solved=value_solved, policy_solved=policy_solved, - map_state_choice_to_index=jnp.asarray( - model_structure_solution["map_state_choice_to_index_with_proxy"] - ), - choice_range=model_structure_solution["choice_range"], - compute_exog_transition_vec=model_funcs_sim["compute_exog_transition_vec"], - compute_utility=model_funcs_sim["compute_utility"], - compute_beginning_of_period_wealth=model_funcs_sim[ - "compute_beginning_of_period_wealth" - ], - exog_state_mapping=model_funcs_sim["exog_state_mapping"], - compute_next_period_states=compute_next_period_states, - shock_functions=model_funcs_sim["shock_functions"], + model_structure_sol=model_structure_sol, + model_funcs_sim=model_funcs_sim, second_continuous_state_dict=second_continuous_state_dict, ) @@ -120,9 +104,9 @@ def simulate_all_periods( states_and_wealth_beginning_of_final_period, sim_specific_keys=sim_specific_keys[-1], params=params, - discrete_states_names=model_structure_solution["discrete_states_names"], - choice_range=model_structure_solution["choice_range"], - map_state_choice_to_index=model_structure_solution[ + discrete_states_names=model_structure_sol["discrete_states_names"], + choice_range=model_structure_sol["choice_range"], + map_state_choice_to_index=model_structure_sol[ "map_state_choice_to_index_with_proxy" ], compute_utility_final_period=model_funcs_sim["compute_utility_final"], @@ -133,7 +117,7 @@ def simulate_all_periods( key: np.vstack([sim_dict[key], final_period_dict[key]]) for key in sim_dict.keys() } - if "dummy_exog" in model_structure_solution["exog_states_names"]: + if "dummy_exog" in model_structure_sol["exog_states_names"]: if "dummy_exog" not in model_sim["model_structure"]["exog_states_names"]: result.pop("dummy_exog") @@ -144,18 +128,11 @@ def simulate_single_period( states_and_wealth_beginning_of_period, sim_specific_keys, params, - discrete_states_names, endog_grid_solved, value_solved, policy_solved, - map_state_choice_to_index, - choice_range, - compute_exog_transition_vec, - compute_utility, - compute_beginning_of_period_wealth, - exog_state_mapping, - compute_next_period_states, - shock_functions, + model_structure_sol, + model_funcs_sim, second_continuous_state_dict=None, ): ( @@ -180,6 +157,7 @@ def simulate_single_period( continuous_state_beginning_of_period = None continuous_grid = None + choice_range = model_structure_sol["choice_range"] # Interpolate policy and value function for all agents. policy, values_pre_taste_shock = interpolate_policy_and_value_for_all_agents( discrete_states_beginning_of_period=discrete_states_beginning_of_period, @@ -188,11 +166,13 @@ def simulate_single_period( value_solved=value_solved, policy_solved=policy_solved, endog_grid_solved=endog_grid_solved, - map_state_choice_to_index=map_state_choice_to_index, - choice_range=choice_range, + map_state_choice_to_index=jnp.asarray( + model_structure_sol["map_state_choice_to_index_with_proxy"] + ), + choice_range=model_structure_sol["choice_range"], params=params, - discrete_states_names=discrete_states_names, - compute_utility=compute_utility, + discrete_states_names=model_structure_sol["discrete_states_names"], + compute_utility=model_funcs_sim["compute_utility"], continuous_grid=continuous_grid, ) @@ -202,7 +182,7 @@ def simulate_single_period( n_choices=len(choice_range), states=states_beginning_of_period, params=params, - shock_functions=shock_functions, + shock_functions=model_funcs_sim["shock_functions"], key=sim_specific_keys[0, :], ) values_across_choices = values_pre_taste_shock + taste_shocks @@ -222,7 +202,7 @@ def simulate_single_period( states_beginning_of_period, choice, params, - compute_utility, + model_funcs_sim["compute_utility"], ) savings_current_period = wealth_beginning_of_period - consumption @@ -237,10 +217,7 @@ def simulate_single_period( savings_current_period=savings_current_period, choice=choice, params=params, - compute_exog_transition_vec=compute_exog_transition_vec, - exog_state_mapping=exog_state_mapping, - compute_beginning_of_period_wealth=compute_beginning_of_period_wealth, - compute_next_period_states=compute_next_period_states, + model_funcs_sim=model_funcs_sim, sim_specific_keys=sim_specific_keys, ) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 32bb8122..3ef890a8 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -108,22 +108,11 @@ def test_simulate_lax_scan(model_setup): simulate_body = partial( simulate_single_period, params=params, - discrete_states_names=discrete_states_names, endog_grid_solved=endog_grid, value_solved=value, policy_solved=policy, - map_state_choice_to_index=jnp.array(map_state_choice_to_index), - choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=np.uint8), - compute_exog_transition_vec=model_funcs["compute_exog_transition_vec"], - compute_utility=model_funcs["compute_utility"], - compute_beginning_of_period_wealth=model_funcs[ - "compute_beginning_of_period_wealth" - ], - exog_state_mapping=exog_state_mapping, - shock_functions=model_funcs["shock_functions"], - compute_next_period_states={ - "next_period_endogenous_state": next_period_endogenous_state - }, + model_funcs_sim=model_funcs, + model_structure_sol=model_structure, ) # a) lax.scan