Skip to content

Commit

Permalink
Added marginal effects to summaries
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Nov 28, 2024
1 parent 31dcb21 commit 41030b6
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 26 deletions.
3 changes: 2 additions & 1 deletion docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
## Version [0.8.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.8.0) - 2024-10-31
### Added
* Implemented randomization inference-based confidence intervals [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
* Added marginal effects to model summaries [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
### Fixed
* Removed unnecessary include and using statements
* Slightly sped up the randomization inference implementation and clarified it in the docs [#77](https://github.com/dscolby/CausalELM.jl/issues/77)
* Fixed the randomization inference index selection procedure for interrupted time series estimators
* Inlined certain methods to slightly improve performance [#79](https://github.com/dscolby/CausalELM.jl/issues/79)
* Inlined certain methods to slightly improve performance [#76](https://github.com/dscolby/CausalELM.jl/issues/76)

## Version [v0.7.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.7.0) - 2024-06-22
### Added
Expand Down
53 changes: 38 additions & 15 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mutable struct InterruptedTimeSeries
Y₀::Array{Float64}
X₁::Array{Float64}
Y₁::Array{Float64}
marginal_effect::Float64
@model_config individual_effect
end

Expand Down Expand Up @@ -77,6 +78,7 @@ function InterruptedTimeSeries(
float(Y₀),
X₁,
float(Y₁),
NaN,
"difference",
true,
task,
Expand Down Expand Up @@ -137,6 +139,7 @@ julia> m5 = GComputation(x_df, t_df, y_df)
mutable struct GComputation <: CausalEstimator
@standard_input_data
@model_config average_effect
marginal_effect::Float64
ensemble::ELMEnsemble

function GComputation(
Expand Down Expand Up @@ -173,6 +176,7 @@ mutable struct GComputation <: CausalEstimator
num_feats,
num_neurons,
NaN,
NaN,
)
end
end
Expand Down Expand Up @@ -220,6 +224,7 @@ julia> m2 = DoubleMachineLearning(x_df, t_df, y_df)
mutable struct DoubleMachineLearning <: CausalEstimator
@standard_input_data
@model_config average_effect
marginal_effect::Float64
folds::Integer
end

Expand Down Expand Up @@ -256,6 +261,7 @@ function DoubleMachineLearning(
num_feats,
num_neurons,
NaN,
NaN,
folds,
)
end
Expand Down Expand Up @@ -285,6 +291,7 @@ julia> estimate_causal_effect!(m1)

fit!(learner)
its.causal_effect = predict(learner, its.X₁) .- its.Y₁
its.marginal_effect = mean(its.causal_effect)

return its.causal_effect
end
Expand All @@ -309,7 +316,9 @@ julia> estimate_causal_effect!(m1)
```
"""
@inline function estimate_causal_effect!(g::GComputation)
g.causal_effect = mean(g_formula!(g))
causal_effect, marginal_effect = g_formula!(g)
g.causal_effect, g.marginal_effect = mean(causal_effect), mean(marginal_effect)

return g.causal_effect
end

Expand All @@ -330,6 +339,7 @@ julia> g_formula!(m2)
"""
@inline function g_formula!(g) # Keeping this separate for reuse with S-Learning
covariates, y = hcat(g.X, g.T), g.Y
x₁, x₀ = hcat(g.X, ones(size(g.X, 1))), hcat(g.X, zeros(size(g.X, 1)))

if g.quantity_of_interest ("ITT", "ATE", "CATE")
Xₜ = hcat(covariates[:, 1:(end - 1)], ones(size(covariates, 1)))
Expand All @@ -350,10 +360,9 @@ julia> g_formula!(m2)
)

fit!(g.ensemble)

yₜ, yᵤ = predict(g.ensemble, Xₜ), predict(g.ensemble, Xᵤ)

return vec(yₜ) - vec(yᵤ)
return vec(yₜ) - vec(yᵤ), predict(g.ensemble, x₁) - predict(g.ensemble, x₀)
end

"""
Expand All @@ -374,27 +383,35 @@ julia> estimate_causal_effect!(m2)
"""
@inline function estimate_causal_effect!(DML::DoubleMachineLearning)
X, T, Y = generate_folds(DML.X, DML.T, DML.Y, DML.folds)
DML.causal_effect = 0
DML.causal_effect, DML.marginal_effect = 0, 0
Δ = var_type(DML.T) isa Binary ? 1.0 : 1.5e-8mean(DML.T)

# Cross fitting by training on the main folds and predicting residuals on the auxillary
for fld in 1:(DML.folds)
X_train, X_test = reduce(vcat, X[1:end .!== fld]), X[fld]
Y_train, Y_test = reduce(vcat, Y[1:end .!== fld]), Y[fld]
T_train, T_test = reduce(vcat, T[1:end .!== fld]), T[fld]

Ỹ, T̃ = predict_residuals(DML, X_train, X_test, Y_train, Y_test, T_train, T_test)
for fold in 1:(DML.folds)
X_train, X_test = reduce(vcat, X[1:end .!== fold]), X[fold]
Y_train, Y_test = reduce(vcat, Y[1:end .!== fold]), Y[fold]
T_train, T_test = reduce(vcat, T[1:end .!== fold]), T[fold]
T_train₊ = var_type(DML.T) isa Binary ? T_train .* 0 : T_train .+ Δ

Ỹ, T̃, T̃₊ = predict_residuals(
DML, X_train, X_test, Y_train, Y_test, T_train, T_test, T_train₊
)

DML.causal_effect +=\
DML.marginal_effect += (T̃₊\- DML.causal_effect) / Δ
end

DML.causal_effect /= DML.folds
DML.marginal_effect /= DML.folds

return DML.causal_effect
end

"""
predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test)
predict_residuals(D, x_train, x_test, y_train, y_test, t_train, t_test, t_train₊)
Predict treatment and outcome residuals for double machine learning or R-learning.
Predict treatment, outcome, and marginal effect residuals for double machine learning or
R-learning.
# Notes
This method should not be called directly.
Expand All @@ -406,7 +423,7 @@ julia> x_train, x_test = X[1:80, :], X[81:end, :]
julia> y_train, y_test = Y[1:80], Y[81:end]
julia> t_train, t_test = T[1:80], T[81:100]
julia> m1 = DoubleMachineLearning(X, T, Y)
julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test, zeros(100))
```
"""
@inline function predict_residuals(
Expand All @@ -417,6 +434,7 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
yₜₑ::Vector{Float64},
tₜᵣ::Vector{Float64},
tₜₑ::Vector{Float64},
tₜᵣ₊::Vector{Float64}
)
y = ELMEnsemble(
xₜᵣ, yₜᵣ, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
Expand All @@ -426,12 +444,17 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
xₜᵣ, tₜᵣ, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
)

t₊ = ELMEnsemble(
xₜᵣ, tₜᵣ₊, D.sample_size, D.num_machines, D.num_feats, D.num_neurons, D.activation
)

fit!(y)
fit!(t)
fit!(t₊) # Estimate a model with T + a finite difference

yₚᵣ, tₚᵣ = predict(y, xₜₑ), predict(t, xₜₑ)
yₚᵣ, tₚᵣ, tₚᵣ₊ = predict(y, xₜₑ), predict(t, xₜₑ), predict(t₊, xₜₑ)

return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ
return yₜₑ - yₚᵣ, tₜₑ - tₚᵣ, tₜₑ - tₚᵣ₊
end

"""
Expand Down
12 changes: 8 additions & 4 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ function summarize(mod; kwargs...)
"Standard Error",
"p-value",
"Lower 2.5% CI",
"Upper 97.5% CI"
"Upper 97.5% CI",
"Marginal Effect"
]

if haskey(kwargs, :inference) && kwargs[:inference] == true
Expand All @@ -82,7 +83,8 @@ function summarize(mod; kwargs...)
stderr,
p,
lower_ci,
upper_ci
upper_ci,
mod.marginal_effect
]

for (nicename, value) in zip(nicenames, values)
Expand Down Expand Up @@ -124,7 +126,8 @@ function summarize(its::InterruptedTimeSeries; kwargs...)
"Standard Error",
"p-value",
"Lower 2.5% CI",
"Upper 97.5% CI"
"Upper 97.5% CI",
"Marginal Effect"
]

values = [
Expand All @@ -140,7 +143,8 @@ function summarize(its::InterruptedTimeSeries; kwargs...)
stderr,
p,
l,
u
u,
its.marginal_effect
]

for (nicename, value) in zip(nicenames, values)
Expand Down
32 changes: 27 additions & 5 deletions src/metalearners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ julia> m4 = SLearner(x_df, t_df, y_df)
mutable struct SLearner <: Metalearner
@standard_input_data
@model_config individual_effect
marginal_effect::Vector{Float64}
ensemble::ELMEnsemble

function SLearner(
Expand Down Expand Up @@ -76,6 +77,7 @@ mutable struct SLearner <: Metalearner
num_feats,
num_neurons,
fill(NaN, size(T, 1)),
fill(NaN, size(T, 1)),
)
end
end
Expand Down Expand Up @@ -123,6 +125,7 @@ julia> m3 = TLearner(x_df, t_df, y_df)
mutable struct TLearner <: Metalearner
@standard_input_data
@model_config individual_effect
marginal_effect::Vector{Float64}
μ₀::ELMEnsemble
μ₁::ELMEnsemble

Expand Down Expand Up @@ -154,6 +157,7 @@ mutable struct TLearner <: Metalearner
num_feats,
num_neurons,
fill(NaN, size(T, 1)),
fill(NaN, size(T, 1)),
)
end
end
Expand Down Expand Up @@ -201,6 +205,7 @@ julia> m3 = XLearner(x_df, t_df, y_df)
mutable struct XLearner <: Metalearner
@standard_input_data
@model_config individual_effect
marginal_effect::Vector{Float64}
μ₀::ELMEnsemble
μ₁::ELMEnsemble
ps::Array{Float64}
Expand Down Expand Up @@ -233,6 +238,7 @@ mutable struct XLearner <: Metalearner
num_feats,
num_neurons,
fill(NaN, size(T, 1)),
fill(NaN, size(T, 1)),
)
end
end
Expand Down Expand Up @@ -278,6 +284,7 @@ julia> m2 = RLearner(x_df, t_df, y_df)
mutable struct RLearner <: Metalearner
@standard_input_data
@model_config individual_effect
marginal_effect::Vector{Float64}
folds::Integer
end

Expand Down Expand Up @@ -315,6 +322,7 @@ function RLearner(
num_feats,
num_neurons,
fill(NaN, size(T, 1)),
fill(NaN, size(T, 1)),
folds,
)
end
Expand Down Expand Up @@ -363,6 +371,7 @@ julia> m3 = DoublyRobustLearner(X, T, Y, W=w)
mutable struct DoublyRobustLearner <: Metalearner
@standard_input_data
@model_config individual_effect
marginal_effect::Vector{Float64}
folds::Integer
end

Expand Down Expand Up @@ -398,6 +407,7 @@ function DoublyRobustLearner(
num_feats,
num_neurons,
fill(NaN, size(T, 1)),
fill(NaN, size(T, 1)),
2,
)
end
Expand All @@ -421,7 +431,7 @@ julia> estimate_causal_effect!(m4)
```
"""
@inline function estimate_causal_effect!(s::SLearner)
s.causal_effect = g_formula!(s)
s.causal_effect, s.marginal_effect = g_formula!(s)
return s.causal_effect
end

Expand Down Expand Up @@ -458,6 +468,7 @@ julia> estimate_causal_effect!(m5)
fit!(t.μ₁)
predictionsₜ, predictionsᵪ = predict(t.μ₁, t.X), predict(t.μ₀, t.X)
t.causal_effect = @fastmath vec(predictionsₜ - predictionsᵪ)
t.marginal_effect = t.causal_effect

return t.causal_effect
end
Expand Down Expand Up @@ -488,6 +499,8 @@ julia> estimate_causal_effect!(m1)
(x.ps .* predict(μχ₀, x.X)) .+ ((1 .- x.ps) .* predict(μχ₁, x.X))
))

x.marginal_effect = x.causal_effect # Works since T is binary

return x.causal_effect
end

Expand All @@ -510,26 +523,34 @@ julia> estimate_causal_effect!(m1)
"""
@inline function estimate_causal_effect!(R::RLearner)
X, T̃, Ỹ = generate_folds(R.X, R.T, R.Y, R.folds)
T̃₊, Δ = similar(T̃), var_type(R.T) isa Binary ? 1.0 : 1.5e-8mean(R.T)
R.X, R.T, R.Y = reduce(vcat, X), reduce(vcat, T̃), reduce(vcat, Ỹ)

# Get residuals from out-of-fold predictions
for f in 1:(R.folds)
X_train, X_test = reduce(vcat, X[1:end .!== f]), X[f]
Y_train, Y_test = reduce(vcat, Ỹ[1:end .!== f]), Ỹ[f]
T_train, T_test = reduce(vcat, T̃[1:end .!== f]), T̃[f]
Ỹ[f], T̃[f] = predict_residuals(R, X_train, X_test, Y_train, Y_test, T_train, T_test)
T_train₊ = var_type(R.T) isa Binary ? T_train .* 0 : T_train .+ Δ
Ỹ[f], T̃[f], T̃₊[f] = predict_residuals(
R, X_train, X_test, Y_train, Y_test, T_train, T_test, T_train₊
)
end

# Using target transformation and the weight trick to minimize the causal loss
T̃², target = reduce(vcat, T̃).^2, reduce(vcat, Ỹ) ./ reduce(vcat, T̃)
Xʷ, Yʷ = R.X .* T̃², target .* T̃²

# Fit a weighted residual-on-residual model
T̃²₊, target₊ = reduce(vcat, T̃₊).^2, reduce(vcat, Ỹ) ./ reduce(vcat, T̃₊)
final_model = ELMEnsemble(
Xʷ, Yʷ, R.sample_size, R.num_machines, R.num_feats, R.num_neurons, R.activation
)
fit!(final_model)

# Using finite differences to calculate marginal effects
final_model₊ = deepcopy(final_model)
final_model₊.X, final_model₊.Y = R.X .* T̃²₊, target₊ .* T̃²₊
fit!(final_model); fit!(final_model₊)
R.causal_effect = predict(final_model, R.X)
R.marginal_effect = (predict(final_model₊, final_model.X) - R.causal_effect) ./ Δ

return R.causal_effect
end
Expand Down Expand Up @@ -563,6 +584,7 @@ julia> estimate_causal_effect!(m1)

causal_effect ./= 2
DRE.causal_effect = causal_effect
DRE.marginal_effect = causal_effect

return DRE.causal_effect
end
Expand Down
Loading

0 comments on commit 41030b6

Please sign in to comment.