From 560b8fa14a77866ea3e94bec5e4e599b22478fab Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Fri, 20 Dec 2024 17:09:18 +0000 Subject: [PATCH] Fix Parliamentary constituencies chart returns errors #67 --- .../parliamentary_constituencies.py | 2 +- policyengine/simulation.py | 22 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/policyengine/outputs/macro/single/gov/local_areas/parliamentary_constituencies.py b/policyengine/outputs/macro/single/gov/local_areas/parliamentary_constituencies.py index 44e2ff8..9761c9c 100644 --- a/policyengine/outputs/macro/single/gov/local_areas/parliamentary_constituencies.py +++ b/policyengine/outputs/macro/single/gov/local_areas/parliamentary_constituencies.py @@ -81,7 +81,7 @@ def parliamentary_constituencies( ) if chart: - return plot_hex_map(result, "local_authorities") + return plot_hex_map(result, "parliamentary_constituencies") if code_index: return { diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 7cfb035..bf6f7b7 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -16,14 +16,15 @@ import pandas as pd import h5py from pathlib import Path +from typing import Literal class Simulation: """The top-level class through which all PE usage is carried out.""" - country: str + country: Literal["uk", "us"] """The country for which the simulation is being run.""" - scope: str + scope: Literal["macro", "household"] """The type of simulation being run (macro or household).""" data: dict | str | Dataset """The dataset being used for the simulation.""" @@ -49,10 +50,12 @@ class Simulation: def __init__( self, - country: str, - scope: str, + country: Literal["uk", "us"], + scope: Literal["macro", "household"], data: str | dict | None = None, - time_period: str | None = None, + time_period: str | None = Literal[ + 2024, 2025, 2026, 2027, 2028, 2029, 2030 + ], reform: dict | None = None, baseline: dict | None = None, verbose: bool = False, @@ -141,7 +144,7 @@ def calculate(self, output: str, force: bool = False, **kwargs) -> Any: if child_key not in parent: try: is_numeric_key = int(child_key) in parent - except KeyError: + except ValueError: is_numeric_key = False if is_numeric_key: child_key = int(child_key) @@ -162,9 +165,12 @@ def calculate(self, output: str, force: bool = False, **kwargs) -> Any: force or parent[child_key] is None or len(kwargs) > 0 ) and output in self.output_functions: output_function = self.output_functions[output] - parent[child_key] = node = output_function(self, **kwargs) + node = output_function(self, **kwargs) + if len(kwargs) == 0: + # Only save as part of the larger tree if no non-standard args are passed + parent[child_key] = node - if isinstance(node, dict): + if isinstance(node, dict) and len(kwargs) == 0: for child_key in node.keys(): self.calculate(output + "/" + str(child_key))