Skip to content

Commit 4301f2d

Browse files
Split simulations into chunks (#1938)
* Add progress * Add progress * Progress * Add progress in adding parallelisation! * Add final fixes * Add two workers * Add error handling * Fix bug * Add error handling strength * Fix bug * Turn down memory to 32gb * Fix bug * Fix bugs * Download microdata first
1 parent 4546ddb commit 4301f2d

File tree

11 files changed

+193
-18
lines changed

11 files changed

+193
-18
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ debug:
77
test:
88
MAX_HOUSEHOLDS=1000 pytest tests
99

10+
microdata:
11+
python policyengine_api/download_microdata.py
12+
1013
debug-test:
1114
MAX_HOUSEHOLDS=1000 FLASK_DEBUG=1 pytest -vv --durations=0 tests
1215

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Chunking and baseline/reform parallelisation.

gcp/policyengine_api/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ ADD . /app
1414
# Make start.sh executable
1515
RUN chmod +x /app/start.sh
1616

17-
RUN cd /app && make install && make test
17+
RUN cd /app && make install && make microdata && make test
1818

1919
# Use full path to start.sh
2020
CMD ["/bin/sh", "/app/start.sh"]

gcp/policyengine_api/app.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
runtime: custom
22
env: flex
33
resources:
4-
cpu: 24
5-
memory_gb: 128
6-
disk_size_gb: 128
4+
cpu: 16
5+
memory_gb: 32
6+
disk_size_gb: 64
77
automatic_scaling:
88
min_num_instances: 1
99
max_num_instances: 1

gcp/policyengine_api/start.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@ gunicorn -b :$PORT policyengine_api.api --timeout 300 --workers 5 &
33
# Start the redis server
44
redis-server &
55
# Start the worker
6+
python3 policyengine_api/worker.py &
7+
python3 policyengine_api/worker.py &
8+
python3 policyengine_api/worker.py &
9+
python3 policyengine_api/worker.py &
610
python3 policyengine_api/worker.py

policyengine_api/data/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,6 @@ def initialize(self):
175175
if os.environ.get("FLASK_DEBUG") == "1":
176176
database = PolicyEngineDatabase(local=True, initialize=False)
177177
else:
178-
database = PolicyEngineDatabase(local=False, initialize=False)
178+
database = PolicyEngineDatabase(local=True, initialize=False)
179179

180180
local_database = PolicyEngineDatabase(local=True, initialize=False)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from policyengine_us_data import EnhancedCPS_2024, CPS_2024
2+
from policyengine_uk_data import EnhancedFRS_2022_23
3+
4+
DATASETS = [EnhancedCPS_2024, CPS_2024, EnhancedFRS_2022_23]
5+
6+
7+
def download_microdata():
8+
for dataset in DATASETS:
9+
dataset = dataset()
10+
if not dataset.exists:
11+
dataset.download()
12+
13+
14+
if __name__ == "__main__":
15+
download_microdata()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import time
2+
from tqdm import tqdm
3+
import numpy as np
4+
5+
6+
def calc_chunks(variables=None, count_chunks=5, logger=None, sim=None):
7+
for i in range(len(variables)):
8+
if isinstance(variables[i], str):
9+
variables[i] = (variables[i], sim.default_calculation_period)
10+
variables = [
11+
(variable, time_period)
12+
for variable, time_period in variables
13+
if variable in sim.tax_benefit_system.variables
14+
]
15+
if count_chunks > 1:
16+
households = sim.calculate("household_id", 2024).values
17+
chunk_size = len(households) // count_chunks + 1
18+
input_df = sim.to_input_dataframe()
19+
20+
variable_data = {
21+
variable: np.array([]) for variable, time_period in variables
22+
}
23+
24+
for i in tqdm(range(count_chunks)):
25+
if logger is not None:
26+
pct_complete = i / count_chunks
27+
logger(pct_complete)
28+
households_in_chunk = households[
29+
i * chunk_size : (i + 1) * chunk_size
30+
]
31+
chunk_df = input_df[
32+
input_df["household_id__2024"].isin(households_in_chunk)
33+
]
34+
35+
subset_sim = type(sim)(dataset=chunk_df, reform=sim.reform)
36+
subset_sim.default_calculation_period = (
37+
sim.default_calculation_period
38+
)
39+
40+
for variable, time_period in variables:
41+
chunk_values = subset_sim.calculate(
42+
variable, time_period
43+
).values
44+
variable_data[variable] = np.concatenate(
45+
[variable_data[variable], chunk_values]
46+
)
47+
48+
for variable, time_period in variables:
49+
sim.set_input(variable, time_period, variable_data[variable])
50+
51+
return sim

policyengine_api/endpoints/economy/reform_impact.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from policyengine_api.utils import hash_object
1414
from datetime import datetime
1515
import traceback
16+
from rq import Queue
17+
from rq.job import Job
18+
from redis import Redis
19+
import time
20+
21+
queue = Queue(connection=Redis())
1622

1723

1824
def ensure_economy_computed(
@@ -166,24 +172,43 @@ def set_reform_impact_data_routine(
166172
),
167173
)
168174
comment = lambda x: set_comment_on_job(x, *identifiers)
169-
comment("Computing baseline")
170-
baseline_economy = compute_economy(
171-
country_id,
172-
policy_id,
175+
176+
baseline_economy = queue.enqueue(
177+
compute_economy,
178+
country_id=country_id,
179+
policy_id=baseline_policy_id,
173180
region=region,
174181
time_period=time_period,
175182
options=options,
176183
policy_json=baseline_policy,
177184
)
178-
comment("Computing reform")
179-
reform_economy = compute_economy(
180-
country_id,
181-
policy_id,
185+
reform_economy = queue.enqueue(
186+
compute_economy,
187+
country_id=country_id,
188+
policy_id=policy_id,
182189
region=region,
183190
time_period=time_period,
184191
options=options,
185192
policy_json=reform_policy,
186193
)
194+
while baseline_economy.get_status() in ("queued", "started"):
195+
time.sleep(1)
196+
while reform_economy.get_status() in ("queued", "started"):
197+
time.sleep(1)
198+
if reform_economy.get_status() != "finished":
199+
reform_economy = {
200+
"status": "error",
201+
"message": "Error computing reform economy.",
202+
}
203+
else:
204+
reform_economy = reform_economy.result
205+
if baseline_economy.get_status() != "finished":
206+
baseline_economy = {
207+
"status": "error",
208+
"message": "Error computing baseline economy.",
209+
}
210+
else:
211+
baseline_economy = baseline_economy.result
187212
if baseline_economy["status"] != "ok" or reform_economy["status"] != "ok":
188213
local_database.query(
189214
"UPDATE reform_impact SET status = ?, message = ?, end_time = ?, reform_impact_json = ? WHERE country_id = ? AND reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND time_period = ? AND options_hash = ?",

policyengine_api/endpoints/economy/single_economy.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,46 @@
77
from policyengine_uk import Microsimulation
88
import time
99
import os
10+
from policyengine_api.endpoints.economy.chunks import calc_chunks
1011

1112

1213
def compute_general_economy(
13-
simulation: Microsimulation, country_id: str = None
14+
simulation: Microsimulation,
15+
country_id: str = None,
16+
simulation_type: str = None,
17+
comment=None,
1418
) -> dict:
19+
variables = [
20+
"labor_supply_behavioral_response",
21+
"employment_income_behavioral_response",
22+
"household_tax",
23+
"household_benefits",
24+
"household_state_income_tax",
25+
"weekly_hours_worked_behavioural_response_income_elasticity",
26+
"weekly_hours_worked_behavioural_response_substitution_elasticity",
27+
"household_net_income",
28+
"household_market_income",
29+
"in_poverty",
30+
"in_deep_poverty",
31+
"poverty_gap",
32+
"deep_poverty_gap",
33+
"income_tax",
34+
"national_insurance",
35+
"vat",
36+
"council_tax",
37+
"fuel_duty",
38+
"tax_credits",
39+
"universal_credit",
40+
"child_benefit",
41+
"state_pension",
42+
"pension_credit",
43+
"ni_employer",
44+
]
45+
# chunk_logger = lambda pct_complete: comment(f"Simulation {simulation_type}: {pct_complete:.0%}")
46+
calc_chunks(
47+
sim=simulation, variables=variables, count_chunks=4, logger=None
48+
)
49+
1550
total_tax = simulation.calculate("household_tax").sum()
1651
total_spending = simulation.calculate("household_benefits").sum()
1752

@@ -226,7 +261,7 @@ def compute_cliff_impact(
226261
}
227262

228263

229-
def compute_economy(
264+
def get_microsimulation(
230265
country_id: str,
231266
policy_id: str,
232267
region: str,
@@ -311,10 +346,41 @@ def compute_economy(
311346
"person_weight"
312347
).get_known_periods():
313348
simulation.delete_arrays("person_weight", time_period)
349+
print(f"Initialised simulation in {time.time() - start} seconds")
350+
351+
return simulation
314352

353+
354+
def compute_economy(
355+
country_id: str,
356+
policy_id: str,
357+
region: str,
358+
time_period: str,
359+
options: dict,
360+
policy_json: dict,
361+
simulation_type: str = None,
362+
comment=None,
363+
):
364+
simulation = get_microsimulation(
365+
country_id,
366+
policy_id,
367+
region,
368+
time_period,
369+
options,
370+
policy_json,
371+
)
315372
if options.get("target") == "cliff":
316373
return compute_cliff_impact(simulation)
317-
print(f"Initialised simulation in {time.time() - start} seconds")
318-
economy = compute_general_economy(simulation, country_id=country_id)
374+
start = time.time()
375+
try:
376+
economy = compute_general_economy(
377+
simulation,
378+
country_id=country_id,
379+
simulation_type=simulation_type,
380+
comment=comment,
381+
)
382+
except Exception as e:
383+
print(f"Error in economy computation: {e}")
384+
return {"status": "error", "message": str(e)}
319385
print(f"Computed economy in {time.time() - start} seconds")
320386
return {"status": "ok", "result": economy}

0 commit comments

Comments
 (0)