Skip to content

Commit

Permalink
Merge pull request #34 from masterismail/charts
Browse files Browse the repository at this point in the history
added charting for inequality metrics
  • Loading branch information
nikhilwoodruff authored Sep 25, 2024
2 parents 22c36fc + 27f961d commit 17c1cf3
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
Empty file added policyengine/charts/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions policyengine/charts/inequality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import plotly.graph_objects as go
from policyengine_core.charts.formatting import *

class InequalityImpactChart:
def __init__(self, data=None) -> None:
if data is None:
raise ValueError("Data must be provided")

# Expecting data to contain baseline, reform, change, and change_percentage for each metric
self.data = data

def generate_chart_data(self):
# Data for the x-axis labels
metrics = ["Gini index", "Top 1% share", "Top 10% share"]

# Extract the change percentages, baseline, and reform values for hover text
change_percentages = [self.data[metric]['change_percentage'] for metric in metrics]
baseline_values = [self.data[metric]['baseline'] for metric in metrics]
reform_values = [self.data[metric]['reform'] for metric in metrics]

# Generate hover text for each metric
hover_texts = [
f"The reform would increase the {metric} by {change_percentages[i]}% from {baseline_values[i]} to {reform_values[i]}%"
if change_percentages[i] > 0
else f"The reform would decrease the {metric} by {change_percentages[i]}% from {baseline_values[i]} to {reform_values[i]}%"
for i, metric in enumerate(metrics)
]

# Create the bar chart figure
fig = go.Figure()

# Add a bar trace for the change percentages of each metric
fig.add_trace(go.Bar(
x=metrics, # Labels for each metric
y=change_percentages, # Change percentages for each metric
marker=dict(
color=[BLUE if change_percentages[i] > 0 else GRAY for i in range(len(change_percentages))], # Conditional color for each bar
line=dict(width=1),
),
text=[f"{percent}%" for percent in change_percentages], # Display percentage as text
textposition='outside', # Position text outside the bars
hovertemplate=f"<b>%{{x}}</b><br><br>%{{customdata}}<extra></extra>",
customdata=hover_texts # Hover text for each bar
))

# Update layout for the chart
fig.update_layout(
yaxis=dict(
tickformat=".1f", # Show y-values with one decimal place
ticksuffix="%",
title="Relative change" # Add percentage symbol
),
hoverlabel=dict(
bgcolor="white", # Background color of the hover label
font=dict(
color="black", # Text color of the hover label
size=16, # Font size
),
),
title="Impact of Reform on Inequality Metrics" # Add a title to the chart
)

format_fig(fig)

return fig
60 changes: 58 additions & 2 deletions policyengine/economic_impact/economic_impact.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@
from .winners_and_losers.by_wealth_decile.by_wealth_decile import ByWealthDecile


from typing import Dict
from typing import Dict, Type, Union

from policyengine.charts.inequality import InequalityImpactChart


class EconomicImpact:
"""
Expand Down Expand Up @@ -170,6 +173,21 @@ def __init__(self, reform: dict, country: str, dataset: str = None) -> None:

}


self.chart_generators: Dict[str, Type] = {
"inequality": InequalityImpactChart,
}

self.composite_metrics: Dict[str, Dict[str, str]] = {
"inequality": {
"Gini index": "inequality/gini",
"Top 1% share": "inequality/top_1_pct_share",
"Top 10% share": "inequality/top_10_pct_share",
}
}

self.metric_results: Dict[str, any] = {}

def _get_simulation_class(self) -> type:
"""
Get the appropriate Microsimulation class based on the country code.
Expand Down Expand Up @@ -203,4 +221,42 @@ def calculate(self, metric: str) -> dict:
"""
if metric not in self.metric_calculators:
raise ValueError(f"Unknown metric: {metric}")
return self.metric_calculators[metric].calculate()

if metric not in self.metric_results:
result = self.metric_calculators[metric].calculate()
self.metric_results[metric] = result

return self.metric_results[metric]

def _calculate_composite_metric(self, metric: str) -> dict:
if metric not in self.composite_metrics:
raise ValueError(f"Unknown composite metric: {metric}")

composite_data = {}
for key, sub_metric in self.composite_metrics[metric].items():
composite_data[key] = self.calculate(sub_metric)

return composite_data

def chart(self, metric: str) -> dict:
if metric in self.composite_metrics:
data = self._calculate_composite_metric(metric)
elif metric in self.chart_generators:
data = self.calculate(metric)
else:
raise ValueError(f"Unknown metric for charting: {metric}")

chart_generator = self.chart_generators.get(metric)
if not chart_generator:
raise ValueError(f"No chart generator found for metric: {metric}")

return chart_generator(data=data).generate_chart_data()

def add_metric(self, metric: str, calculator: object, chart_generator: Type = None):
self.metric_calculators[metric] = calculator
if chart_generator:
self.chart_generators[metric] = chart_generator

def add_composite_metric(self, name: str, components: Dict[str, str], chart_generator: Type):
self.composite_metrics[name] = components
self.chart_generators[name] = chart_generator

0 comments on commit 17c1cf3

Please sign in to comment.