76
76
from .winners_and_losers .by_wealth_decile .by_wealth_decile import ByWealthDecile
77
77
78
78
79
- from typing import Dict
79
+ from typing import Dict , Type , Union
80
+
81
+ from policyengine .charts .inequality import InequalityImpactChart
82
+
80
83
81
84
class EconomicImpact :
82
85
"""
@@ -170,6 +173,21 @@ def __init__(self, reform: dict, country: str, dataset: str = None) -> None:
170
173
171
174
}
172
175
176
+
177
+ self .chart_generators : Dict [str , Type ] = {
178
+ "inequality" : InequalityImpactChart ,
179
+ }
180
+
181
+ self .composite_metrics : Dict [str , Dict [str , str ]] = {
182
+ "inequality" : {
183
+ "Gini index" : "inequality/gini" ,
184
+ "Top 1% share" : "inequality/top_1_pct_share" ,
185
+ "Top 10% share" : "inequality/top_10_pct_share" ,
186
+ }
187
+ }
188
+
189
+ self .metric_results : Dict [str , any ] = {}
190
+
173
191
def _get_simulation_class (self ) -> type :
174
192
"""
175
193
Get the appropriate Microsimulation class based on the country code.
@@ -203,4 +221,42 @@ def calculate(self, metric: str) -> dict:
203
221
"""
204
222
if metric not in self .metric_calculators :
205
223
raise ValueError (f"Unknown metric: { metric } " )
206
- return self .metric_calculators [metric ].calculate ()
224
+
225
+ if metric not in self .metric_results :
226
+ result = self .metric_calculators [metric ].calculate ()
227
+ self .metric_results [metric ] = result
228
+
229
+ return self .metric_results [metric ]
230
+
231
+ def _calculate_composite_metric (self , metric : str ) -> dict :
232
+ if metric not in self .composite_metrics :
233
+ raise ValueError (f"Unknown composite metric: { metric } " )
234
+
235
+ composite_data = {}
236
+ for key , sub_metric in self .composite_metrics [metric ].items ():
237
+ composite_data [key ] = self .calculate (sub_metric )
238
+
239
+ return composite_data
240
+
241
+ def chart (self , metric : str ) -> dict :
242
+ if metric in self .composite_metrics :
243
+ data = self ._calculate_composite_metric (metric )
244
+ elif metric in self .chart_generators :
245
+ data = self .calculate (metric )
246
+ else :
247
+ raise ValueError (f"Unknown metric for charting: { metric } " )
248
+
249
+ chart_generator = self .chart_generators .get (metric )
250
+ if not chart_generator :
251
+ raise ValueError (f"No chart generator found for metric: { metric } " )
252
+
253
+ return chart_generator (data = data ).generate_chart_data ()
254
+
255
+ def add_metric (self , metric : str , calculator : object , chart_generator : Type = None ):
256
+ self .metric_calculators [metric ] = calculator
257
+ if chart_generator :
258
+ self .chart_generators [metric ] = chart_generator
259
+
260
+ def add_composite_metric (self , name : str , components : Dict [str , str ], chart_generator : Type ):
261
+ self .composite_metrics [name ] = components
262
+ self .chart_generators [name ] = chart_generator
0 commit comments