Skip to content

Commit

Permalink
Really fixd confidence intervals
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Nov 6, 2024
1 parent 376a4e0 commit 7198361
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
45 changes: 27 additions & 18 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ function quantities_of_interest(mod, n)
null_dist = generate_null_distribution(mod, n)
avg_effect = mod isa Metalearner ? mean(mod.causal_effect) : mod.causal_effect
pvalue, stderr = p_value_and_std_err(null_dist, avg_effect)
lb, ub = confidence_interval(null_dist)
lb, ub = confidence_interval(null_dist, avg_effect)

return pvalue, stderr, lb, ub
end
Expand All @@ -268,13 +268,13 @@ function quantities_of_interest(mod::InterruptedTimeSeries, n, mean_effect)
metric = ifelse(mean_effect, mean, sum)
effect = metric(mod.causal_effect)
pvalue, stderr = p_value_and_std_err(null_dist, effect)
lb, ub = confidence_interval(null_dist)
lb, ub = confidence_interval(null_dist, effect)

return pvalue, stderr, lb, ub
end

"""
confidence_interval(null_dist)
confidence_interval(null_dist, effect)
Compute 95% confidence intervals via randomization inference.
Expand All @@ -289,24 +289,33 @@ julia> x, t, y = rand(100, 5), [rand()<0.4 for i in 1:100], rand(1:100, 100, 1)
julia> g_computer = GComputation(x, t, y)
julia> estimate_causal_effect!(g_computer)
julia> null_dist = CausalELM.generate_null_distribution(g_computer, 1000)
julia> confidence_interval(null_dist)
julia> confidence_interval(null_dist, g_computer.causal_effect)
(-0.45147664642089147, 0.45147664642089147)
```
"""
function confidence_interval(null_dist)
sorted_null_dist, n = sort(null_dist), length(null_dist)
low_idx, high_idx = 0.025 * (n - 1), 0.975 * (n - 1)

lb = if isinteger(low_idx)
sorted_null_dist[Int(low_idx)]
else
mean(sorted_null_dist[floor(Int, low_idx):ceil(Int, low_idx)])
end

ub = if isinteger(high_idx)
sorted_null_dist[Int(high_idx)]
else
mean(sorted_null_dist[floor(Int, high_idx):ceil(Int, high_idx)])
function confidence_interval(null_dist, effect)
# Grid to search that probably includes the lower and upper bounds and is pretty precise
max_magnitude_val = maximum(abs.(null_dist))
grid = range(
start=effect - 2max_magnitude_val,
stop=effect + 2max_magnitude_val,
length=4length(null_dist)
)
lb, ub = Inf, -Inf
low_idx, high_idx = 1, length(grid)

# Start from the smallest and largest values until we get p > 0.05 and break out
while (isinf(lb) || isinf(ub)) && (low_idx < high_idx)
left_p_val, _ = p_value_and_std_err(null_dist, grid[low_idx])
right_p_val, _ = p_value_and_std_err(null_dist, grid[high_idx])

lb = left_p_val > 0.05 && isinf(lb) ? grid[low_idx] : lb
ub = right_p_val > 0.05 && isinf(ub) ? grid[high_idx] : ub

(isinf(lb) == false && isinf(ub) == false) && break

low_idx += 1
high_idx -= 1
end

return lb, ub
Expand Down
24 changes: 17 additions & 7 deletions test/test_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ g_computer = GComputation(x, t, y)
estimate_causal_effect!(g_computer)
g_inference = CausalELM.generate_null_distribution(g_computer, 1000)
p1, stderr1 = CausalELM.p_value_and_std_err(g_inference, CausalELM.mean(g_inference))
lb1, ub1 = CausalELM.confidence_interval(g_inference)
lb1, ub1 = CausalELM.confidence_interval(g_inference, g_computer.causal_effect)
p11, stderr11, lb11, ub11 = CausalELM.quantities_of_interest(g_computer, 100)
summary1 = summarize(g_computer, n=100, inference=true)

dm = DoubleMachineLearning(x, t, y)
estimate_causal_effect!(dm)
dm_inference = CausalELM.generate_null_distribution(dm, 1000)
p2, stderr2 = CausalELM.p_value_and_std_err(dm_inference, CausalELM.mean(dm_inference))
lb2, ub2 = CausalELM.confidence_interval(dm_inference)
lb2, ub2 = CausalELM.confidence_interval(dm_inference, dm.causal_effect)
summary2 = summarize(dm, n=100)

# With a continuous treatment variable
Expand All @@ -27,7 +27,9 @@ dm_continuous_inference = CausalELM.generate_null_distribution(dm_continuous, 10
p3, stderr3 = CausalELM.p_value_and_std_err(
dm_continuous_inference, CausalELM.mean(dm_continuous_inference)
)
lb3, ub3 = CausalELM.confidence_interval(dm_continuous_inference)
lb3, ub3 = CausalELM.confidence_interval(
dm_continuous_inference, dm_continuous.causal_effect
)
summary3 = summarize(dm_continuous, n=100)

x₀, y₀, x₁, y₁ = rand(1:100, 100, 5), rand(100), rand(10, 5), rand(10)
Expand All @@ -39,7 +41,9 @@ summary4_inference = summarize(its, n=100, inference=true)
# Null distributions for the mean and cummulative changes
its_inference1 = CausalELM.generate_null_distribution(its, 1000, true)
its_inference2 = CausalELM.generate_null_distribution(its, 10, false)
lb4, ub4 = CausalELM.confidence_interval(its_inference1)
lb4, ub4 = CausalELM.confidence_interval(
its_inference1, CausalELM.mean(its.causal_effect)
)
p4, stderr4 = CausalELM.p_value_and_std_err(its_inference1, CausalELM.mean(its_inference1))
p44, stderr44, lb44, ub44 = CausalELM.quantities_of_interest(its, 100, true)

Expand All @@ -50,7 +54,9 @@ summary5 = summarize(slearner, n=100)
tlearner = TLearner(x, t, y)
estimate_causal_effect!(tlearner)
tlearner_inference = CausalELM.generate_null_distribution(tlearner, 1000)
lb6, ub6 = CausalELM.confidence_interval(tlearner_inference)
lb6, ub6 = CausalELM.confidence_interval(
tlearner_inference, CausalELM.mean(tlearner.causal_effect)
)
p6, stderr6 = CausalELM.p_value_and_std_err(
tlearner_inference, CausalELM.mean(tlearner_inference)
)
Expand All @@ -60,7 +66,9 @@ summary6 = summarize(tlearner, n=100)
xlearner = XLearner(x, t, y)
estimate_causal_effect!(xlearner)
xlearner_inference = CausalELM.generate_null_distribution(xlearner, 1000)
lb7, ub7 = CausalELM.confidence_interval(xlearner_inference)
lb7, ub7 = CausalELM.confidence_interval(
xlearner_inference, CausalELM.mean(xlearner.causal_effect)
)
p7, stderr7 = CausalELM.p_value_and_std_err(
xlearner_inference, CausalELM.mean(xlearner_inference)
)
Expand All @@ -74,7 +82,9 @@ summary9 = summarize(rlearner, n=100)
dr_learner = DoublyRobustLearner(x, t, y)
estimate_causal_effect!(dr_learner)
dr_learner_inference = CausalELM.generate_null_distribution(dr_learner, 1000)
lb8, ub8 = CausalELM.confidence_interval(dr_learner_inference)
lb8, ub8 = CausalELM.confidence_interval(
dr_learner_inference, CausalELM.mean(dr_learner.causal_effect)
)
p8, stderr8 = CausalELM.p_value_and_std_err(
dr_learner_inference, CausalELM.mean(dr_learner_inference)
)
Expand Down

0 comments on commit 7198361

Please sign in to comment.