Skip to content

Commit 9330a9c

Browse files
authored
Merge pull request #359 from pymc-labs/summary_pymc
2 parents 2916688 + 267d7c7 commit 9330a9c

11 files changed

+572
-715
lines changed

causalpy/pymc_experiments.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,29 +102,37 @@ def print_coefficients(self, round_to=None) -> None:
102102
... "progressbar": False
103103
... }),
104104
... )
105-
>>> result.print_coefficients(round_to=1) # doctest: +NUMBER
105+
>>> result.print_coefficients(round_to=1)
106106
Model coefficients:
107-
Intercept 1, 94% HDI [1, 1]
108-
post_treatment[T.True] 1, 94% HDI [0.9, 1]
109-
group 0.2, 94% HDI [0.09, 0.2]
110-
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
111-
sigma 0.08, 94% HDI [0.07, 0.1]
107+
Intercept 1, 94% HDI [1, 1]
108+
post_treatment[T.True] 1, 94% HDI [0.9, 1]
109+
group 0.2, 94% HDI [0.09, 0.2]
110+
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
111+
sigma 0.08, 94% HDI [0.07, 0.1]
112112
"""
113+
114+
def print_row(
115+
max_label_length: int, name: str, coeff_samples: xr.DataArray, round_to: int
116+
) -> None:
117+
"""Print one row of the coefficient table"""
118+
formatted_name = f" {name: <{max_label_length}}"
119+
formatted_val = f"{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
120+
print(f" {formatted_name} {formatted_val}")
121+
113122
print("Model coefficients:")
114123
coeffs = az.extract(self.idata.posterior, var_names="beta")
115-
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of
116-
# the stats despite variable names of different lengths
124+
125+
# Determine the width of the longest label
126+
max_label_length = max(len(name) for name in self.labels + ["sigma"])
127+
117128
for name in self.labels:
118129
coeff_samples = coeffs.sel(coeffs=name)
119-
print(
120-
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
121-
)
122-
# add coeff for measurement std
130+
print_row(max_label_length, name, coeff_samples, round_to)
131+
132+
# Add coefficient for measurement std
123133
coeff_samples = az.extract(self.model.idata.posterior, var_names="sigma")
124134
name = "sigma"
125-
print(
126-
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
127-
)
135+
print_row(max_label_length, name, coeff_samples, round_to)
128136

129137

130138
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
@@ -160,13 +168,13 @@ class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
160168
... }
161169
... ),
162170
... )
163-
>>> result.summary(round_to=1) # doctest: +NUMBER
171+
>>> result.summary(round_to=1)
164172
==================================Pre-Post Fit==================================
165173
Formula: actual ~ 0 + a + g
166174
Model coefficients:
167-
a 0.6, 94% HDI [0.6, 0.6]
168-
g 0.4, 94% HDI [0.4, 0.4]
169-
sigma 0.8, 94% HDI [0.6, 0.9]
175+
a 0.6, 94% HDI [0.6, 0.6]
176+
g 0.4, 94% HDI [0.4, 0.4]
177+
sigma 0.8, 94% HDI [0.6, 0.9]
170178
"""
171179

172180
def __init__(
@@ -1181,10 +1189,10 @@ class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator):
11811189
Results:
11821190
Causal impact = 2, $CI_{94%}$[2, 2]
11831191
Model coefficients:
1184-
Intercept -0.5, 94% HDI [-1, 0.2]
1185-
C(group)[T.1] 2, 94% HDI [2, 2]
1186-
pre 1, 94% HDI [1, 1]
1187-
sigma 0.5, 94% HDI [0.5, 0.6]
1192+
Intercept -0.5, 94% HDI [-1, 0.2]
1193+
C(group)[T.1] 2, 94% HDI [2, 2]
1194+
pre 1, 94% HDI [1, 1]
1195+
sigma 0.5, 94% HDI [0.5, 0.6]
11881196
"""
11891197

11901198
def __init__(

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_did():
4343
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
4444
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
4545
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
46+
result.summary()
4647

4748

4849
# TODO: set up fixture for the banks dataset
@@ -98,6 +99,7 @@ def test_did_banks_simple():
9899
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
99100
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
100101
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
102+
result.summary()
101103

102104

103105
@pytest.mark.integration
@@ -149,6 +151,7 @@ def test_did_banks_multi():
149151
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
150152
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
151153
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
154+
result.summary()
152155

153156

154157
@pytest.mark.integration
@@ -174,6 +177,7 @@ def test_rd():
174177
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
175178
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
176179
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
180+
result.summary()
177181

178182

179183
@pytest.mark.integration
@@ -200,6 +204,7 @@ def test_rd_bandwidth():
200204
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
201205
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
202206
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
207+
result.summary()
203208

204209

205210
@pytest.mark.integration
@@ -229,6 +234,7 @@ def test_rd_drinking():
229234
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
230235
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
231236
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
237+
result.summary()
232238

233239

234240
def setup_regression_kink_data(kink):
@@ -281,6 +287,7 @@ def test_rkink():
281287
assert isinstance(result, cp.pymc_experiments.RegressionKink)
282288
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
283289
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
290+
result.summary()
284291

285292

286293
@pytest.mark.integration
@@ -307,6 +314,7 @@ def test_rkink_bandwidth():
307314
assert isinstance(result, cp.pymc_experiments.RegressionKink)
308315
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
309316
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
317+
result.summary()
310318

311319

312320
@pytest.mark.integration
@@ -336,6 +344,7 @@ def test_its():
336344
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
337345
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
338346
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
347+
result.summary()
339348

340349

341350
@pytest.mark.integration
@@ -366,6 +375,7 @@ def test_its_covid():
366375
assert isinstance(result, cp.pymc_experiments.InterruptedTimeSeries)
367376
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
368377
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
378+
result.summary()
369379

370380

371381
@pytest.mark.integration
@@ -392,6 +402,7 @@ def test_sc():
392402
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
393403
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
394404
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
405+
result.summary()
395406

396407

397408
@pytest.mark.integration
@@ -430,6 +441,7 @@ def test_sc_brexit():
430441
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
431442
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
432443
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
444+
result.summary()
433445

434446

435447
@pytest.mark.integration
@@ -455,6 +467,7 @@ def test_ancova():
455467
assert isinstance(result, cp.pymc_experiments.PrePostNEGD)
456468
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
457469
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
470+
result.summary()
458471

459472

460473
@pytest.mark.integration
@@ -485,6 +498,7 @@ def test_geolift1():
485498
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
486499
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
487500
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
501+
result.summary()
488502

489503

490504
@pytest.mark.integration

docs/source/notebooks/ancova_pymc.ipynb

Lines changed: 41 additions & 38 deletions
Large diffs are not rendered by default.

docs/source/notebooks/did_pymc.ipynb

Lines changed: 27 additions & 34 deletions
Large diffs are not rendered by default.

docs/source/notebooks/did_pymc_banks.ipynb

Lines changed: 58 additions & 72 deletions
Large diffs are not rendered by default.

docs/source/notebooks/its_pymc.ipynb

Lines changed: 51 additions & 58 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_pymc.ipynb

Lines changed: 93 additions & 121 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_pymc_drinking.ipynb

Lines changed: 73 additions & 124 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rkink_pymc.ipynb

Lines changed: 95 additions & 124 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_pymc.ipynb

Lines changed: 47 additions & 62 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_pymc_brexit.ipynb

Lines changed: 42 additions & 59 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)