Skip to content

Commit

Permalink
Implemented counterfactualconsistency
Browse files Browse the repository at this point in the history
dscolby committed Oct 30, 2023
1 parent 4f11df8 commit 2c3841b
Showing 5 changed files with 209 additions and 32 deletions.
3 changes: 3 additions & 0 deletions src/CausalELM.jl
Original file line number Diff line number Diff line change
@@ -34,6 +34,9 @@ end

const summarise = summarize

# Helpers to subtract consecutive elements in a vector
consecutive(v::Vector{<:Real}; f::String="minus") = [-(v[i+1], v[i]) for i = 1:length(v)-1]

include("activation.jl")
include("models.jl")
include("metrics.jl")
2 changes: 1 addition & 1 deletion src/metrics.jl
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ julia> mse([-1.0, -1.0, -1.0], [1.0, 1.0, 1.0])
4
```
"""
function mse(y::Vector{Float64}, ŷ::Vector{Float64})
function mse(y::Vector{<:Real}, ŷ::Vector{<:Real})
if length(y) !== length(ŷ)
throw(DimensionMismatch("y and ̂y must be the same length"))
end
187 changes: 165 additions & 22 deletions src/model_validation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module ModelValidation

using ..Estimators: InterruptedTimeSeries, estimatecausaleffect!, GComputation
using CausalELM: mean, var
using ..Metrics: mse
using CausalELM: mean, consecutive
using LinearAlgebra: norm

"""
@@ -98,7 +99,7 @@ function testcovariateindependence(its::InterruptedTimeSeries; n::Int=1000)
results = Dict{String, Float64}()

for i in 1:size(all_vars, 2)-2
y = all_vars[:, i]
y = all_vars[:, i]
β = first(x\y)
p = pval(x, y, β, n=n)
results["Column " * string(i) * " p-value"] = p
@@ -260,38 +261,180 @@ function pval(x::Array{Float64}, y::Array{Float64}, β::Float64; n::Int=1000,
return p
end

function counterfactualconsistency(g::GComputation)
treatment_covariates, treatment_outcomes = g.X[g.T == 1, :], g.Y[g.T == 1]
= treatment_covariates\treatment_outcomes
observed_residual_variance = var(ŷ)
"""
counterfactualconsistency(g; num_treatments)
Examine the counterfactual consistency assumption. First, this cunction generates for Jenks
breaks based on outcome values for the treatment group. Then, it replaces treatment statuses
with the numbers corresponding to each group. Next, it runs two linear regressions, one
with and one without the fake treatment assignemnts generated by the Jenks breaks. Finally,
it subtracts the mean squared error from the regression with real data from the mean squared
error from the regression with the fake treatment statuses. If this number is negative, it
might indicate a violation of the counterfactual consistency assumption or omitted variable
bias.
Examples
```julia-repl
julia> x, y, t = rand(100, 5), vec(rand(1:100, 100, 1)),
Float64.([rand()<0.4 for i in 1:100])
julia> g_computer = GComputation(x, y, t, temporal=false)
julia> estimatecausaleffect!(g_computer)
julia> counterfactualconsistency(g_computer)
2.7653668647301795
```
"""
function counterfactualconsistency(g::GComputation; num_treatments::Int=5)
treatment_covariates, treatment_outcomes = g.X[g.T .== 1, :], g.Y[g.T .== 1]
fake_treat = fake_treatments(treatment_outcomes; num_treatments)
β_real = treatment_covariates\treatment_outcomes
β_fake = Real.(reduce(hcat, (treatment_covariates, fake_treat))\treatment_outcomes)
ŷ_real = treatment_covariates*β_real
ŷ_fake = Real.(reduce(hcat, (treatment_covariates, fake_treat))*β_fake)
mse_real_treat = mse(treatment_outcomes, ŷ_real)
mse_fake_treat = mse(treatment_outcomes, ŷ_fake)

return mse_fake_treat - mse_real_treat
end

"""
jenks(y, k)
Generate optimal Jenks breaks.
Note that this finds the optimal breaks by enumerating every possible combination of breaks,
so this can take a while.
Examples
```julia-repl
julia> jenks(collect(1:10))
Vector{<:Real}[Real[1, 2], Real[3, 4], Real[5, 6, 7], Real[8, 9, 10]]
```
"""
function jenks(y::Vector{<:Real}, k::Int=5)
y_sorted = sort(y)
best_gvfs = Vector{Float64}(undef, k-1)
candidate_breaks = Vector{Vector{Vector{<:Real}}}(undef, k)

# Iterate through every possible number of breaks, find the best splits for that number
# of breaks, add its GVF and associated splits to best_gvfs and candidate_breaks
for i in 2:k
partitions = split_vector_ways(y_sorted; n=i)
best_current_gvf, best_index = findmax(gvf.(partitions))
best_gvfs[i-1], candidate_breaks[i-1] = best_current_gvf, partitions[best_index]
end

# Find the set of splits with the largest decrease in slope
rise, run = consecutive(best_gvfs), consecutive(collect(2:k))
δ_slope = consecutive(rise ./ run)
_, δ_idx = findmax(δ_slope)

return candidate_breaks[δ_idx+1]
end

"""
sdam(x)
Calculate the sum of squared deviations for array mean for a set of sub arrays.
Examples
```julia-repl
julia> sdam([5, 4, 9, 10])
26.0
```
"""
function sdam(x::Vector{T}) where T <: Real
= mean(x)

return @fastmath sum((x .- x̄).^2)
end

"""
ned(a, b)
sdcm(x)
Calculate the normalized Euclidean distance between two vectors. Before calculating the
normalized Euclidean distance, both vectors are sorted and padded with zeros if they are of
different lengths.
Calculate the sum of squared deviations for class means for a set of sub arrays.
Examples
```julia-repl
julia> ned([1, 1, 1], [0, 0])
01.0
julia> ned([1, 1], [0, 0])
0.7653668647301795
julia> scdm([[4], [5, 9, 10]])
14.0
```
"""
function ned(a::Vector{T}, b::Vector{T}) where T <: Number
if length(a) !== length(b)
if length(a) > length(b)
b = reduce(vcat, (b, zeros(abs(length(a)-length(b)))))
else
a = reduce(vcat, (a, zeros(abs(length(a)-length(b)))))
scdm(x::Vector{Vector{T}}) where T <: Real = @fastmath sum(sdam.(x))

"""
gvf(x)
Calculate the goodness of variance fit for a set of sub vectors.
Examples
```julia-repl
julia> gvf([[4, 5], [9, 10]])
0.96153846153
```
"""
function gvf(x::Vector{Vector{T}}) where T <: Real
return (sdam(collect(Iterators.flatten(x)))-scdm(x))/sdam(collect(Iterators.flatten(x)))
end

"""
split_vector_ways(x; n)
Find every possible way to split a vector into n sub vectors
Examples
```julia-repl
julia> split_vector_ways([1, 2, 3, 4, 5], n=3)
Vector{Vector{Real}}[[[1], [2], [3, 4, 5]], [[1], [2, 3], [4, 5]], [[1], [2, 3, 4], [5]],
[[1, 2], [3], [4, 5]], [[1, 2], [3, 4], [5]], [[1, 2, 3], [4], [5]]]
```
"""
function split_vector_ways(x::Vector{<:Real}; n::Int=5)
ways = Vector{Vector{Vector{Real}}}()

n == 1 && return [[x]]

for i in 1:(length(x)-n+1)
first_part = @inbounds [copy(x[1:i])]
rest = @inbounds x[i+1:end]
rest_ways = split_vector_ways(rest; n=n-1)
for way in rest_ways
@inbounds push!(ways, [first_part; way])
end
end
return ways
end

"""
fake_treatments(treatment_outcomes; num_treatments)
# Changing NaN to zero fixes divde by zero errors
@fastmath norm(replace(sort(a)./norm(a), NaN=>0) .- replace((sort(b)./norm(b)), NaN=>0))
Generate a vector of fake treatment statuses based on Jenks breaks.
Examples
```julia-repl
julia> outcomes = ones(100)
100-element Vector{Float64}:
1.0
1.0
1.0
1.0
julia> fake_treatments(outcomes)
Real[3, 3, 3, 3, 3, 3, 3, 3, 3, 3 … 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
```
"""
function fake_treatments(treatment_outcomes::Vector{<:Real}; num_treatments::Int=5)
fake_treatments = Vector{Real}(undef, length(treatment_outcomes))
new_partitions = jenks(treatment_outcomes, num_treatments)

# Create fictitious treatment statuses based on Jenks breaks
for (i, outcome) in enumerate(treatment_outcomes)
for idx in 1:(length(new_partitions))
if outcome new_partitions[idx]
fake_treatments[i] = idx
end
end
end
return fake_treatments
end

end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -6,6 +6,11 @@ import CausalELM
@test CausalELM.var([1, 2, 3]) == 1
end

@testset "Add and Subtract Consecutive Elements" begin
@test CausalELM.consecutive([1, 2, 3, 4, 5], f="minus") == [1, 1, 1, 1]
@test CausalELM.consecutive([1, 2, 3, 4, 5]) == [1, 1, 1, 1]
end

include("test_activation.jl")
include("test_models.jl")
include("test_metrics.jl")
44 changes: 35 additions & 9 deletions test/test_model_validation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Test
using CausalELM.Estimators: InterruptedTimeSeries, estimatecausaleffect!
using CausalELM.Estimators: InterruptedTimeSeries, GComputation, estimatecausaleffect!
using CausalELM.ModelValidation: pval, testcovariateindependence, testomittedpredictor,
supwald, validate, ned
supwald, validate, sdam, scdm, gvf, split_vector_ways, consecutive, jenks,
fake_treatments,counterfactualconsistency

x₀, y₀, x₁, y₁ = Float64.(rand(1:5, 100, 5)), randn(100), rand(1:5, (10, 5)), randn(10)
its = InterruptedTimeSeries(x₀, y₀, x₁, y₁)
@@ -11,6 +12,19 @@ wald_test = supwald(its)
ovb = testomittedpredictor(its)
its_validation = validate(its)

x, y, t = rand(100, 5), vec(rand(1:100, 100, 1)), Float64.([rand()<0.4 for i in 1:100])
g_computer = GComputation(x, y, t, temporal=false)
estimatecausaleffect!(g_computer)
test_outcomes = g_computer.Y[g_computer.T .== 1]

# Test splits for Jenks breaks
three_split = Vector{Vector{Real}}[[[1], [2], [3, 4, 5]],
[[1], [2, 3], [4, 5]], [[1], [2, 3, 4], [5]], [[1, 2], [3], [4, 5]],
[[1, 2], [3, 4], [5]], [[1, 2, 3], [4], [5]]]

two_split = Vector{Vector{Real}}[[[1], [2, 3, 4, 5]], [[1, 2], [3, 4, 5]],
[[1, 2, 3], [4, 5]], [[1, 2, 3, 4], [5]]]

@testset "p-value Argument Validation" begin
@test_throws ArgumentError pval(rand(10, 1), rand(10), 0.5)
@test_throws ArgumentError pval(rand(10, 3), rand(10), 0.5)
@@ -52,11 +66,23 @@ end
@test length(its_validation) === 3
end

@testset "Normailzed Euclidean Distance" begin
@test ned([1, 2, 3], [1, 2, 3]) === 0.0
@test ned([1, 1, 1], [0, 0, 0]) === 1.0
@test ned([1, 1, 1], [0, 0]) === 1.0
@test ned([0, 0], [1, 1, 1]) === 1.0
@test ned([0, 0, 0], [0, 0, 0]) === 0.0
@test ned([1, 1], [1, 0]) 0.76536686
# Examples taken from https://www.ehdp.com/methods/jenks-natural-breaks-2.htm
@testset "Jenks Breaks" begin
@test sdam([5, 4, 9, 10]) == 26
@test scdm([[4], [5, 9, 10]]) == 14
@test gvf([[4, 5], [9, 10]]) 0.96153846153
@test gvf([[4], [5], [9, 10]]) 0.9807692307692307
@test split_vector_ways([1, 2, 3, 4, 5]; n=3) == three_split
@test split_vector_ways([1, 2, 3, 4, 5]; n=2) == two_split
@test length(collect(Base.Iterators.flatten(jenks(collect(1:10), 5)))) == 10

for vec in jenks(collect(1:10), 5)
@test !isempty(vec)
end
end

@testset "Counterfactual Consistency" begin
@test length(fake_treatments(test_outcomes)) == length(test_outcomes)
@test sort(unique(fake_treatments(test_outcomes))) == [1, 2, 3, 4]
@test counterfactualconsistency(g_computer) isa Real
end

0 comments on commit 2c3841b

Please sign in to comment.