Skip to content

Commit 17c1cf3

Browse files
Merge pull request #34 from masterismail/charts
added charting for inequality metrics
2 parents 22c36fc + 27f961d commit 17c1cf3

File tree

3 files changed

+123
-2
lines changed

3 files changed

+123
-2
lines changed

policyengine/charts/__init__.py

Whitespace-only changes.

policyengine/charts/inequality.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import plotly.graph_objects as go
2+
from policyengine_core.charts.formatting import *
3+
4+
class InequalityImpactChart:
5+
def __init__(self, data=None) -> None:
6+
if data is None:
7+
raise ValueError("Data must be provided")
8+
9+
# Expecting data to contain baseline, reform, change, and change_percentage for each metric
10+
self.data = data
11+
12+
def generate_chart_data(self):
13+
# Data for the x-axis labels
14+
metrics = ["Gini index", "Top 1% share", "Top 10% share"]
15+
16+
# Extract the change percentages, baseline, and reform values for hover text
17+
change_percentages = [self.data[metric]['change_percentage'] for metric in metrics]
18+
baseline_values = [self.data[metric]['baseline'] for metric in metrics]
19+
reform_values = [self.data[metric]['reform'] for metric in metrics]
20+
21+
# Generate hover text for each metric
22+
hover_texts = [
23+
f"The reform would increase the {metric} by {change_percentages[i]}% from {baseline_values[i]} to {reform_values[i]}%"
24+
if change_percentages[i] > 0
25+
else f"The reform would decrease the {metric} by {change_percentages[i]}% from {baseline_values[i]} to {reform_values[i]}%"
26+
for i, metric in enumerate(metrics)
27+
]
28+
29+
# Create the bar chart figure
30+
fig = go.Figure()
31+
32+
# Add a bar trace for the change percentages of each metric
33+
fig.add_trace(go.Bar(
34+
x=metrics, # Labels for each metric
35+
y=change_percentages, # Change percentages for each metric
36+
marker=dict(
37+
color=[BLUE if change_percentages[i] > 0 else GRAY for i in range(len(change_percentages))], # Conditional color for each bar
38+
line=dict(width=1),
39+
),
40+
text=[f"{percent}%" for percent in change_percentages], # Display percentage as text
41+
textposition='outside', # Position text outside the bars
42+
hovertemplate=f"<b>%{{x}}</b><br><br>%{{customdata}}<extra></extra>",
43+
customdata=hover_texts # Hover text for each bar
44+
))
45+
46+
# Update layout for the chart
47+
fig.update_layout(
48+
yaxis=dict(
49+
tickformat=".1f", # Show y-values with one decimal place
50+
ticksuffix="%",
51+
title="Relative change" # Add percentage symbol
52+
),
53+
hoverlabel=dict(
54+
bgcolor="white", # Background color of the hover label
55+
font=dict(
56+
color="black", # Text color of the hover label
57+
size=16, # Font size
58+
),
59+
),
60+
title="Impact of Reform on Inequality Metrics" # Add a title to the chart
61+
)
62+
63+
format_fig(fig)
64+
65+
return fig

policyengine/economic_impact/economic_impact.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@
7676
from .winners_and_losers.by_wealth_decile.by_wealth_decile import ByWealthDecile
7777

7878

79-
from typing import Dict
79+
from typing import Dict, Type, Union
80+
81+
from policyengine.charts.inequality import InequalityImpactChart
82+
8083

8184
class EconomicImpact:
8285
"""
@@ -170,6 +173,21 @@ def __init__(self, reform: dict, country: str, dataset: str = None) -> None:
170173

171174
}
172175

176+
177+
self.chart_generators: Dict[str, Type] = {
178+
"inequality": InequalityImpactChart,
179+
}
180+
181+
self.composite_metrics: Dict[str, Dict[str, str]] = {
182+
"inequality": {
183+
"Gini index": "inequality/gini",
184+
"Top 1% share": "inequality/top_1_pct_share",
185+
"Top 10% share": "inequality/top_10_pct_share",
186+
}
187+
}
188+
189+
self.metric_results: Dict[str, any] = {}
190+
173191
def _get_simulation_class(self) -> type:
174192
"""
175193
Get the appropriate Microsimulation class based on the country code.
@@ -203,4 +221,42 @@ def calculate(self, metric: str) -> dict:
203221
"""
204222
if metric not in self.metric_calculators:
205223
raise ValueError(f"Unknown metric: {metric}")
206-
return self.metric_calculators[metric].calculate()
224+
225+
if metric not in self.metric_results:
226+
result = self.metric_calculators[metric].calculate()
227+
self.metric_results[metric] = result
228+
229+
return self.metric_results[metric]
230+
231+
def _calculate_composite_metric(self, metric: str) -> dict:
232+
if metric not in self.composite_metrics:
233+
raise ValueError(f"Unknown composite metric: {metric}")
234+
235+
composite_data = {}
236+
for key, sub_metric in self.composite_metrics[metric].items():
237+
composite_data[key] = self.calculate(sub_metric)
238+
239+
return composite_data
240+
241+
def chart(self, metric: str) -> dict:
242+
if metric in self.composite_metrics:
243+
data = self._calculate_composite_metric(metric)
244+
elif metric in self.chart_generators:
245+
data = self.calculate(metric)
246+
else:
247+
raise ValueError(f"Unknown metric for charting: {metric}")
248+
249+
chart_generator = self.chart_generators.get(metric)
250+
if not chart_generator:
251+
raise ValueError(f"No chart generator found for metric: {metric}")
252+
253+
return chart_generator(data=data).generate_chart_data()
254+
255+
def add_metric(self, metric: str, calculator: object, chart_generator: Type = None):
256+
self.metric_calculators[metric] = calculator
257+
if chart_generator:
258+
self.chart_generators[metric] = chart_generator
259+
260+
def add_composite_metric(self, name: str, components: Dict[str, str], chart_generator: Type):
261+
self.composite_metrics[name] = components
262+
self.chart_generators[name] = chart_generator

0 commit comments

Comments
 (0)