From 5531176364eda3361c7e4598539279b4f35118b8 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 9 Mar 2023 14:54:37 +0100 Subject: [PATCH 1/6] simmetric mvp - tests not working --- README.md | 1 + cluster_experiments/__init__.py | 8 ++- cluster_experiments/washover.py | 111 ++++++++++++++++++++++++++++++++ tests/splitter/conftest.py | 17 +++++ tests/splitter/test_washover.py | 22 ++++++- tests/test_docs.py | 2 + 6 files changed, 159 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 53d41a90..0846e226 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,7 @@ The library offers the following classes: * Washover for switchback experiments: * `EmptyWashover`: no washover done at all. * `ConstantWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval. + * `SimmetricWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval, and also before the switch. * Regarding analysis: * `GeeExperimentAnalysis`: to run GEE analysis on the results of a clustered design * `TTestClusteredAnalysis`: to run a t-test on aggregated data for clusters diff --git a/cluster_experiments/__init__.py b/cluster_experiments/__init__.py index 379f1e54..ab042060 100644 --- a/cluster_experiments/__init__.py +++ b/cluster_experiments/__init__.py @@ -24,7 +24,12 @@ StratifiedSwitchbackSplitter, SwitchbackSplitter, ) -from cluster_experiments.washover import ConstantWashover, EmptyWashover, Washover +from cluster_experiments.washover import ( + ConstantWashover, + EmptyWashover, + SimmetricWashover, + Washover, +) __all__ = [ "ExperimentAnalysis", @@ -50,5 +55,6 @@ "PairedTTestClusteredAnalysis", "EmptyWashover", "ConstantWashover", + "SimmetricWashover", "Washover", ] diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index 661c1f9a..b4194f8c 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -169,5 +169,116 @@ def from_config(cls, config) -> "Washover": ) +class SimmetricWashover(Washover): + """Simmetric washover - we drop all rows before and after the switch within + the time delta when there is a switch where the treatment is different.""" + + def __init__(self, washover_time_delta: datetime.timedelta): + self.washover_time_delta = washover_time_delta + + def washover( + self, + df: pd.DataFrame, + time_col: str, + treatment_col: str, + cluster_cols: List[str], + ) -> pd.DataFrame: + """No washover - returns the same dataframe as input. + + Args: + df (pd.DataFrame): Input dataframe. + time_col (str): Name of the time column. + treatment_col (str): Name of the treatment column. + cluster_cols (List[str]): List of clusters of experiment. + + Returns: + pd.DataFrame: Same dataframe as input. + + Usage: + ```python + from cluster_experiments import SwitchbackSplitter + from cluster_experiments import ConstantWashover + + washover = SimmetricWashover(washover_time_delta=datetime.timedelta(minutes=30)) + + n = 10 + df = pd.DataFrame( + { + # Random time each minute in 2022-01-01, length 10 + "time": pd.date_range("2022-01-01", "2022-01-02", freq="1min")[ + np.random.randint(24 * 60, size=n) + ], + "city": random.choices(["TGN", "NYC", "LON", "REU"], k=n), + } + ) + + + splitter = SwitchbackSplitter( + washover=washover, + time_col="time", + cluster_cols=["city", "time"], + treatment_col="treatment", + switch_frequency="30T", + ) + + out_df = splitter.assign_treatment_df(df=washover_split_df) + + """ + # Cluster columns that do not involve time + non_time_cols = list(set(cluster_cols) - set([time_col])) + # For each cluster, we need to check if treatment has changed wrt last time + df_agg = df.drop_duplicates(subset=cluster_cols + [treatment_col]).copy() + df_agg["__changed"] = ( + df_agg.groupby(non_time_cols)[treatment_col].shift(1) + != df_agg[treatment_col] + ) + df_agg = df_agg.loc[:, cluster_cols + ["__changed"]] + + print( + df.merge(df_agg, on=cluster_cols, how="inner").assign( + __time_since_switch=lambda x: x[_original_time_column(time_col)].astype( + "datetime64[ns]" + ) + - x[time_col].astype("datetime64[ns]"), + __before_or_after_washover=lambda x: ( + x["__time_since_switch"] > self.washover_time_delta + ) + | (x["__time_since_switch"] < -self.washover_time_delta), + ) + ) + return ( + df.merge(df_agg, on=cluster_cols, how="inner") + .assign( + __time_since_switch=lambda x: x[_original_time_column(time_col)].astype( + "datetime64[ns]" + ) + - x[time_col].astype("datetime64[ns]"), + __before_or_after_washover=lambda x: ( + x["__time_since_switch"] > self.washover_time_delta + ) + | (x["__time_since_switch"] < -self.washover_time_delta), + ) + # add not changed in query + .query("__before_or_after_washover or not __changed") + .drop( + columns=[ + "__time_since_switch", + "__before_or_after_washover", + "__changed", + ] + ) + ) + + @classmethod + def from_config(cls, config) -> "Washover": + if not config.washover_time_delta: + raise ValueError( + f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" + ) + return cls( + washover_time_delta=config.washover_time_delta, + ) + + # This is kept in here because of circular imports, need to rethink this washover_mapping = {"": EmptyWashover, "constant_washover": ConstantWashover} diff --git a/tests/splitter/conftest.py b/tests/splitter/conftest.py index 1cf12fb9..54415b0c 100644 --- a/tests/splitter/conftest.py +++ b/tests/splitter/conftest.py @@ -161,6 +161,23 @@ def washover_base_df(): return df +@pytest.fixture +def simmetric_washover_base_df(): + df = pd.DataFrame( + { + "original___time": [ + pd.to_datetime("2022-01-01 00:20:00"), + pd.to_datetime("2022-01-01 00:31:00"), + pd.to_datetime("2022-01-01 01:14:00"), + pd.to_datetime("2022-01-01 01:31:00"), + ], + "treatment": ["A", "A", "B", "B"], + "city": ["TGN"] * 4, + } + ).assign(time=lambda x: x["original___time"].dt.floor("1h", ambiguous="infer")) + return df + + @pytest.fixture def washover_df_no_switch(): df = pd.DataFrame( diff --git a/tests/splitter/test_washover.py b/tests/splitter/test_washover.py index 8526bd56..2eb72637 100644 --- a/tests/splitter/test_washover.py +++ b/tests/splitter/test_washover.py @@ -3,7 +3,11 @@ import pytest from cluster_experiments import SwitchbackSplitter -from cluster_experiments.washover import ConstantWashover, EmptyWashover +from cluster_experiments.washover import ( + ConstantWashover, + EmptyWashover, + SimmetricWashover, +) @pytest.mark.parametrize("minutes, n_rows", [(30, 2), (10, 4), (15, 3)]) @@ -20,6 +24,22 @@ def test_constant_washover_base(minutes, n_rows, washover_base_df): assert (out_df["original___time"].dt.minute > minutes).all() +# @pytest.mark.parametrize("minutes, n_rows", [(30, 1), (10, 4), (15, 3)]) +@pytest.mark.parametrize("minutes, n_rows", [(30, 1)]) +def test_simmetric_washover_base(minutes, n_rows, washover_base_df): + + out_df = SimmetricWashover(washover_time_delta=timedelta(minutes=minutes)).washover( + df=washover_base_df, + time_col="time", + cluster_cols=["city", "time"], + treatment_col="treatment", + ) + + print(out_df, minutes, n_rows) + assert len(out_df) == n_rows + assert (out_df["original___time"].dt.minute > minutes).all() + + @pytest.mark.parametrize( "minutes, n_rows, df", [ diff --git a/tests/test_docs.py b/tests/test_docs.py index cd94660d..01348c8f 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -19,6 +19,7 @@ PowerAnalysis, PowerConfig, RandomSplitter, + SimmetricWashover, StratifiedClusteredSplitter, StratifiedSwitchbackSplitter, SwitchbackSplitter, @@ -49,6 +50,7 @@ UniformPerturbator, _original_time_column, ConstantWashover, + SimmetricWashover, EmptyWashover, BalancedSwitchbackSplitter, StratifiedSwitchbackSplitter, From 648dd6ea30276f70a638552ef99667e2cbe185db Mon Sep 17 00:00:00 2001 From: David Date: Thu, 9 Mar 2023 14:59:31 +0100 Subject: [PATCH 2/6] next switch --- cluster_experiments/washover.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index b4194f8c..a84d914c 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -253,6 +253,9 @@ def washover( "datetime64[ns]" ) - x[time_col].astype("datetime64[ns]"), + # TODO: we have to get the next switch time + __time_to_next_switch=lambda x: x[time_col].astype("datetime64[ns]") + - x[_original_time_column(time_col)].astype("datetime64[ns]"), __before_or_after_washover=lambda x: ( x["__time_since_switch"] > self.washover_time_delta ) From 7df02a9070e518db022bbcb8abc96a1ac78ec30a Mon Sep 17 00:00:00 2001 From: David Date: Fri, 10 Mar 2023 11:25:27 +0100 Subject: [PATCH 3/6] add simmetric washover --- cluster_experiments/washover.py | 87 +++++++++++++++++++++------------ tests/splitter/conftest.py | 2 +- tests/splitter/test_washover.py | 9 ++-- 3 files changed, 62 insertions(+), 36 deletions(-) diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index a0c579ac..9d5aea91 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -97,7 +97,7 @@ def washover( cluster_cols: List[str], original_time_col: Optional[str] = None, ) -> pd.DataFrame: - """No washover - returns the same dataframe as input. + """Constant washover - removes rows after the switch. Args: df (pd.DataFrame): Input dataframe. @@ -190,17 +190,19 @@ def __init__(self, washover_time_delta: datetime.timedelta): def washover( self, df: pd.DataFrame, - time_col: str, + truncated_time_col: str, treatment_col: str, cluster_cols: List[str], + original_time_col: Optional[str] = None, ) -> pd.DataFrame: - """No washover - returns the same dataframe as input. + """Simmetric washover - removes rows simmetrically around the switch. Args: df (pd.DataFrame): Input dataframe. - time_col (str): Name of the time column. + truncated_time_col (str): Name of the truncated time column. treatment_col (str): Name of the treatment column. cluster_cols (List[str]): List of clusters of experiment. + original_time_col (Optional[str], optional): Name of the original time column. Returns: pd.DataFrame: Same dataframe as input. @@ -235,50 +237,71 @@ def washover( out_df = splitter.assign_treatment_df(df=washover_split_df) """ + # Set original time column + original_time_col = ( + original_time_col + if original_time_col + else _original_time_column(truncated_time_col) + ) + # Cluster columns that do not involve time - non_time_cols = list(set(cluster_cols) - set([time_col])) + non_time_cols = list(set(cluster_cols) - set([truncated_time_col])) + # For each cluster, we need to check if treatment has changed wrt last time df_agg = df.drop_duplicates(subset=cluster_cols + [treatment_col]).copy() df_agg["__changed"] = ( df_agg.groupby(non_time_cols)[treatment_col].shift(1) != df_agg[treatment_col] ) - df_agg = df_agg.loc[:, cluster_cols + ["__changed"]] - print( - df.merge(df_agg, on=cluster_cols, how="inner").assign( - __time_since_switch=lambda x: x[_original_time_column(time_col)].astype( - "datetime64[ns]" - ) - - x[time_col].astype("datetime64[ns]"), - __before_or_after_washover=lambda x: ( - x["__time_since_switch"] > self.washover_time_delta - ) - | (x["__time_since_switch"] < -self.washover_time_delta), - ) - ) + # We also check if treatment changes for the next time + df_agg["__changed_next"] = ( + df_agg.groupby(non_time_cols)[treatment_col].shift(-1) + != df_agg[treatment_col] + ) & df_agg.groupby(non_time_cols)[treatment_col].shift(-1).notnull() + + # Calculate switch start of the next time + df_agg[f"__next_{truncated_time_col}"] = df_agg.groupby(non_time_cols)[ + truncated_time_col + ].shift(-1) + + # Clean switch df + df_agg = df_agg.loc[ + :, + cluster_cols + + ["__changed", "__changed_next", f"__next_{truncated_time_col}"], + ] return ( df.merge(df_agg, on=cluster_cols, how="inner") .assign( - __time_since_switch=lambda x: x[_original_time_column(time_col)].astype( + __time_since_switch=lambda x: x[original_time_col].astype( "datetime64[ns]" ) - - x[time_col].astype("datetime64[ns]"), - # TODO: we have to get the next switch time - __time_to_next_switch=lambda x: x[time_col].astype("datetime64[ns]") - - x[_original_time_column(time_col)].astype("datetime64[ns]"), - __before_or_after_washover=lambda x: ( + - x[truncated_time_col].astype("datetime64[ns]"), + __time_to_next_switch=lambda x: x[ + f"__next_{truncated_time_col}" + ].astype("datetime64[ns]") + - x[original_time_col].astype("datetime64[ns]"), + __after_washover=lambda x: ( x["__time_since_switch"] > self.washover_time_delta - ) - | (x["__time_since_switch"] < -self.washover_time_delta), + ), + __before_washover=lambda x: ( + x["__time_to_next_switch"] > self.washover_time_delta + ), ) - # add not changed in query - .query("__before_or_after_washover or not __changed") + # if no change or too late after change, don't drop + .query("__after_washover or not __changed") + # if no change to next switch or too early before change, don't drop + .query("__before_washover or not __changed_next") .drop( columns=[ "__time_since_switch", - "__before_or_after_washover", + "__time_to_next_switch", + "__after_washover", + "__before_washover", "__changed", + "__changed_next", + f"__next_{truncated_time_col}", ] ) ) @@ -295,4 +318,8 @@ def from_config(cls, config) -> "Washover": # This is kept in here because of circular imports, need to rethink this -washover_mapping = {"": EmptyWashover, "constant_washover": ConstantWashover} +washover_mapping = { + "": EmptyWashover, + "constant_washover": ConstantWashover, + "simmetric_washover": SimmetricWashover, +} diff --git a/tests/splitter/conftest.py b/tests/splitter/conftest.py index 54415b0c..4f9958c1 100644 --- a/tests/splitter/conftest.py +++ b/tests/splitter/conftest.py @@ -167,7 +167,7 @@ def simmetric_washover_base_df(): { "original___time": [ pd.to_datetime("2022-01-01 00:20:00"), - pd.to_datetime("2022-01-01 00:31:00"), + pd.to_datetime("2022-01-01 00:49:00"), pd.to_datetime("2022-01-01 01:14:00"), pd.to_datetime("2022-01-01 01:31:00"), ], diff --git a/tests/splitter/test_washover.py b/tests/splitter/test_washover.py index e1514822..ec44b0fe 100644 --- a/tests/splitter/test_washover.py +++ b/tests/splitter/test_washover.py @@ -25,13 +25,12 @@ def test_constant_washover_base(minutes, n_rows, washover_base_df): assert (out_df["original___time"].dt.minute > minutes).all() -# @pytest.mark.parametrize("minutes, n_rows", [(30, 1), (10, 4), (15, 3)]) -@pytest.mark.parametrize("minutes, n_rows", [(30, 1)]) -def test_simmetric_washover_base(minutes, n_rows, washover_base_df): +@pytest.mark.parametrize("minutes, n_rows", [(30, 1), (15, 2), (12, 3), (10, 4)]) +def test_simmetric_washover_base(minutes, n_rows, simmetric_washover_base_df): out_df = SimmetricWashover(washover_time_delta=timedelta(minutes=minutes)).washover( - df=washover_base_df, - time_col="time", + df=simmetric_washover_base_df, + truncated_time_col="time", cluster_cols=["city", "time"], treatment_col="treatment", ) From 8ef02822deddad334f24fdb35bdbdff091c612a9 Mon Sep 17 00:00:00 2001 From: David Date: Fri, 10 Mar 2023 11:42:14 +0100 Subject: [PATCH 4/6] add 2-sided --- README.md | 3 ++- cluster_experiments/__init__.py | 2 ++ cluster_experiments/washover.py | 43 +++++++++++++++++++++++++-------- tests/splitter/test_washover.py | 25 +++++++++++++++++-- tests/test_docs.py | 2 ++ 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0846e226..ac648c37 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,8 @@ The library offers the following classes: * Washover for switchback experiments: * `EmptyWashover`: no washover done at all. * `ConstantWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval. - * `SimmetricWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval, and also before the switch. + * `TwoSidedWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval, and also before the switch. + * `SimmetricWashover`: same as two sided, but with the same timedelta for both sides. * Regarding analysis: * `GeeExperimentAnalysis`: to run GEE analysis on the results of a clustered design * `TTestClusteredAnalysis`: to run a t-test on aggregated data for clusters diff --git a/cluster_experiments/__init__.py b/cluster_experiments/__init__.py index ab042060..5bc416a5 100644 --- a/cluster_experiments/__init__.py +++ b/cluster_experiments/__init__.py @@ -28,6 +28,7 @@ ConstantWashover, EmptyWashover, SimmetricWashover, + TwoSidedWashover, Washover, ) @@ -55,6 +56,7 @@ "PairedTTestClusteredAnalysis", "EmptyWashover", "ConstantWashover", + "TwoSidedWashover", "SimmetricWashover", "Washover", ] diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index 9d5aea91..a85e4159 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -180,12 +180,17 @@ def from_config(cls, config) -> "Washover": ) -class SimmetricWashover(Washover): - """Simmetric washover - we drop all rows before and after the switch within +class TwoSidedWashover(Washover): + """Two sided washover - we drop all rows before and after the switch within the time delta when there is a switch where the treatment is different.""" - def __init__(self, washover_time_delta: datetime.timedelta): - self.washover_time_delta = washover_time_delta + def __init__( + self, + washover_time_delta_before: datetime.timedelta, + washover_time_delta_after: datetime.timedelta, + ): + self.washover_time_delta_before = washover_time_delta_before + self.washover_time_delta_after = washover_time_delta_after def washover( self, @@ -195,7 +200,7 @@ def washover( cluster_cols: List[str], original_time_col: Optional[str] = None, ) -> pd.DataFrame: - """Simmetric washover - removes rows simmetrically around the switch. + """Two sided washover - removes rows simmetrically around the switch. Args: df (pd.DataFrame): Input dataframe. @@ -212,7 +217,7 @@ def washover( from cluster_experiments import SwitchbackSplitter from cluster_experiments import ConstantWashover - washover = SimmetricWashover(washover_time_delta=datetime.timedelta(minutes=30)) + washover = TwoSidedWashover(washover_time_delta=datetime.timedelta(minutes=30)) n = 10 df = pd.DataFrame( @@ -252,7 +257,7 @@ def washover( df_agg["__changed"] = ( df_agg.groupby(non_time_cols)[treatment_col].shift(1) != df_agg[treatment_col] - ) + ) & df_agg.groupby(non_time_cols)[treatment_col].shift(1).notnull() # We also check if treatment changes for the next time df_agg["__changed_next"] = ( @@ -283,10 +288,10 @@ def washover( ].astype("datetime64[ns]") - x[original_time_col].astype("datetime64[ns]"), __after_washover=lambda x: ( - x["__time_since_switch"] > self.washover_time_delta + x["__time_since_switch"] > self.washover_time_delta_after ), __before_washover=lambda x: ( - x["__time_to_next_switch"] > self.washover_time_delta + x["__time_to_next_switch"] > self.washover_time_delta_before ), ) # if no change or too late after change, don't drop @@ -313,13 +318,31 @@ def from_config(cls, config) -> "Washover": f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" ) return cls( - washover_time_delta=config.washover_time_delta, + washover_time_delta_before=config.washover_time_delta, + washover_time_delta_after=config.washover_time_delta, + ) + + +class SimmetricWashover(TwoSidedWashover): + def __init__(self, washover_time_delta: datetime.timedelta): + super().__init__( + washover_time_delta_before=washover_time_delta, + washover_time_delta_after=washover_time_delta, ) + @classmethod + def from_config(cls, config) -> "Washover": + if not config.washover_time_delta: + raise ValueError( + f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" + ) + return cls(washover_time_delta=config.washover_time_delta) + # This is kept in here because of circular imports, need to rethink this washover_mapping = { "": EmptyWashover, "constant_washover": ConstantWashover, + "two_sided_washover": TwoSidedWashover, "simmetric_washover": SimmetricWashover, } diff --git a/tests/splitter/test_washover.py b/tests/splitter/test_washover.py index ec44b0fe..43ee05ff 100644 --- a/tests/splitter/test_washover.py +++ b/tests/splitter/test_washover.py @@ -7,6 +7,7 @@ ConstantWashover, EmptyWashover, SimmetricWashover, + TwoSidedWashover, ) @@ -25,7 +26,7 @@ def test_constant_washover_base(minutes, n_rows, washover_base_df): assert (out_df["original___time"].dt.minute > minutes).all() -@pytest.mark.parametrize("minutes, n_rows", [(30, 1), (15, 2), (12, 3), (10, 4)]) +@pytest.mark.parametrize("minutes, n_rows", [(30, 2), (15, 2), (12, 3), (10, 4)]) def test_simmetric_washover_base(minutes, n_rows, simmetric_washover_base_df): out_df = SimmetricWashover(washover_time_delta=timedelta(minutes=minutes)).washover( @@ -35,11 +36,31 @@ def test_simmetric_washover_base(minutes, n_rows, simmetric_washover_base_df): treatment_col="treatment", ) - print(out_df, minutes, n_rows) assert len(out_df) == n_rows assert (out_df["original___time"].dt.minute > minutes).all() +@pytest.mark.parametrize( + "minutes_before, minutes_after, n_rows", + [(30, 30, 2), (10, 10, 4), (15, 15, 2), (12, 15, 2)], +) +def test_2_sided_washover( + minutes_before, minutes_after, n_rows, simmetric_washover_base_df +): + + out_df = TwoSidedWashover( + washover_time_delta_before=timedelta(minutes=minutes_before), + washover_time_delta_after=timedelta(minutes=minutes_after), + ).washover( + df=simmetric_washover_base_df, + truncated_time_col="time", + cluster_cols=["city", "time"], + treatment_col="treatment", + ) + + assert len(out_df) == n_rows + + @pytest.mark.parametrize( "minutes, n_rows, df", [ diff --git a/tests/test_docs.py b/tests/test_docs.py index 01348c8f..a9b1bfe7 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -25,6 +25,7 @@ SwitchbackSplitter, TargetAggregation, TTestClusteredAnalysis, + TwoSidedWashover, UniformPerturbator, ) from cluster_experiments.utils import _original_time_column @@ -50,6 +51,7 @@ UniformPerturbator, _original_time_column, ConstantWashover, + TwoSidedWashover, SimmetricWashover, EmptyWashover, BalancedSwitchbackSplitter, From b42f4977dc789fba51a63887c856c3959fa7df1d Mon Sep 17 00:00:00 2001 From: David Date: Fri, 10 Mar 2023 11:46:02 +0100 Subject: [PATCH 5/6] refactor constant washover --- cluster_experiments/washover.py | 114 +++++--------------------------- tests/splitter/test_washover.py | 8 +-- 2 files changed, 19 insertions(+), 103 deletions(-) diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index a85e4159..4d407ef0 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -82,104 +82,6 @@ def washover( return df -class ConstantWashover(Washover): - """Constant washover - we drop all rows in the washover period when - there is a switch where the treatment is different.""" - - def __init__(self, washover_time_delta: datetime.timedelta): - self.washover_time_delta = washover_time_delta - - def washover( - self, - df: pd.DataFrame, - truncated_time_col: str, - treatment_col: str, - cluster_cols: List[str], - original_time_col: Optional[str] = None, - ) -> pd.DataFrame: - """Constant washover - removes rows after the switch. - - Args: - df (pd.DataFrame): Input dataframe. - truncated_time_col (str): Name of the truncated time column. - treatment_col (str): Name of the treatment column. - cluster_cols (List[str]): List of clusters of experiment. - original_time_col (Optional[str], optional): Name of the original time column. - - Returns: - pd.DataFrame: Same dataframe as input. - - Usage: - ```python - from cluster_experiments import SwitchbackSplitter - from cluster_experiments import ConstantWashover - - washover = ConstantWashover(washover_time_delta=datetime.timedelta(minutes=30)) - - n = 10 - df = pd.DataFrame( - { - # Random time each minute in 2022-01-01, length 10 - "time": pd.date_range("2022-01-01", "2022-01-02", freq="1min")[ - np.random.randint(24 * 60, size=n) - ], - "city": random.choices(["TGN", "NYC", "LON", "REU"], k=n), - } - ) - - - splitter = SwitchbackSplitter( - washover=washover, - time_col="time", - cluster_cols=["city", "time"], - treatment_col="treatment", - switch_frequency="30T", - ) - - out_df = splitter.assign_treatment_df(df=washover_split_df) - - """ - # Set original time column - original_time_col = ( - original_time_col - if original_time_col - else _original_time_column(truncated_time_col) - ) - # Cluster columns that do not involve time - non_time_cols = list(set(cluster_cols) - set([truncated_time_col])) - # For each cluster, we need to check if treatment has changed wrt last time - df_agg = df.drop_duplicates(subset=cluster_cols + [treatment_col]).copy() - df_agg["__changed"] = ( - df_agg.groupby(non_time_cols)[treatment_col].shift(1) - != df_agg[treatment_col] - ) - df_agg = df_agg.loc[:, cluster_cols + ["__changed"]] - return ( - df.merge(df_agg, on=cluster_cols, how="inner") - .assign( - __time_since_switch=lambda x: x[original_time_col].astype( - "datetime64[ns]" - ) - - x[truncated_time_col].astype("datetime64[ns]"), - __after_washover=lambda x: x["__time_since_switch"] - > self.washover_time_delta, - ) - # add not changed in query - .query("__after_washover or not __changed") - .drop(columns=["__time_since_switch", "__after_washover", "__changed"]) - ) - - @classmethod - def from_config(cls, config) -> "Washover": - if not config.washover_time_delta: - raise ValueError( - f"Washover time delta must be specified for ConstantWashover, while it is {config.washover_time_delta = }" - ) - return cls( - washover_time_delta=config.washover_time_delta, - ) - - class TwoSidedWashover(Washover): """Two sided washover - we drop all rows before and after the switch within the time delta when there is a switch where the treatment is different.""" @@ -323,6 +225,22 @@ def from_config(cls, config) -> "Washover": ) +class ConstantWashover(TwoSidedWashover): + """Constant washover - we drop all rows in the washover period when + there is a switch where the treatment is different.""" + + def __init__(self, washover_time_delta: datetime.timedelta): + super().__init__(datetime.timedelta(seconds=0), washover_time_delta) + + @classmethod + def from_config(cls, config) -> "Washover": + if not config.washover_time_delta: + raise ValueError( + f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" + ) + return cls(washover_time_delta=config.washover_time_delta) + + class SimmetricWashover(TwoSidedWashover): def __init__(self, washover_time_delta: datetime.timedelta): super().__init__( diff --git a/tests/splitter/test_washover.py b/tests/splitter/test_washover.py index 43ee05ff..017c2f16 100644 --- a/tests/splitter/test_washover.py +++ b/tests/splitter/test_washover.py @@ -11,7 +11,7 @@ ) -@pytest.mark.parametrize("minutes, n_rows", [(30, 2), (10, 4), (15, 3)]) +@pytest.mark.parametrize("minutes, n_rows", [(30, 3), (10, 4), (15, 3)]) def test_constant_washover_base(minutes, n_rows, washover_base_df): out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover( @@ -23,7 +23,6 @@ def test_constant_washover_base(minutes, n_rows, washover_base_df): ) assert len(out_df) == n_rows - assert (out_df["original___time"].dt.minute > minutes).all() @pytest.mark.parametrize("minutes, n_rows", [(30, 2), (15, 2), (12, 3), (10, 4)]) @@ -37,7 +36,6 @@ def test_simmetric_washover_base(minutes, n_rows, simmetric_washover_base_df): ) assert len(out_df) == n_rows - assert (out_df["original___time"].dt.minute > minutes).all() @pytest.mark.parametrize( @@ -64,8 +62,8 @@ def test_2_sided_washover( @pytest.mark.parametrize( "minutes, n_rows, df", [ - (30, 4, "washover_df_no_switch"), - (30, 4 + 4, "washover_df_multi_city"), + (30, 5, "washover_df_no_switch"), + (30, 5 + 5, "washover_df_multi_city"), ], ) def test_constant_washover_no_switch(minutes, n_rows, df, request): From eeeb0ca1bffaa5f2cf067fd4501982b87c20b214 Mon Sep 17 00:00:00 2001 From: David Date: Wed, 5 Apr 2023 16:50:25 +0200 Subject: [PATCH 6/6] add comments --- cluster_experiments/washover.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index 4d407ef0..c0e50bcd 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -84,7 +84,7 @@ def washover( class TwoSidedWashover(Washover): """Two sided washover - we drop all rows before and after the switch within - the time delta when there is a switch where the treatment is different.""" + the time deltas when there is a switch where the treatment is different.""" def __init__( self, @@ -102,7 +102,7 @@ def washover( cluster_cols: List[str], original_time_col: Optional[str] = None, ) -> pd.DataFrame: - """Two sided washover - removes rows simmetrically around the switch. + """Two sided washover - removes rows around the switch. Args: df (pd.DataFrame): Input dataframe. @@ -215,9 +215,12 @@ def washover( @classmethod def from_config(cls, config) -> "Washover": - if not config.washover_time_delta: + if ( + not config.washover_time_delta_before + or not config.washover_time_delta_after + ): raise ValueError( - f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }" + f"Washover time deltas must be specified for , while it is {config.washover_time_delta_before = } and {config.washover_time_delta_after = }" ) return cls( washover_time_delta_before=config.washover_time_delta, @@ -226,8 +229,8 @@ def from_config(cls, config) -> "Washover": class ConstantWashover(TwoSidedWashover): - """Constant washover - we drop all rows in the washover period when - there is a switch where the treatment is different.""" + """Constant washover - we drop all rows in the washover period after + the switch when the treatment is different.""" def __init__(self, washover_time_delta: datetime.timedelta): super().__init__(datetime.timedelta(seconds=0), washover_time_delta) @@ -242,6 +245,9 @@ def from_config(cls, config) -> "Washover": class SimmetricWashover(TwoSidedWashover): + """Simmetric washover - we drop all rows in the washover period before + and after the switch when the treatment is different.""" + def __init__(self, washover_time_delta: datetime.timedelta): super().__init__( washover_time_delta_before=washover_time_delta,