Skip to content

Commit

Permalink
[Internals] Few simplifications (#34)
Browse files Browse the repository at this point in the history
* Change type hierarchy

* Simplify fitting

* Simplify test fiting

* add missing function

* pass tests

* Better version with dataframes

* adapt the show function

* Adapt example code in the docs
  • Loading branch information
lrnv authored May 10, 2024
1 parent 3cca726 commit e8b1751
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 94 deletions.
30 changes: 13 additions & 17 deletions docs/src/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ We are now interested in comparing the different groups of patients defined by v

```@example 2
pp_sex = fit(PoharPerme, @formula(Surv(time5,status5)~sex), colrec, slopop)
pp_males = pp_sex[pp_sex.sex .== :male,:estimator][1]
pp_females = pp_sex[pp_sex.sex .== :female,:estimator][1]
```

When comparing at time $1826$, we notice that the survival probability is slightly inferior for men than for women ($0.433 < 0.449$). It is also more probable for the women to die from other causes than the men seeing as $0.0255 > 0.025$. Still, the differences are minimal. Let's confirm this with the Grafféo log-rank test:
Expand All @@ -188,14 +190,10 @@ The p-value is indeed above $0.05$. We cannot reject the null hypothesis $H_0$ a
As for the age, we will define two different groups: individuals aged 65 and above and those who are not.

```@example 2
colrec.age65 .= Bool(false)
for i in 1:nrow(colrec)
if colrec.age[i] >= 65*365.241
colrec.age65[i] = true
end
end
colrec.age65 .= ifelse.(colrec.age .>= 65*365.241, :old, :young)
pp_age65 = fit(PoharPerme, @formula(Surv(time5,status5)~age65), colrec, slopop)
pp_young = pp_age65[pp_age65.age65 .== :young, :estimator][1]
pp_old = pp_age65[pp_age65.age65 .== :old, :estimator][1]
```

Here, the difference between the two is much more important. In the first group, the individuals are aged under 65 and at $5$ years time, they have a $50.1$% chance of survival. On the other hand, the individuals aged 65 and up have a $40.1$% chance of survival.
Expand All @@ -214,29 +212,27 @@ When plotting both we get:


```@example 2
conf_int_men = confint(pp_sex[1]; level = 0.05)
conf_int_men = confint(pp_males; level = 0.05)
lower_bounds_men = [lower[1] for lower in conf_int_men]
upper_bounds_men = [upper[2] for upper in conf_int_men]
conf_int_women = confint(pp_sex[2]; level = 0.05)
conf_int_women = confint(pp_females; level = 0.05)
lower_bounds_women = [lower[1] for lower in conf_int_women]
upper_bounds_women = [upper[2] for upper in conf_int_women]
conf_int_under65 = confint(pp_age65[1]; level = 0.05)
conf_int_under65 = confint(pp_young; level = 0.05)
lower_bounds_under65 = [lower[1] for lower in conf_int_under65]
upper_bounds_under65 = [upper[2] for upper in conf_int_under65]
conf_int_65 = confint(pp_age65[2]; level = 0.05)
conf_int_65 = confint(pp_old; level = 0.05)
lower_bounds_65 = [lower[1] for lower in conf_int_65]
upper_bounds_65 = [upper[2] for upper in conf_int_65]
plot1 = plot(pp_sex[1].grid, pp_sex[1].Sₑ, ribbon=(pp_sex[1].Sₑ - lower_bounds_men, upper_bounds_men - pp_sex[1].Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "men")
plot1 = plot!(pp_sex[2].grid, pp_sex[2].Sₑ, ribbon=(pp_sex[2].Sₑ - lower_bounds_women, upper_bounds_women - pp_sex[2].Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "women")
plot2 = plot(pp_age65[1].grid, pp_age65[1].Sₑ, ribbon=(pp_age65[1].Sₑ - lower_bounds_under65, upper_bounds_under65 - pp_age65[1].Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "Under 65")
plot1 = plot(pp_males.grid, pp_males.Sₑ, ribbon=(pp_males.Sₑ - lower_bounds_men, upper_bounds_men - pp_males.Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "men")
plot1 = plot!(pp_females.grid, pp_females.Sₑ, ribbon=(pp_females.Sₑ - lower_bounds_women, upper_bounds_women - pp_females.Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "women")
plot2 = plot!(pp_age65[2].grid, pp_age65[2].Sₑ, ribbon=(pp_age65[2].Sₑ - lower_bounds_65, upper_bounds_65 - pp_age65[2].Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "65 and up")
plot2 = plot(pp_young.grid, pp_young.Sₑ, ribbon=(pp_young.Sₑ - lower_bounds_under65, upper_bounds_under65 - pp_young.Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "Under 65")
plot2 = plot!(pp_old.grid, pp_old.Sₑ, ribbon=(pp_old.Sₑ - lower_bounds_65, upper_bounds_65 - pp_old.Sₑ), xlab = "Time (days)", ylab = "Net survival", label = "65 and up")
plot(plot1, plot2, layout = (1, 2))
```
Expand Down
2 changes: 1 addition & 1 deletion src/EdererI.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct EdererIMethod end
struct EdererIMethod<:NPNSMethod end

"""
EdererI
Expand Down
2 changes: 1 addition & 1 deletion src/EdererII.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct EdererIIMethod end
struct EdererIIMethod<:NPNSMethod end

"""
EdererII
Expand Down
49 changes: 9 additions & 40 deletions src/GraffeoTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,48 +142,17 @@ end
# The fitting and formula interfaces should be here.

function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:GraffeoTest}
rate_predictors = String.([RateTables.predictors(rt)...])

expected_columns = [rate_predictors...,"age","year"]
missing_columns = filter(name -> !(name in names(df)), expected_columns)
if !isempty(missing_columns)
throw(ArgumentError("Missing columns in data: $missing_columns"))
end

strata = ones(nrow(df))
group = ones(nrow(df))
strata_terms = []
group_terms = []

if typeof(formula.rhs) == Term
group = select(df, StatsModels.termvars(formula.rhs))
group = [join(row, " ") for row in eachrow(group)]
elseif typeof(formula.rhs) <: FunctionTerm{typeof(Strata)}
strata = select(df, StatsModels.termvars(formula.rhs))
strata = [join(row, " ") for row in eachrow(strata)]
else
for myterm in formula.rhs
is_strata = typeof(myterm) <: FunctionTerm{typeof(Strata)}
if is_strata
append!(strata_terms, StatsModels.termvars(myterm))
else
push!(group_terms, Symbol(myterm))
end
end
end

if !isempty(group_terms)
group = select(df, group_terms)
group = [join(row, " ") for row in eachrow(group)]
end

if !isempty(strata_terms)
strata = select(df, strata_terms)
strata = [join(row, " ") for row in eachrow(strata)]
end
terms = StatsModels.termvars(formula.rhs)
tf = typeof(formula.rhs)
types = (tf <: AbstractTerm) ? [tf] : typeof.(formula.rhs)
are_strata = [t <: FunctionTerm{typeof(Strata)} for t in types]

formula = apply_schema(formula,schema(df))
resp = modelcols(formula.lhs,df)
strata = groupindices(groupby(df,terms[are_strata]))
group = groupindices(groupby(df,terms[(!).(are_strata)]))

resp = modelcols(apply_schema(formula,schema(df)).lhs,df)
rate_predictors = _get_rate_predictors(rt,df)

return GraffeoTest(resp[:,1], resp[:,2], df.age, df.year, select(df,rate_predictors), strata, group, rt)
end
2 changes: 1 addition & 1 deletion src/Hakulinen.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct HakulinenMethod end
struct HakulinenMethod<:NPNSMethod end

"""
Hakulinen
Expand Down
64 changes: 31 additions & 33 deletions src/NPNSEstimator.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
abstract type NonParametricEstimator <: StatisticalModel end # maybe this one is superfluous now.


struct NPNSEstimator{Method} <: NonParametricEstimator
abstract type NPNSMethod end
struct NPNSEstimator{Method} <: StatisticalModel
Sₑ::Vector{Float64}
∂Λₑ::Vector{Float64}
∂Λₒ::Vector{Float64}
∂Λₚ::Vector{Float64}
σₑ::Vector{Float64}
grid::Vector{Float64}
function NPNSEstimator{Method}(T, Δ, age, year, rate_preds, ratetable) where Method
function NPNSEstimator{Method}(T, Δ, age, year, rate_preds, ratetable) where Method<:NPNSMethod
grid = mk_grid(T,1) # precision is always 1 ?
∂Λₒ, ∂Λₚ, ∂σₑ = Λ(Method, T, Δ, age, year, rate_preds, ratetable, grid)
∂Λₑ = ∂Λₒ .- ∂Λₚ
Expand All @@ -22,7 +20,7 @@ function mk_grid(times,prec)
M = maximum(times)+1
return unique(sort([(1:prec:M)..., times..., M]))
end
function Λ(::Type{M}, T, Δ, age, year, rate_preds, ratetable, grid) where M
function Λ(::Type{M}, T, Δ, age, year, rate_preds, ratetable, grid) where M<:NPNSMethod
num_excess = zero(grid)
num_pop = zero(grid)
num_variance = zero(grid)
Expand All @@ -32,37 +30,31 @@ function Λ(::Type{M}, T, Δ, age, year, rate_preds, ratetable, grid) where M
return num_excess ./ den_excess, num_pop ./ den_pop, num_variance ./ (den_excess.^2)
end

function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:NPNSEstimator}
column_names = names(df)
rate_predictors = String.([RateTables.predictors(rt)...])

expected_columns = [rate_predictors...,"age","year"]
missing_columns = filter(name -> !(name in column_names), expected_columns)
if !isempty(missing_columns)
throw(ArgumentError("Missing columns in data: $missing_columns"))
function _get_rate_predictors(rt,df)
prd = [RateTables.predictors(rt)...]
cl = Symbol.(names(df))
if !(all(prd .∈ Ref(cl)) && (:age cl) && (:year cl))
throw(ArgumentError("Missing columns in data : the chosen ratetable expects colums :age, :year and $(prd) to be present in the dataset."))
end
return prd
end

function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:NPNSEstimator}
rate_predictors = _get_rate_predictors(rt,df)
formula_applied = apply_schema(formula,schema(df))

if isa(formula.rhs, ConstantTerm)
# then there is no predictors.
if isa(formula.rhs, ConstantTerm) # No predictors
resp = modelcols(formula_applied.lhs, df)
return E(resp[:,1], resp[:,2], df.age, df.year, select(df,rate_predictors), rt)
else
nms = StatsModels.termnames(formula.rhs)
if isa(nms, String)
pred_names = [nms]
else
pred_names = nms
end
# we could simply group by the left side and apply fit() again, that would make sense.

new_df = groupby(df, pred_names)
pp = Vector{E}()
for i in 1:nrow(unique(df[!,pred_names]))
resp2 = modelcols(formula_applied.lhs, new_df[i])
push!(pp,E(resp2[:,1], resp2[:,2], new_df[i].age, new_df[i].year, select(new_df[i],rate_predictors), rt))
end
return pp
gdf = groupby(df, StatsModels.termnames(formula.rhs))
return rename(combine(gdf, dfᵢ -> begin
resp2 = modelcols(formula_applied.lhs, dfᵢ)
E(resp2[:,1], resp2[:,2], dfᵢ.age, dfᵢ.year, select(dfᵢ, rate_predictors), rt)
end
), :x1 => :estimator)
end
end

Expand All @@ -76,8 +68,14 @@ function StatsAPI.confint(npe::E; level::Real=0.05) where E <: NPNSEstimator
end

function Base.show(io::IO, npe::E) where E <: NPNSEstimator
lower_bounds = [lower[1] for lower in confint(npe; level = 0.05)]
upper_bounds = [upper[2] for upper in confint(npe; level = 0.05)]
df = DataFrame(Sₑ = npe.Sₑ, ∂Λₑ = npe.∂Λₑ, σₑ=npe.σₑ, lower_95_CI = lower_bounds, upper_95_CI = upper_bounds)
show(io, df)
compact = get(io, :compact, false)
if !compact
print(io, "$(E)(t ∈ $(extrema(npe.grid))) with summary stats:\n ")
lower_bounds = [lower[1] for lower in confint(npe; level = 0.05)]
upper_bounds = [upper[2] for upper in confint(npe; level = 0.05)]
df = DataFrame(Sₑ = npe.Sₑ, ∂Λₑ = npe.∂Λₑ, σₑ=npe.σₑ, lower_95_CI = lower_bounds, upper_95_CI = upper_bounds)
show(io, df)
else
print(io, "$(E)(t ∈ $(extrema(npe.grid)))")
end
end
2 changes: 1 addition & 1 deletion src/PoharPerme.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct PoharPermeMethod end
struct PoharPermeMethod<:NPNSMethod end

"""
PoharPerme
Expand Down

0 comments on commit e8b1751

Please sign in to comment.