Skip to content

Two-sided washover #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ The library offers the following classes:
* Washover for switchback experiments:
* `EmptyWashover`: no washover done at all.
* `ConstantWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval.
* `TwoSidedWashover`: accepts a timedelta parameter and removes the data when we switch from A to B for the timedelta interval, and also before the switch.
* `SimmetricWashover`: same as two sided, but with the same timedelta for both sides.
* Regarding analysis:
* `GeeExperimentAnalysis`: to run GEE analysis on the results of a clustered design
* `TTestClusteredAnalysis`: to run a t-test on aggregated data for clusters
Expand Down
10 changes: 9 additions & 1 deletion cluster_experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
StratifiedSwitchbackSplitter,
SwitchbackSplitter,
)
from cluster_experiments.washover import ConstantWashover, EmptyWashover, Washover
from cluster_experiments.washover import (
ConstantWashover,
EmptyWashover,
SimmetricWashover,
TwoSidedWashover,
Washover,
)

__all__ = [
"ExperimentAnalysis",
Expand All @@ -50,5 +56,7 @@
"PairedTTestClusteredAnalysis",
"EmptyWashover",
"ConstantWashover",
"TwoSidedWashover",
"SimmetricWashover",
"Washover",
]
122 changes: 105 additions & 17 deletions cluster_experiments/washover.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,17 @@ def washover(
return df


class ConstantWashover(Washover):
"""Constant washover - we drop all rows in the washover period when
there is a switch where the treatment is different."""
class TwoSidedWashover(Washover):
"""Two sided washover - we drop all rows before and after the switch within
the time deltas when there is a switch where the treatment is different."""

def __init__(self, washover_time_delta: datetime.timedelta):
self.washover_time_delta = washover_time_delta
def __init__(
self,
washover_time_delta_before: datetime.timedelta,
washover_time_delta_after: datetime.timedelta,
):
self.washover_time_delta_before = washover_time_delta_before
self.washover_time_delta_after = washover_time_delta_after

def washover(
self,
Expand All @@ -97,7 +102,7 @@ def washover(
cluster_cols: List[str],
original_time_col: Optional[str] = None,
) -> pd.DataFrame:
"""No washover - returns the same dataframe as input.
"""Two sided washover - removes rows around the switch.

Args:
df (pd.DataFrame): Input dataframe.
Expand All @@ -114,7 +119,7 @@ def washover(
from cluster_experiments import SwitchbackSplitter
from cluster_experiments import ConstantWashover

washover = ConstantWashover(washover_time_delta=datetime.timedelta(minutes=30))
washover = TwoSidedWashover(washover_time_delta=datetime.timedelta(minutes=30))

n = 10
df = pd.DataFrame(
Expand Down Expand Up @@ -145,40 +150,123 @@ def washover(
if original_time_col
else _original_time_column(truncated_time_col)
)

# Cluster columns that do not involve time
non_time_cols = list(set(cluster_cols) - set([truncated_time_col]))

# For each cluster, we need to check if treatment has changed wrt last time
df_agg = df.drop_duplicates(subset=cluster_cols + [treatment_col]).copy()
df_agg["__changed"] = (
df_agg.groupby(non_time_cols)[treatment_col].shift(1)
!= df_agg[treatment_col]
)
df_agg = df_agg.loc[:, cluster_cols + ["__changed"]]
) & df_agg.groupby(non_time_cols)[treatment_col].shift(1).notnull()

# We also check if treatment changes for the next time
df_agg["__changed_next"] = (
df_agg.groupby(non_time_cols)[treatment_col].shift(-1)
!= df_agg[treatment_col]
) & df_agg.groupby(non_time_cols)[treatment_col].shift(-1).notnull()

# Calculate switch start of the next time
df_agg[f"__next_{truncated_time_col}"] = df_agg.groupby(non_time_cols)[
truncated_time_col
].shift(-1)

# Clean switch df
df_agg = df_agg.loc[
:,
cluster_cols
+ ["__changed", "__changed_next", f"__next_{truncated_time_col}"],
]
return (
df.merge(df_agg, on=cluster_cols, how="inner")
.assign(
__time_since_switch=lambda x: x[original_time_col].astype(
"datetime64[ns]"
)
- x[truncated_time_col].astype("datetime64[ns]"),
__after_washover=lambda x: x["__time_since_switch"]
> self.washover_time_delta,
__time_to_next_switch=lambda x: x[
f"__next_{truncated_time_col}"
].astype("datetime64[ns]")
- x[original_time_col].astype("datetime64[ns]"),
__after_washover=lambda x: (
x["__time_since_switch"] > self.washover_time_delta_after
),
__before_washover=lambda x: (
x["__time_to_next_switch"] > self.washover_time_delta_before
),
)
# add not changed in query
# if no change or too late after change, don't drop
.query("__after_washover or not __changed")
.drop(columns=["__time_since_switch", "__after_washover", "__changed"])
# if no change to next switch or too early before change, don't drop
.query("__before_washover or not __changed_next")
.drop(
columns=[
"__time_since_switch",
"__time_to_next_switch",
"__after_washover",
"__before_washover",
"__changed",
"__changed_next",
f"__next_{truncated_time_col}",
]
)
)

@classmethod
def from_config(cls, config) -> "Washover":
if not config.washover_time_delta:
if (
not config.washover_time_delta_before
or not config.washover_time_delta_after
):
raise ValueError(
f"Washover time delta must be specified for ConstantWashover, while it is {config.washover_time_delta = }"
f"Washover time deltas must be specified for , while it is {config.washover_time_delta_before = } and {config.washover_time_delta_after = }"
)
return cls(
washover_time_delta=config.washover_time_delta,
washover_time_delta_before=config.washover_time_delta,
washover_time_delta_after=config.washover_time_delta,
)


class ConstantWashover(TwoSidedWashover):
"""Constant washover - we drop all rows in the washover period after
the switch when the treatment is different."""

def __init__(self, washover_time_delta: datetime.timedelta):
super().__init__(datetime.timedelta(seconds=0), washover_time_delta)

@classmethod
def from_config(cls, config) -> "Washover":
if not config.washover_time_delta:
raise ValueError(
f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }"
)
return cls(washover_time_delta=config.washover_time_delta)


class SimmetricWashover(TwoSidedWashover):
"""Simmetric washover - we drop all rows in the washover period before
and after the switch when the treatment is different."""

def __init__(self, washover_time_delta: datetime.timedelta):
super().__init__(
washover_time_delta_before=washover_time_delta,
washover_time_delta_after=washover_time_delta,
)

@classmethod
def from_config(cls, config) -> "Washover":
if not config.washover_time_delta:
raise ValueError(
f"Washover time delta must be specified for SimetricWashover, while it is {config.washover_time_delta = }"
)
return cls(washover_time_delta=config.washover_time_delta)


# This is kept in here because of circular imports, need to rethink this
washover_mapping = {"": EmptyWashover, "constant_washover": ConstantWashover}
washover_mapping = {
"": EmptyWashover,
"constant_washover": ConstantWashover,
"two_sided_washover": TwoSidedWashover,
"simmetric_washover": SimmetricWashover,
}
17 changes: 17 additions & 0 deletions tests/splitter/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ def washover_base_df():
return df


@pytest.fixture
def simmetric_washover_base_df():
df = pd.DataFrame(
{
"original___time": [
pd.to_datetime("2022-01-01 00:20:00"),
pd.to_datetime("2022-01-01 00:49:00"),
pd.to_datetime("2022-01-01 01:14:00"),
pd.to_datetime("2022-01-01 01:31:00"),
],
"treatment": ["A", "A", "B", "B"],
"city": ["TGN"] * 4,
}
).assign(time=lambda x: x["original___time"].dt.floor("1h", ambiguous="infer"))
return df


@pytest.fixture
def washover_df_no_switch():
df = pd.DataFrame(
Expand Down
48 changes: 43 additions & 5 deletions tests/splitter/test_washover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import pytest

from cluster_experiments import SwitchbackSplitter
from cluster_experiments.washover import ConstantWashover, EmptyWashover
from cluster_experiments.washover import (
ConstantWashover,
EmptyWashover,
SimmetricWashover,
TwoSidedWashover,
)


@pytest.mark.parametrize("minutes, n_rows", [(30, 2), (10, 4), (15, 3)])
@pytest.mark.parametrize("minutes, n_rows", [(30, 3), (10, 4), (15, 3)])
def test_constant_washover_base(minutes, n_rows, washover_base_df):

out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover(
Expand All @@ -18,14 +23,47 @@ def test_constant_washover_base(minutes, n_rows, washover_base_df):
)

assert len(out_df) == n_rows
assert (out_df["original___time"].dt.minute > minutes).all()


@pytest.mark.parametrize("minutes, n_rows", [(30, 2), (15, 2), (12, 3), (10, 4)])
def test_simmetric_washover_base(minutes, n_rows, simmetric_washover_base_df):

out_df = SimmetricWashover(washover_time_delta=timedelta(minutes=minutes)).washover(
df=simmetric_washover_base_df,
truncated_time_col="time",
cluster_cols=["city", "time"],
treatment_col="treatment",
)

assert len(out_df) == n_rows


@pytest.mark.parametrize(
"minutes_before, minutes_after, n_rows",
[(30, 30, 2), (10, 10, 4), (15, 15, 2), (12, 15, 2)],
)
def test_2_sided_washover(
minutes_before, minutes_after, n_rows, simmetric_washover_base_df
):

out_df = TwoSidedWashover(
washover_time_delta_before=timedelta(minutes=minutes_before),
washover_time_delta_after=timedelta(minutes=minutes_after),
).washover(
df=simmetric_washover_base_df,
truncated_time_col="time",
cluster_cols=["city", "time"],
treatment_col="treatment",
)

assert len(out_df) == n_rows


@pytest.mark.parametrize(
"minutes, n_rows, df",
[
(30, 4, "washover_df_no_switch"),
(30, 4 + 4, "washover_df_multi_city"),
(30, 5, "washover_df_no_switch"),
(30, 5 + 5, "washover_df_multi_city"),
],
)
def test_constant_washover_no_switch(minutes, n_rows, df, request):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
PowerAnalysis,
PowerConfig,
RandomSplitter,
SimmetricWashover,
StratifiedClusteredSplitter,
StratifiedSwitchbackSplitter,
SwitchbackSplitter,
TargetAggregation,
TTestClusteredAnalysis,
TwoSidedWashover,
UniformPerturbator,
)
from cluster_experiments.utils import _original_time_column
Expand All @@ -49,6 +51,8 @@
UniformPerturbator,
_original_time_column,
ConstantWashover,
TwoSidedWashover,
SimmetricWashover,
EmptyWashover,
BalancedSwitchbackSplitter,
StratifiedSwitchbackSplitter,
Expand Down