diff --git a/README.md b/README.md index 53d41a90..ac648c37 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,8 @@ The library offers the following classes: * Washover for switchback experiments: * `EmptyWashover`: no washover done at all. * `ConstantWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval. + * `TwoSidedWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval, and also before the switch. + * `SimmetricWashover`: same as two sided, but with the same timedelta for both sides. * Regarding analysis: * `GeeExperimentAnalysis`: to run GEE analysis on the results of a clustered design * `TTestClusteredAnalysis`: to run a t-test on aggregated data for clusters diff --git a/cluster_experiments/__init__.py b/cluster_experiments/__init__.py index 379f1e54..5bc416a5 100644 --- a/cluster_experiments/__init__.py +++ b/cluster_experiments/__init__.py @@ -24,7 +24,13 @@ StratifiedSwitchbackSplitter, SwitchbackSplitter, ) -from cluster_experiments.washover import ConstantWashover, EmptyWashover, Washover +from cluster_experiments.washover import ( + ConstantWashover, + EmptyWashover, + SimmetricWashover, + TwoSidedWashover, + Washover, +) __all__ = [ "ExperimentAnalysis", @@ -50,5 +56,7 @@ "PairedTTestClusteredAnalysis", "EmptyWashover", "ConstantWashover", + "TwoSidedWashover", + "SimmetricWashover", "Washover", ] diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index 72928a61..c0e50bcd 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -82,12 +82,17 @@ def washover( return df -class ConstantWashover(Washover): - """Constant washover - we drop all rows in the washover period when - there is a switch where the treatment is different.""" +class TwoSidedWashover(Washover): + """Two sided washover - we drop all rows before and after the switch within + the time deltas when there is a switch where the treatment is different.""" - def __init__(self, washover_time_delta: datetime.timedelta): - self.washover_time_delta = washover_time_delta + def __init__( + self, + washover_time_delta_before: datetime.timedelta, + washover_time_delta_after: datetime.timedelta, + ): + self.washover_time_delta_before = washover_time_delta_before + self.washover_time_delta_after = washover_time_delta_after def washover( self, @@ -97,7 +102,7 @@ def washover( cluster_cols: List[str], original_time_col: Optional[str] = None, ) -> pd.DataFrame: - """No washover - returns the same dataframe as input. + """Two sided washover - removes rows around the switch. Args: df (pd.DataFrame): Input dataframe. @@ -114,7 +119,7 @@ def washover( from cluster_experiments import SwitchbackSplitter from cluster_experiments import ConstantWashover - washover = ConstantWashover(washover_time_delta=datetime.timedelta(minutes=30)) + washover = TwoSidedWashover(washover_time_delta=datetime.timedelta(minutes=30)) n = 10 df = pd.DataFrame( @@ -145,15 +150,34 @@ def washover( if original_time_col else _original_time_column(truncated_time_col) ) + # Cluster columns that do not involve time non_time_cols = list(set(cluster_cols) - set([truncated_time_col])) + # For each cluster, we need to check if treatment has changed wrt last time df_agg = df.drop_duplicates(subset=cluster_cols + [treatment_col]).copy() df_agg["__changed"] = ( df_agg.groupby(non_time_cols)[treatment_col].shift(1) != df_agg[treatment_col] - ) - df_agg = df_agg.loc[:, cluster_cols + ["__changed"]] + ) & df_agg.groupby(non_time_cols)[treatment_col].shift(1).notnull() + + # We also check if treatment changes for the next time + df_agg["__changed_next"] = ( + df_agg.groupby(non_time_cols)[treatment_col].shift(-1) + != df_agg[treatment_col] + ) & df_agg.groupby(non_time_cols)[treatment_col].shift(-1).notnull() + + # Calculate switch start of the next time + df_agg[f"__next_{truncated_time_col}"] = df_agg.groupby(non_time_cols)[ + truncated_time_col + ].shift(-1) + + # Clean switch df + df_agg = df_agg.loc[ + :, + cluster_cols + + ["__changed", "__changed_next", f"__next_{truncated_time_col}"], + ] return ( df.merge(df_agg, on=cluster_cols, how="inner") .assign( @@ -161,24 +185,88 @@ def washover( "datetime64[ns]" ) - x[truncated_time_col].astype("datetime64[ns]"), - __after_washover=lambda x: x["__time_since_switch"] - > self.washover_time_delta, + __time_to_next_switch=lambda x: x[ + f"__next_{truncated_time_col}" + ].astype("datetime64[ns]") + - x[original_time_col].astype("datetime64[ns]"), + __after_washover=lambda x: ( + x["__time_since_switch"] > self.washover_time_delta_after + ), + __before_washover=lambda x: ( + x["__time_to_next_switch"] > self.washover_time_delta_before + ), ) - # add not changed in query + # if no change or too late after change, don't drop .query("__after_washover or not __changed") - .drop(columns=["__time_since_switch", "__after_washover", "__changed"]) + # if no change to next switch or too early before change, don't drop + .query("__before_washover or not __changed_next") + .drop( + columns=[ + "__time_since_switch", + "__time_to_next_switch", + "__after_washover", + "__before_washover", + "__changed", + "__changed_next", + f"__next_{truncated_time_col}", + ] + ) ) @classmethod def from_config(cls, config) -> "Washover": - if not config.washover_time_delta: + if ( + not config.washover_time_delta_before + or not config.washover_time_delta_after + ): raise ValueError( - f"Washover time delta must be specified for ConstantWashover, while it is {config.washover_time_delta = }" + f"Washover time deltas must be specified for , while it is {config.washover_time_delta_before = } and {config.washover_time_delta_after = }" ) return cls( - washover_time_delta=config.washover_time_delta, + washover_time_delta_before=config.washover_time_delta, + washover_time_delta_after=config.washover_time_delta, ) +class ConstantWashover(TwoSidedWashover): + """Constant washover - we drop all rows in the washover period after + the switch when the treatment is different.""" + + def __init__(self, washover_time_delta: datetime.timedelta): + super().__init__(datetime.timedelta(seconds=0), washover_time_delta) + + @classmethod + def from_config(cls, config) -> "Washover": + if not config.washover_time_delta: + raise ValueError( + f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" + ) + return cls(washover_time_delta=config.washover_time_delta) + + +class SimmetricWashover(TwoSidedWashover): + """Simmetric washover - we drop all rows in the washover period before + and after the switch when the treatment is different.""" + + def __init__(self, washover_time_delta: datetime.timedelta): + super().__init__( + washover_time_delta_before=washover_time_delta, + washover_time_delta_after=washover_time_delta, + ) + + @classmethod + def from_config(cls, config) -> "Washover": + if not config.washover_time_delta: + raise ValueError( + f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" + ) + return cls(washover_time_delta=config.washover_time_delta) + + # This is kept in here because of circular imports, need to rethink this -washover_mapping = {"": EmptyWashover, "constant_washover": ConstantWashover} +washover_mapping = { + "": EmptyWashover, + "constant_washover": ConstantWashover, + "two_sided_washover": TwoSidedWashover, + "simmetric_washover": SimmetricWashover, +} diff --git a/tests/splitter/conftest.py b/tests/splitter/conftest.py index 1cf12fb9..4f9958c1 100644 --- a/tests/splitter/conftest.py +++ b/tests/splitter/conftest.py @@ -161,6 +161,23 @@ def washover_base_df(): return df +@pytest.fixture +def simmetric_washover_base_df(): + df = pd.DataFrame( + { + "original___time": [ + pd.to_datetime("2022-01-01 00:20:00"), + pd.to_datetime("2022-01-01 00:49:00"), + pd.to_datetime("2022-01-01 01:14:00"), + pd.to_datetime("2022-01-01 01:31:00"), + ], + "treatment": ["A", "A", "B", "B"], + "city": ["TGN"] * 4, + } + ).assign(time=lambda x: x["original___time"].dt.floor("1h", ambiguous="infer")) + return df + + @pytest.fixture def washover_df_no_switch(): df = pd.DataFrame( diff --git a/tests/splitter/test_washover.py b/tests/splitter/test_washover.py index e3dcb24d..017c2f16 100644 --- a/tests/splitter/test_washover.py +++ b/tests/splitter/test_washover.py @@ -3,10 +3,15 @@ import pytest from cluster_experiments import SwitchbackSplitter -from cluster_experiments.washover import ConstantWashover, EmptyWashover +from cluster_experiments.washover import ( + ConstantWashover, + EmptyWashover, + SimmetricWashover, + TwoSidedWashover, +) -@pytest.mark.parametrize("minutes, n_rows", [(30, 2), (10, 4), (15, 3)]) +@pytest.mark.parametrize("minutes, n_rows", [(30, 3), (10, 4), (15, 3)]) def test_constant_washover_base(minutes, n_rows, washover_base_df): out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover( @@ -18,14 +23,47 @@ def test_constant_washover_base(minutes, n_rows, washover_base_df): ) assert len(out_df) == n_rows - assert (out_df["original___time"].dt.minute > minutes).all() + + +@pytest.mark.parametrize("minutes, n_rows", [(30, 2), (15, 2), (12, 3), (10, 4)]) +def test_simmetric_washover_base(minutes, n_rows, simmetric_washover_base_df): + + out_df = SimmetricWashover(washover_time_delta=timedelta(minutes=minutes)).washover( + df=simmetric_washover_base_df, + truncated_time_col="time", + cluster_cols=["city", "time"], + treatment_col="treatment", + ) + + assert len(out_df) == n_rows + + +@pytest.mark.parametrize( + "minutes_before, minutes_after, n_rows", + [(30, 30, 2), (10, 10, 4), (15, 15, 2), (12, 15, 2)], +) +def test_2_sided_washover( + minutes_before, minutes_after, n_rows, simmetric_washover_base_df +): + + out_df = TwoSidedWashover( + washover_time_delta_before=timedelta(minutes=minutes_before), + washover_time_delta_after=timedelta(minutes=minutes_after), + ).washover( + df=simmetric_washover_base_df, + truncated_time_col="time", + cluster_cols=["city", "time"], + treatment_col="treatment", + ) + + assert len(out_df) == n_rows @pytest.mark.parametrize( "minutes, n_rows, df", [ - (30, 4, "washover_df_no_switch"), - (30, 4 + 4, "washover_df_multi_city"), + (30, 5, "washover_df_no_switch"), + (30, 5 + 5, "washover_df_multi_city"), ], ) def test_constant_washover_no_switch(minutes, n_rows, df, request): diff --git a/tests/test_docs.py b/tests/test_docs.py index cd94660d..a9b1bfe7 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -19,11 +19,13 @@ PowerAnalysis, PowerConfig, RandomSplitter, + SimmetricWashover, StratifiedClusteredSplitter, StratifiedSwitchbackSplitter, SwitchbackSplitter, TargetAggregation, TTestClusteredAnalysis, + TwoSidedWashover, UniformPerturbator, ) from cluster_experiments.utils import _original_time_column @@ -49,6 +51,8 @@ UniformPerturbator, _original_time_column, ConstantWashover, + TwoSidedWashover, + SimmetricWashover, EmptyWashover, BalancedSwitchbackSplitter, StratifiedSwitchbackSplitter,