diff --git a/src/schema.jl b/src/schema.jl index dd91550a..0253e451 100644 --- a/src/schema.jl +++ b/src/schema.jl @@ -53,6 +53,19 @@ Base.merge!(a::Schema, b::Schema) = (merge!(a.schema, b.schema); a) Base.keys(schema::Schema) = keys(schema.schema) Base.haskey(schema::Schema, key) = haskey(schema.schema, key) +function Base.:(==)(first::Schema, second::Schema) + first === second && return true + first.schema === second.schema && return true + length(first.schema) != length(second.schema) && return false + for key in keys(first) + !haskey(second, key) && return false + second[key] != first[key] && return false + end + true +end + +Base.hash(schema::Schema, h::UInt) = hash(schema.schema, h) + """ schema([terms::AbstractVector{<:AbstractTerm}, ]data, hints::Dict{Symbol}) schema(term::AbstractTerm, data, hints::Dict{Symbol}) diff --git a/src/terms.jl b/src/terms.jl index b49a9027..2533ac2a 100644 --- a/src/terms.jl +++ b/src/terms.jl @@ -2,6 +2,14 @@ abstract type AbstractTerm end const TermOrTerms = Union{AbstractTerm, Tuple{AbstractTerm, Vararg{AbstractTerm}}} const TupleTerm = Tuple{TermOrTerms, Vararg{TermOrTerms}} +Base.hash(term::T, h::UInt) where {T<:AbstractTerm} = + foldl((h, x) -> hash(x, h), getfield(term, field) for field in fieldnames(T); init=h) + +function Base.:(==)(a::A, b::B) where {A<:AbstractTerm, B<:AbstractTerm} + fieldnames(A) == fieldnames(B) || return false + return all(getfield(a, field) == getfield(b, field) for field in fieldnames(A)) +end + width(::T) where {T<:AbstractTerm} = throw(ArgumentError("terms of type $T have undefined width")) @@ -127,7 +135,10 @@ FunctionTerm(forig::Fo, fanon::Fa, names::NTuple{N,Symbol}, FunctionTerm{Fo, Fa, names}(forig, fanon, exorig, args_parsed) width(::FunctionTerm) = 1 -Base.:(==)(a::FunctionTerm, b::FunctionTerm) = a.forig == b.forig && a.exorig == b.exorig +Base.:(==)(first::FunctionTerm, second::FunctionTerm) = + first.forig == second.forig && + first.exorig == second.exorig +Base.hash(term::FunctionTerm, h::UInt) = hash(term.forig, hash(term.exorig, h)) """ InteractionTerm{Ts} <: AbstractTerm @@ -191,6 +202,8 @@ via the [`implicit_intercept`](@ref) trait). struct InterceptTerm{HasIntercept} <: AbstractTerm end width(::InterceptTerm{H}) where {H} = H ? 1 : 0 +Base.:(==)(first::InterceptTerm{T}, second::InterceptTerm{S}) where {T,S} = T == S + # Typed terms """ diff --git a/test/schema.jl b/test/schema.jl index 3f5a2219..c432282a 100644 --- a/test/schema.jl +++ b/test/schema.jl @@ -1,5 +1,4 @@ @testset "schemas" begin - using StatsModels: schema, apply_schema, FullRank @testset "no-op apply_schema" begin @@ -70,4 +69,56 @@ end + @testset "basic hash and equality" begin + f = @formula(y ~ 1 + a + log(b) + c + b & c) + y = rand(9) + b = rand(9) + + df = (y = y, a = 1:9, b = b, c = repeat(["d", "e", "f"], 3)) + f = apply_schema(f, schema(f, df)) + @test f == apply_schema(f, schema(f, df)) + + sch1 = schema(f, df) + sch2 = schema(f, df) + @test sch1 == sch2 + @test sch1 !== sch2 + @test hash(sch1) == hash(sch2) + + # double categorical column c to test for invariance based on levels + df2 = (y = y, a = 1:9, b = b, c = [df.c; df.c]) + @test schema(df) == schema(df2) + @test hash(schema(df)) == hash(schema(df2)) + @test apply_schema(f, schema(df)) == apply_schema(f, schema(df2)) + + # different levels + df3 = (y = y, a = 1:9, b = b, c = repeat(["a", "b", "c"], 3)) + @test schema(df) != schema(df3) + + # different length, so different summary stats for continuous + df4 = (y = [df.y; df.y], a = [1:9; 1:9], b = [b; b], c = [df.c; df.c]) + @test schema(df) != schema(df4) + + # different names for some columns + df5 = (z = y, a = 1:9, b = b, c = repeat(["d", "e", "f"], 3)) + @test schema(df) != schema(df5) + + # different values in continuous column so different stats + df6 = (y = y, a = 2:10, b = b, c = repeat(["a", "b", "c"], 3)) + @test schema(df) != schema(df6) + + # different names? + df7 = (w = y, d = 1:9, x = b, z = repeat(["d", "e", "f"], 3)) + @test schema(df) != schema(df7) + + # missing column + df8 = (y = y, a = 1:9, c = repeat(["d", "e", "f"], 3)) + @test schema(df) != schema(df8) + + # different coding/hints + sch = schema(df, Dict(:c => DummyCoding(base="e"))) + sch2 = schema(df, Dict(:c => EffectsCoding(base="e"))) + sch3 = schema(df, Dict(:y => DummyCoding())) + @test sch != sch2 + @test sch != sch3 + end end diff --git a/test/terms.jl b/test/terms.jl index 09c0199d..91603467 100644 --- a/test/terms.jl +++ b/test/terms.jl @@ -30,26 +30,36 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) = @test t0.var == var([1,2,3]) @test t0.min == 1.0 @test t0.max == 3.0 + @test t0 == concrete_term(t, [3, 2, 1]) + @test hash(t0) == hash(concrete_term(t, [3, 2, 1])) t1 = concrete_term(t, [:a, :b, :c]) @test t1.contrasts isa StatsModels.ContrastsMatrix{DummyCoding} @test string(t1) == "aaa" @test mimestring(t1) == "aaa(DummyCoding:3→2)" + @test t1 == concrete_term(t, [:a, :b, :c]) + @test t1 !== concrete_term(t, [:a, :b, :c]) + @test hash(t1) == hash(concrete_term(t, [:a, :b, :c])) t3 = concrete_term(t, [:a, :b, :c], DummyCoding()) @test t3.contrasts isa StatsModels.ContrastsMatrix{DummyCoding} @test string(t3) == "aaa" @test mimestring(t3) == "aaa(DummyCoding:3→2)" + @test t1 == t3 + @test hash(t1) == hash(t3) t2 = concrete_term(t, [:a, :a, :b], EffectsCoding()) @test t2.contrasts isa StatsModels.ContrastsMatrix{EffectsCoding} @test mimestring(t2) == "aaa(EffectsCoding:2→1)" @test string(t2) == "aaa" + @test t2 == concrete_term(t, [:a, :a, :b], EffectsCoding()) + @test t1 != t2 t2full = concrete_term(t, [:a, :a, :b], StatsModels.FullDummyCoding()) @test t2full.contrasts isa StatsModels.ContrastsMatrix{StatsModels.FullDummyCoding} @test mimestring(t2full) == "aaa(StatsModels.FullDummyCoding:2→2)" @test string(t2full) == "aaa" + @test t1 != t2full end @testset "term operators" begin @@ -89,18 +99,6 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) = @test +a == a end - @testset "uniqueness of FunctionTerms" begin - f1 = @formula(y ~ lag(x,1) + lag(x,1)) - f2 = @formula(y ~ lag(x,1)) - f3 = @formula(y ~ lag(x,1) + lag(x,2)) - - @test f1.rhs == f2.rhs - @test f1.rhs != f3.rhs - - ## addition of two identical function terms - @test f2.rhs + f2.rhs == f2.rhs - end - @testset "expand nested tuples of terms during apply_schema" begin sch = schema((a=rand(10), b=rand(10), c=rand(10))) @@ -173,6 +171,44 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) = end + @testset "equality of function terms" begin + # for now, we use `@formula` to construct the function terms + f1 = @formula(0 ~ (1 | x)).rhs + f2 = @formula(0 ~ (1 | x)).rhs + @test f1 !== f2 + @test f1 == f2 + @test hash(f1) == hash(f2) + + f3 = @formula(0 ~ (1 % x)).rhs + @test f1 != f3 + @test hash(f1) != hash(f3) + + f4 = @formula(0 ~ (x | 1)).rhs + @test f1 != f4 + @test hash(f1) != hash(f4) + + f5 = @formula(0 ~ (1 & y | x)).rhs + @test f1 != f5 + @test hash(f1) != hash(f5) + + ff1 = @formula(y ~ 1 + x + x & y + (1 + x | g)) + ff2 = @formula(y ~ 1 + x + x & y + (1 + x | g)) + @test ff1 == ff2 + @test hash(ff1) == hash(ff2) + end + + @testset "uniqueness of FunctionTerms" begin + f1 = @formula(y ~ lag(x,1) + lag(x,1)) + f2 = @formula(y ~ lag(x,1)) + f3 = @formula(y ~ lag(x,1) + lag(x,2)) + + @test f1.rhs == f2.rhs + @test f1.rhs != f3.rhs + + ## addition of two identical function terms + @test f2.rhs + f2.rhs == f2.rhs + end + @testset "Tuple terms" begin using StatsModels: TermOrTerms, TupleTerm, Term a, b, c = Term.((:a, :b, :c))