From bffe8276804592d2c85b24bc1b7710dfe91283b8 Mon Sep 17 00:00:00 2001 From: Farhad Allian Date: Fri, 26 Jul 2024 13:48:49 +0100 Subject: [PATCH] fix: tests in pandas version > 2 --- causal_testing/surrogate/causal_surrogate_assisted.py | 11 ++++++++--- .../surrogate_tests/test_causal_surrogate_assisted.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/causal_testing/surrogate/causal_surrogate_assisted.py b/causal_testing/surrogate/causal_surrogate_assisted.py index 74f309be..8c0f157f 100644 --- a/causal_testing/surrogate/causal_surrogate_assisted.py +++ b/causal_testing/surrogate/causal_surrogate_assisted.py @@ -8,7 +8,7 @@ from causal_testing.specification.causal_specification import CausalSpecification from causal_testing.testing.base_test_case import BaseTestCase from causal_testing.testing.estimators import CubicSplineRegressionEstimator - +import pandas as pd @dataclass class SimulationResult: @@ -18,6 +18,11 @@ class SimulationResult: fault: bool relationship: str + def to_dataframe(self) -> pd.DataFrame: + """Convert the simulation result data to a pandas DataFrame""" + data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()} + return pd.DataFrame(data_as_lists) + class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods """Class to be inherited with the search algorithm consisting of a search function and the fitness function of the @@ -87,14 +92,14 @@ def execute( self.simulator.startup() test_result = self.simulator.run_with_config(candidate_test_case) + test_result_df = test_result.to_dataframe() self.simulator.shutdown() if custom_data_aggregator is not None: if data_collector.data is not None: data_collector.data = custom_data_aggregator(data_collector.data, test_result.data) else: - data_collector.data = data_collector.data.append(test_result.data, ignore_index=True) - + data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True) if test_result.fault: print( f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with " diff --git a/tests/surrogate_tests/test_causal_surrogate_assisted.py b/tests/surrogate_tests/test_causal_surrogate_assisted.py index c5eb6e2c..54c93af1 100644 --- a/tests/surrogate_tests/test_causal_surrogate_assisted.py +++ b/tests/surrogate_tests/test_causal_surrogate_assisted.py @@ -231,4 +231,7 @@ def shutdown(self): pass def data_double_aggregator(data, new_data): - return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True) \ No newline at end of file + """Previously used data.append(new_data), however, pandas version >2 requires pd.concat() since append is now a private method. + Converting new_data to a pd.DataFrame is required to use pd.concat(). """ + new_data = pd.DataFrame([new_data]) + return pd.concat([data, new_data, new_data], ignore_index=True)