Skip to content

Commit

Permalink
Implemented Tables.jl API and support for GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Dec 15, 2024
1 parent d5cabb6 commit 1169985
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 64 deletions.
34 changes: 33 additions & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.11.1"
manifest_format = "2.0"
project_hash = "ef638d9b7dd3411a6b5c86406ac77e48f19d8d42"
project_hash = "48b0ecc3de09367019241b9866f1be8d1ab8f4cc"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand All @@ -13,6 +13,21 @@ deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0"

[[deps.DataAPI]]
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0"

[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
version = "1.0.0"

[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
version = "1.11.0"
Expand All @@ -27,6 +42,11 @@ deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.27+1"

[[deps.OrderedCollections]]
git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.7.0"

[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -36,6 +56,18 @@ version = "1.11.0"
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.1"

[[deps.Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"]
git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.12.0"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ version = "0.8.0"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Aqua = "0.8"
DataFrames = "1.5"
Documenter = "1.2"
LinearAlgebra = "1.8"
Random = "1.8"
Tables = "1.12.0"
Test = "1.8"
julia = "1.8"

Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,5 @@ CausalELM.clip_if_binary
CausalELM.@model_config
CausalELM.@standard_input_data
CausalELM.generate_folds
CausalELM.convert_if_table
```
6 changes: 3 additions & 3 deletions docs/src/guide/doublemachinelearning.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ the residuals from the first stage models.

## Step 1: Initialize a Model
The DoubleMachineLearning constructor takes at least three arguments—covariates, a
treatment statuses, and outcomes, all of which may be either an array or any struct that
implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary, count,
or continuous treatments and binary, count, continuous, or time to event outcomes.
treatment statuses, and outcomes, all of which may be either an AbstractArray or any struct
that implements the Tables.jl interface (e.g. DataFrames). This estimator supports binary,
count, or continuous treatments and binary, count, continuous, or time to event outcomes.

!!! note
Non-binary categorical outcomes are treated as continuous.
Expand Down
6 changes: 3 additions & 3 deletions docs/src/guide/gcomputation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ steps for using G-computation in CausalELM are below.

## Step 1: Initialize a Model
The GComputation constructor takes at least three arguments: covariates, treatment statuses,
outcomes, all of which can be either an array or any data structure that implements the
Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments and
binary, continuous, time to event, and count outcome variables.
outcomes, all of which can be either an AbstractArray or any data structure that implements
the Tables.jl interface (e.g. DataFrames). This implementation supports binary treatments
and binary, continuous, time to event, and count outcome variables.

!!! note
Non-binary categorical outcomes are treated as continuous.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/guide/its.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Estimating an interrupted time series design in CausalELM consists of three step
## Step 1: Initialize an interrupted time series estimator
The InterruptedTimeSeries constructor takes at least four agruments: pre-event covariates,
pre-event outcomes, post-event covariates, and post-event outcomes, all of which can be
either an array or any data structure that implements the Tables.jl interface (e.g.
either an AbstractArray or any data structure that implements the Tables.jl interface (e.g.
DataFrames). The interrupted time series estimator assumes outcomes are either continuous,
count, or time to event variables.

Expand Down
15 changes: 8 additions & 7 deletions docs/src/guide/metalearners.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ continuous outcomes.
Kennedy, Edward H. "Towards optimal doubly robust estimation of heterogeneous causal
effects." Electronic Journal of Statistics 17, no. 2 (2023): 3008-3049.

# Initialize a Metalearner
# Step 1: Initialize a Metalearner
S-learners, T-learners, X-learners, R-learners, and doubly robust estimators all take at
least three arguments—covariates, treatment statuses, and outcomes, all of which can be
either an array or any struct that implements the Tables.jl interface (e.g. DataFrames). S,
T, X, and doubly robust learners support binary treatment variables and binary, continuous,
count, or time to event outcomes. The R-learning estimator supports binary, continuous, or
count treatment variables and binary, continuous, count, or time to event outcomes.
either an AbstractArray or any struct that implements the Tables.jl interface (e.g.
DataFrames). S, T, X, and doubly robust learners support binary treatment variables and
binary, continuous, count, or time to event outcomes. The R-learning estimator supports
binary, continuous, or count treatment variables and binary, continuous, count, or time to
event outcomes.

!!! note
Non-binary categorical outcomes are treated as continuous.
Expand Down Expand Up @@ -64,7 +65,7 @@ r_learner = RLearner(X, Y, T)
dr_learner = DoublyRobustLearner(X, T, Y)
```

# Estimate the CATE
# Step 2: Estimate the CATE
We can estimate the CATE for all the models by passing them to estimate_causal_effect!.
```julia
estimate_causal_effect!(s_learner)
Expand All @@ -74,7 +75,7 @@ estimate_causal_effect!(r_learner)
estimate_causal_effect!(dr_lwarner)
```

# Get a Summary
# Step 3: Get a Summary
We can get a summary of the model by pasing the model to the summarize method.

!!!note
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ these libraries are:
econometrics, and biostatistics.

### Installation
CausalELM requires Julia version 1.7 or greater and can be installed from the REPL as shown
CausalELM requires Julia version 1.8 or greater and can be installed from the REPL as shown
below.
```julia
using Pkg
Expand Down
2 changes: 2 additions & 0 deletions docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
### Added
* Implemented randomization inference-based confidence intervals [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
* Added marginal effects to model summaries [#78](https://github.com/dscolby/CausalELM.jl/issues/78)
* CausalELM models now support any AbstractArray data type, including support for using GPUs with CuArrays or similar structures for Mac, Intel, and AMD hardware[#37](https://github.com/dscolby/CausalELM.jl/issues/37)
### Fixed
* Removed unnecessary include and using statements
* Slightly sped up the randomization inference implementation and clarified it in the docs [#77](https://github.com/dscolby/CausalELM.jl/issues/77)
* Fixed the randomization inference index selection procedure for interrupted time series estimators
* Inlined certain methods to slightly improve performance [#76](https://github.com/dscolby/CausalELM.jl/issues/76)
* CausalELM models now support any data structure that implements the Tables.jl API, not just DataFrames

## Version [v0.7.0](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.7.0) - 2024-06-22
### Added
Expand Down
39 changes: 22 additions & 17 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ abstract type CausalEstimator end
Initialize an interrupted time series estimator.
# Arguments
- `X₀::Any`: array or DataFrame of covariates from the pre-treatment period.
- `Y₁::Any`: array or DataFrame of outcomes from the pre-treatment period.
- `X₁::Any`: array or DataFrame of covariates from the post-treatment period.
- `Y₁::Any`: array or DataFrame of outcomes from the post-treatment period.
- `X₀::Any`: AbstractArray or Tables.jl API compliant data structure of covariates from the
pre-treatment period.
- `Y₁::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes from the
pre-treatment period.
- `X₁::Any`: AbstractArray or Tables.jl API compliant data structure of covariates from the
post-treatment period.
- `Y₁::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes from the
post-treatment period.
# Keywords
- `activation::Function=swish`: activation function to use.
Expand Down Expand Up @@ -44,10 +48,10 @@ julia> m3 = InterruptedTimeSeries(x₀_df, y₀_df, x₁_df, y₁_df)
```
"""
mutable struct InterruptedTimeSeries
X₀::Array{Float64}
Y₀::Array{Float64}
X₁::Array{Float64}
Y₁::Array{Float64}
X₀::AbstractArray{<: Real}
Y₀::AbstractArray{<: Real}
X₁::AbstractArray{<: Real}
Y₁::AbstractArray{<: Real}
marginal_effect::Float64
@model_config individual_effect
end
Expand All @@ -65,7 +69,7 @@ function InterruptedTimeSeries(
autoregression::Bool=true,
)
# Convert to arrays
X₀, X₁, Y₀, Y₁ = Matrix{Float64}(X₀), Matrix{Float64}(X₁), Y₀[:, 1], Y₁[:, 1]
X₀, X₁, Y₀, Y₁ = convert_if_table.((X₀, X₁, Y₀, Y₁))

# Add autoregressive term
X₀ = ifelse(autoregression == true, reduce(hcat, (X₀, moving_average(Y₀))), X₀)
Expand Down Expand Up @@ -97,9 +101,9 @@ end
Initialize a G-Computation estimator.
# Arguments
- `X::Any`: array or DataFrame of covariates.
- `T::Any`: vector or DataFrame of treatment statuses.
- `Y::Any`: array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `quantity_of_interest::String`: ATE for average treatment effect or ATT for average
Expand Down Expand Up @@ -159,7 +163,7 @@ mutable struct GComputation <: CausalEstimator
end

# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

task = var_type(Y) isa Binary ? "classification" : "regression"

Expand Down Expand Up @@ -187,9 +191,10 @@ end
Initialize a double machine learning estimator with cross fitting.
# Arguments
- `X::Any`: array or DataFrame of covariates of interest.
- `T::Any`: vector or DataFrame of treatment statuses.
- `Y::Any`: array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
interest.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `activation::Function=swish`: activation function to use.
Expand Down Expand Up @@ -240,7 +245,7 @@ function DoubleMachineLearning(
folds::Integer=5,
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

# Shuffle data with random indices
indices = shuffle(1:length(Y))
Expand Down
42 changes: 22 additions & 20 deletions src/metalearners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ abstract type Metalearner end
Initialize a S-Learner.
# Arguments
- `X::Any`: an array or DataFrame of covariates.
- `T::Any`: an vector or DataFrame of treatment statuses.
- `Y::Any`: an array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `activation::Function=swish`: the activation function to use.
Expand Down Expand Up @@ -60,7 +60,7 @@ mutable struct SLearner <: Metalearner
)

# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

task = var_type(Y) isa Binary ? "classification" : "regression"

Expand Down Expand Up @@ -88,9 +88,9 @@ end
Initialize a T-Learner.
# Arguments
- `X::Any`: an array or DataFrame of covariates.
- `T::Any`: an vector or DataFrame of treatment statuses.
- `Y::Any`: an array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `activation::Function=swish`: the activation function to use.
Expand Down Expand Up @@ -140,7 +140,7 @@ mutable struct TLearner <: Metalearner
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

task = var_type(Y) isa Binary ? "classification" : "regression"

Expand Down Expand Up @@ -168,9 +168,9 @@ end
Initialize an X-Learner.
# Arguments
- `X::Any`: an array or DataFrame of covariates.
- `T::Any`: an vector or DataFrame of treatment statuses.
- `Y::Any`: an array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `activation::Function=swish`: the activation function to use.
Expand Down Expand Up @@ -221,7 +221,7 @@ mutable struct XLearner <: Metalearner
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

task = var_type(Y) isa Binary ? "classification" : "regression"

Expand Down Expand Up @@ -249,9 +249,10 @@ end
Initialize an R-Learner.
# Arguments
- `X::Any`: an array or DataFrame of covariates of interest.
- `T::Any`: an vector or DataFrame of treatment statuses.
- `Y::Any`: an array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
interest.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `activation::Function=swish`: the activation function to use.
Expand Down Expand Up @@ -301,7 +302,7 @@ function RLearner(
)

# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

# Shuffle data with random indices
indices = shuffle(1:length(Y))
Expand Down Expand Up @@ -333,9 +334,10 @@ end
Initialize a doubly robust CATE estimator.
# Arguments
- `X::Any`: an array or DataFrame of covariates of interest.
- `T::Any`: an vector or DataFrame of treatment statuses.
- `Y::Any`: an array or DataFrame of outcomes.
- `X::Any`: AbstractArray or Tables.jl API compliant data structure of covariates of
interest.
- `T::Any`: AbstractArray or Tables.jl API compliant data structure of treatment statuses.
- `Y::Any`: AbstractArray or Tables.jl API compliant data structure of outcomes.
# Keywords
- `activation::Function=swish`: the activation function to use.
Expand Down Expand Up @@ -386,7 +388,7 @@ function DoublyRobustLearner(
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y = convert_if_table.((X, T, Y))

# Shuffle data with random indices
indices = shuffle(1:length(Y))
Expand Down
Loading

0 comments on commit 1169985

Please sign in to comment.