Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve comparator #227

Merged
merged 7 commits into from
Jul 19, 2023
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 96 additions & 55 deletions openfisca_france_data/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,24 @@ class AbstractComparator(object):
filter_expr_by_label = None
period = None
messages = list()
survey_name = None
survey_scenario = None

def __init__(self):
name = self.get_name()
assert name is not None and isinstance(name, str)

figures_directory = Path(config.get("paths", "figures_directory")) / name

if not figures_directory.exists():
figures_directory.mkdir(parents = True, exist_ok = True)

self.figures_directory = figures_directory

def compute_aggregates_comparison(self, input_dataframe_by_entity = None):
pass

def compute_distibution_comparison(self, input_dataframe_by_entity = None):
pass

def get_name(self):
return self.name + "_" + str(self.period)
Expand Down Expand Up @@ -339,19 +356,10 @@ def get_test_dataframes(self, rebuild = False, noindivs = None):
}
return input_dataframe_by_entity, target_dataframe_by_entity

def compare(self, browse, load, verbose, debug, target_variables = None, period = None, rebuild = False, summary = False):
def compare(self, browse, load, verbose, debug, target_variables = None, period = None, rebuild = False, summary = False, compute_divergence = False):
"""Compare actual data with openfisca-france-data computation."""
log.setLevel(level = logging.DEBUG if verbose else logging.WARNING)

name = self.get_name()

assert name is not None and isinstance(name, str)

figures_directory = Path(config.get("paths", "figures_directory")) / name

if not figures_directory.exists():
figures_directory.mkdir(parents = True, exist_ok = True)

if target_variables is not None and isinstance(target_variables, str):
target_variables = [target_variables]

Expand All @@ -360,12 +368,14 @@ def compare(self, browse, load, verbose, debug, target_variables = None, period
if target_variables is None:
target_variables = self.default_target_variables

self.target_variables = target_variables

if period is not None:
period = int(period)

backup_directory = PurePath.joinpath(Path(config.get("paths", "backup")))
backup_directory.mkdir(parents = True, exist_ok = True)

name = self.name
backup_path = PurePath.joinpath(backup_directory, f"{name}_backup.h5")

if load:
Expand All @@ -386,24 +396,36 @@ def compare(self, browse, load, verbose, debug, target_variables = None, period

log.debug(f"Test data has been prepared in {datetime.datetime.now() - start_time}")

# specific_figures_directory = PurePath.joinpath(figures_directory, self.name)
specific_figures_directory = figures_directory
specific_figures_directory.mkdir(parents = True, exist_ok = True)

result_by_variable = self.compute_divergence(
# input_dataframe_by_entity,
None, # To force load the data_table from hdf file
target_dataframe_by_entity,
specific_figures_directory,
target_variables = target_variables,
period = period,
summary = summary,
self.compute_aggregates_comparison(
input_dataframe_by_entity = input_dataframe_by_entity,
)

result = pd.concat(result_by_variable, ignore_index = True)
self.compute_distibution_comparison(input_dataframe_by_entity = input_dataframe_by_entity)

if compute_divergence:
result_by_variable, markdown_section_by_variable, markdown_summary_section_by_variable = self.compute_divergence(
input_dataframe_by_entity = None, # To force load the data_table from hdf file
target_dataframe_by_entity = target_dataframe_by_entity,
target_variables = target_variables,
period = period,
summary = summary,
)

# Deal with markdown_section

self.create_report(markdown_section_by_variable, markdown_summary_section_by_variable)

if result_by_variable is None:
return

result = pd.concat(result_by_variable, ignore_index = True)
else:
self.create_report(None,None)

log.debug(f"Eveyrthing has been computed in {datetime.datetime.now() - start_time}")
del input_dataframe_by_entity, target_dataframe_by_entity


if browse:
start_browsing_time = datetime.datetime.now()
result = result.dropna(axis = 1, how = 'all')
Expand Down Expand Up @@ -434,7 +456,8 @@ def compare(self, browse, load, verbose, debug, target_variables = None, period
pdb.post_mortem(sys.exc_info()[2])
raise error

def compute_divergence(self, input_dataframe_by_entity, target_dataframe_by_entity, figures_directory, target_variables = None, period = None, summary = False):
def compute_divergence(self, input_dataframe_by_entity, target_dataframe_by_entity,
target_variables = None, period = None, summary = False):
"""
Compare openfisca-france-data computation with data targets.

Expand All @@ -443,12 +466,12 @@ def compute_divergence(self, input_dataframe_by_entity, target_dataframe_by_enti
target_dataframe_by_period (dict): Targets to macth
figures_directory (path): Where to store the figures
"""
figures_directory = figures_directory.resolve()
figures_directory = self.figures_directory.resolve()
assert Path.exists(figures_directory)

if target_variables is None:
log.info(f"No target variables. Exiting divergence computation.")
return
return None, None, None

data = (
dict(input_dataframe_by_entity = input_dataframe_by_entity)
Expand Down Expand Up @@ -505,7 +528,6 @@ def compute_divergence(self, input_dataframe_by_entity, target_dataframe_by_enti
markdown_section_by_variable[variable] = variable_markdown_section

if summary:
# create_stats_by_period_figure(variable, result, period, figures_directory = figures_directory)
variable_markdown_summary_section = create_variable_markdown_summary_section(
variable,
stats,
Expand All @@ -514,39 +536,55 @@ def compute_divergence(self, input_dataframe_by_entity, target_dataframe_by_enti
if variable_markdown_summary_section is not None:
markdown_summary_section_by_variable[variable] = variable_markdown_summary_section

messages_markdown_section = """
return result_by_variable, markdown_section_by_variable, markdown_summary_section_by_variable

def compute_test_dataframes(self):
NotImplementedError

def create_report(self, markdown_section_by_variable, markdown_summary_section_by_variable):
figures_directory = self.figures_directory

if self.messages:
messages_markdown_section = """
Filtres appliqués:

""" + "\n".join(f"- {message}" for message in self.messages) + """
"""
with open(figures_directory / "filters.md", "w", encoding = 'utf-8') as filters_md_file:
filters_md_file.write(messages_markdown_section)

markdown_sections = list(filter(
lambda x: x is not None,
[messages_markdown_section] + list(markdown_section_by_variable.values()),
))
create_output_files(
markdown_sections,
figures_directory,
"variables",
)
if summary:
markdown_sections = list(filter(
lambda x: x is not None,
[messages_markdown_section] + list(markdown_summary_section_by_variable.values()),
))
else:
messages_markdown_section = ""

table_agregats_markdown = None
if PurePath.joinpath(figures_directory, "table_agregats.md").exists():
with open(figures_directory / "table_agregats.md", "r", encoding = 'utf-8') as table_agregats_md_file:
table_agregats_markdown = table_agregats_md_file.read()

distribution_comparison_markdown = None
if PurePath.joinpath(figures_directory, "distribution_comparison_md").exists():
with open(figures_directory / "distribution_comparison_md", "r", encoding = 'utf-8') as distribution_comparison_md_file:
distribution_comparison_markdown = distribution_comparison_md_file.read()

front_sections = [messages_markdown_section, table_agregats_markdown, distribution_comparison_markdown]
sections_by_filename = {
"variables": markdown_section_by_variable,
"summary_variables": markdown_summary_section_by_variable
}

for filename, section_by_variable in sections_by_filename.items():
if section_by_variable is not None:
markdown_sections = list(filter(
lambda x: x is not None,
front_sections + list(section_by_variable.values()),
))
else:
markdown_sections = list(filter(
lambda x: x is not None,front_sections
))
create_output_files(
markdown_sections,
figures_directory,
"summary_variables",
filename,
)

return result_by_variable

def compute_test_dataframes(self):
NotImplementedError

def filter(self, data_frame):
for label, filter_expr in self.filter_expr_by_label.items():
obs_before = data_frame.noind.nunique()
Expand All @@ -561,8 +599,11 @@ def filter(self, data_frame):
log.info(log_message)
self.messages.append(log_message + "\n")

def get_survey_scenario(self, data = None):
survey_name = self.survey_name
def get_survey_scenario(self, data = None, survey_name = None):

if self.survey_scenario is not None:
return self.survey_scenario

return get_survey_scenario(
year = str(self.period),
data = data,
Expand Down