Skip to content

Commit

Permalink
fix: tests in pandas version > 2
Browse files Browse the repository at this point in the history
  • Loading branch information
f-allian committed Jul 26, 2024
1 parent cff1776 commit bffe827
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
11 changes: 8 additions & 3 deletions causal_testing/surrogate/causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.estimators import CubicSplineRegressionEstimator

import pandas as pd

@dataclass
class SimulationResult:
Expand All @@ -18,6 +18,11 @@ class SimulationResult:
fault: bool
relationship: str

def to_dataframe(self) -> pd.DataFrame:
"""Convert the simulation result data to a pandas DataFrame"""
data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()}
return pd.DataFrame(data_as_lists)


class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
Expand Down Expand Up @@ -87,14 +92,14 @@ def execute(

self.simulator.startup()
test_result = self.simulator.run_with_config(candidate_test_case)
test_result_df = test_result.to_dataframe()
self.simulator.shutdown()

if custom_data_aggregator is not None:
if data_collector.data is not None:
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
else:
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)

data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True)
if test_result.fault:
print(
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
Expand Down
5 changes: 4 additions & 1 deletion tests/surrogate_tests/test_causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,7 @@ def shutdown(self):
pass

def data_double_aggregator(data, new_data):
return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True)
"""Previously used data.append(new_data), however, pandas version >2 requires pd.concat() since append is now a private method.
Converting new_data to a pd.DataFrame is required to use pd.concat(). """
new_data = pd.DataFrame([new_data])
return pd.concat([data, new_data, new_data], ignore_index=True)

0 comments on commit bffe827

Please sign in to comment.