From f8eca8f5b2a4e16f00f6f90afd57dc6a0800ae09 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Fri, 26 Jul 2024 18:45:18 +0800 Subject: [PATCH] typeintersect: fix bounds merging during inner `intersect_all`. This is a backport of #55299 --- src/subtype.c | 179 ++++++++++++++---------------------------------- test/subtype.jl | 43 ++++++++++-- 2 files changed, 88 insertions(+), 134 deletions(-) diff --git a/src/subtype.c b/src/subtype.c index 24c4d74440282..2f8f41d50c239 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -65,7 +65,6 @@ typedef struct jl_varbinding_t { jl_value_t *lb; jl_value_t *ub; int8_t right; // whether this variable came from the right side of `A <: B` - int8_t occurs; // occurs in any position int8_t occurs_inv; // occurs in invariant position int8_t occurs_cov; // # of occurrences in covariant position int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete @@ -168,7 +167,7 @@ static int current_env_length(jl_stenv_t *e) typedef struct { int8_t *buf; int rdepth; - int8_t _space[24]; // == 8 * 3 + int8_t _space[16]; // == 8 * 2 jl_gcframe_t gcframe; jl_value_t *roots[24]; } jl_savedenv_t; @@ -197,7 +196,6 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root) roots[i++] = v->ub; roots[i++] = (jl_value_t*)v->innervars; } - se->buf[j++] = v->occurs; se->buf[j++] = v->occurs_inv; se->buf[j++] = v->occurs_cov; v = v->prev; @@ -278,7 +276,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO v->ub = roots[i++]; v->innervars = (jl_array_t*)roots[i++]; } - v->occurs = se->buf[j++]; v->occurs_inv = se->buf[j++]; v->occurs_cov = se->buf[j++]; v = v->prev; @@ -289,15 +286,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO memset(&e->envout[e->envidx], 0, (e->envsz - e->envidx)*sizeof(void*)); } -static void clean_occurs(jl_stenv_t *e) -{ - jl_varbinding_t *v = e->vars; - while (v) { - v->occurs = 0; - v = v->prev; - } -} - #define flip_offset(e) ((e)->Loffset *= -1) // type utilities @@ -586,6 +574,8 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi) static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param); +#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0) + static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT { jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; @@ -666,8 +656,6 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par // of determining whether the variable is concrete. static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT { - if (vb != NULL) - vb->occurs = 1; if (vb != NULL && param) { // saturate counters at 2; we don't need values bigger than that if (param == 2 && e->invdepth > vb->depth0) { @@ -898,7 +886,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e) static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param) { u = unalias_unionall(u, e); - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, e->invdepth, NULL, e->vars }; JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars); e->vars = &vb; @@ -3198,7 +3186,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ { jl_value_t *res = NULL; jl_savedenv_t se; - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, e->invdepth, NULL, e->vars }; JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars); save_env(e, &se, 1); @@ -3226,7 +3214,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ vb.ub = vb.var->ub; } restore_env(e, &se, vb.constraintkind == 1 ? 1 : 0); - vb.occurs = vb.occurs_cov = vb.occurs_inv = 0; + vb.occurs_cov = vb.occurs_inv = 0; res = intersect_unionall_(t, u, e, R, param, &vb); } } @@ -3893,73 +3881,12 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa return jl_bottom_type; } -static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count) +static int merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se, int count) { - if (count == 0) - alloc_env(e, se, 1); - jl_value_t **roots = NULL; - int nroots = 0; - if (se->gcframe.nroots == JL_GC_ENCODE_PUSHARGS(1)) { - jl_svec_t *sv = (jl_svec_t*)se->roots[0]; - assert(jl_is_svec(sv)); - roots = jl_svec_data(sv); - nroots = jl_svec_len(sv); - } - else { - roots = se->roots; - nroots = se->gcframe.nroots >> 2; - } - int n = 0; - jl_varbinding_t *v = e->vars; - v = e->vars; - while (v != NULL) { - if (count == 0) { - // need to initialize this - se->buf[n] = 0; - se->buf[n+1] = 0; - se->buf[n+2] = 0; - } - if (v->occurs) { - // only merge lb/ub/innervars if this var occurs. - jl_value_t *b1, *b2; - b1 = roots[n]; - JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame - b2 = v->lb; - JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots - roots[n] = b1 ? simple_meet(b1, b2, 0) : b2; - b1 = roots[n+1]; - JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame - b2 = v->ub; - JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots - roots[n+1] = b1 ? simple_join(b1, b2) : b2; - b1 = roots[n+2]; - JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame - b2 = (jl_value_t*)v->innervars; - JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots - if (b2 && b1 != b2) { - if (b1) - jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2); - else - roots[n+2] = b2; - } - // record the meeted vars. - se->buf[n] = 1; - } - // always merge occurs_inv/cov by max (never decrease) - if (v->occurs_inv > se->buf[n+1]) - se->buf[n+1] = v->occurs_inv; - if (v->occurs_cov > se->buf[n+2]) - se->buf[n+2] = v->occurs_cov; - n = n + 3; - v = v->prev; + if (count == 0) { + save_env(e, me, 1); + return 1; } - assert(n == nroots); (void)nroots; - return count + 1; -} - -// merge untouched vars' info. -static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se) -{ jl_value_t **merged = NULL; jl_value_t **saved = NULL; int nroots = 0; @@ -3981,47 +3908,46 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se) } assert(nroots == current_env_length(e) * 3); assert(nroots % 3 == 0); - for (int n = 0; n < nroots; n = n + 3) { - if (merged[n] == NULL) - merged[n] = saved[n]; - if (merged[n+1] == NULL) - merged[n+1] = saved[n+1]; - jl_value_t *b1, *b2; + int m = 0, n = 0; + jl_varbinding_t *v = e->vars; + while (v != NULL) { + jl_value_t *b0, *b1, *b2; + // merge `lb` + b0 = saved[n]; + b1 = merged[n]; + JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame + b2 = v->lb; + JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots + merged[n] = (b1 == b0 || b2 == b0) ? b0 : simple_meet(b1, b2, 0); + // merge `ub` + b0 = saved[n+1]; + b1 = merged[n+1]; + JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame + b2 = v->ub; + JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots + merged[n+1] = (b1 == b0 || b2 == b0) ? b0 : simple_join(b1, b2); + // merge `innervars` b1 = merged[n+2]; JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame - b2 = saved[n+2]; - JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know this came from our GC frame + b2 = (jl_value_t*)v->innervars; + JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots if (b2 && b1 != b2) { if (b1) jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2); else merged[n+2] = b2; } - me->buf[n] |= se->buf[n]; - } -} - -static void expand_local_env(jl_stenv_t *e, jl_value_t *res) -{ - jl_varbinding_t *v = e->vars; - // Here we pull in some typevar missed in fastpath. - while (v != NULL) { - v->occurs = v->occurs || jl_has_typevar(res, v->var); - assert(v->occurs == 0 || v->occurs == 1); - v = v->prev; - } - v = e->vars; - while (v != NULL) { - if (v->occurs == 1) { - jl_varbinding_t *v2 = e->vars; - while (v2 != NULL) { - if (v2 != v && v2->occurs == 0) - v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var)); - v2 = v2->prev; - } - } + // merge occurs_inv/cov by max (never decrease) + if (v->occurs_inv > me->buf[m]) + me->buf[m] = v->occurs_inv; + if (v->occurs_cov > me->buf[m+1]) + me->buf[m+1] = v->occurs_cov; + m = m + 2; + n = n + 3; v = v->prev; } + assert(n == nroots); (void)nroots; + return count + 1; } static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) @@ -4034,12 +3960,9 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) jl_savedenv_t se, me; save_env(e, &se, 1); int niter = 0, total_iter = 0; - clean_occurs(e); is[0] = intersect(x, y, e, 0); // root - if (is[0] != jl_bottom_type) { - expand_local_env(e, is[0]); - niter = merge_env(e, &me, niter); - } + if (is[0] != jl_bottom_type) + niter = merge_env(e, &me, &se, niter); restore_env(e, &se, 1); while (next_union_state(e, 1)) { if (e->emptiness_only && is[0] != jl_bottom_type) @@ -4047,12 +3970,9 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) e->Runions.depth = 0; e->Runions.more = 0; - clean_occurs(e); is[1] = intersect(x, y, e, 0); - if (is[1] != jl_bottom_type) { - expand_local_env(e, is[1]); - niter = merge_env(e, &me, niter); - } + if (is[1] != jl_bottom_type) + niter = merge_env(e, &me, &se, niter); restore_env(e, &se, 1); if (is[0] == jl_bottom_type) is[0] = is[1]; @@ -4061,13 +3981,18 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) is[0] = jl_type_union(is, 2); } total_iter++; - if (niter > 4 || total_iter > 400000) { + if (has_next_union_state(e, 1) && (niter > 4 || total_iter > 400000)) { is[0] = y; + // we give up precise intersection here, just restore the saved env + restore_env(e, &se, 1); + if (niter > 0) { + free_env(&me); + niter = 0; + } break; } } if (niter) { - final_merge_env(e, &me, &se); restore_env(e, &me, 1); free_env(&me); } @@ -4552,7 +4477,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) { static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot) { - jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot }; + jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, NULL, troot }; jl_value_t *nt; JL_GC_PUSH2(&vb.innervars, &nt); if (jl_is_unionall(u->body)) diff --git a/test/subtype.jl b/test/subtype.jl index 93deda2a9929a..9ee158e710c80 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -2374,12 +2374,41 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})}, @testintersect(S, T, !Union{}) end -# A simple case which has a small local union. -# make sure the env is not widened too much when we intersect(Int8, Int8). -struct T48006{A1,A2,A3} end -@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}}, - Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}}, - Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8}) +let S = Dict{Int, S1} where {F1, S1<:Union{Int8, Val{F1}}}, + T = Dict{F2, S2} where {F2, S2<:Union{Int8, Val{F2}}} + @test_broken typeintersect(S, T) == Dict{Int, S} where S<:Union{Val{Int}, Int8} + @test typeintersect(T, S) == Dict{Int, S} where S<:Union{Val{Int}, Int8} +end + +# Ensure inner `intersect_all` never under-esitimate. +let S = Tuple{F1, Dict{Int, S1}} where {F1, S1<:Union{Int8, Val{F1}}}, + T = Tuple{Any, Dict{F2, S2}} where {F2, S2<:Union{Int8, Val{F2}}} + @test Tuple{Nothing, Dict{Int, Int8}} <: S + @test Tuple{Nothing, Dict{Int, Int8}} <: T + @test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(S, T) + @test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(T, S) +end + +let S = Tuple{F1, Val{S1}} where {F1, S1<:Dict{F1}} + T = Tuple{Any, Val{S2}} where {F2, S2<:Union{map(T->Dict{T}, Base.BitInteger_types)...}} + ST = typeintersect(S, T) + TS = typeintersect(S, T) + for U in Base.BitInteger_types + @test Tuple{U, Val{Dict{U,Nothing}}} <: S + @test Tuple{U, Val{Dict{U,Nothing}}} <: T + @test Tuple{U, Val{Dict{U,Nothing}}} <: ST + @test Tuple{U, Val{Dict{U,Nothing}}} <: TS + end +end + +#issue 55206 +struct T55206{A,B<:Complex{A},C<:Union{Dict{Nothing},Dict{A}}} end +@testintersect(T55206, T55206{<:Any,<:Any,<:Dict{Nothing}}, T55206{A,<:Complex{A},<:Dict{Nothing}} where {A}) +@testintersect( + Tuple{Dict{Int8, Int16}, Val{S1}} where {F1, S1<:AbstractSet{F1}}, + Tuple{Dict{T1, T2}, Val{S2}} where {T1, T2, S2<:Union{Set{T1},Set{T2}}}, + Tuple{Dict{Int8, Int16}, Val{S1}} where {S1<:Union{Set{Int8},Set{Int16}}} +) f48167(::Type{Val{L2}}, ::Type{Union{Val{L1}, Set{R}}}) where {L1, R, L2<:L1} = 1 f48167(::Type{Val{L1}}, ::Type{Union{Val{L2}, Set{R}}}) where {L1, R, L2<:L1} = 2 @@ -2548,7 +2577,7 @@ end let T = Tuple{Union{Type{T}, Type{S}}, Union{Val{T}, Val{S}}, Union{Val{T}, S}} where T<:Val{A} where A where S<:Val, S = Tuple{Type{T}, T, Val{T}} where T<:(Val{S} where S<:Val) # optimal = Union{}? - @test typeintersect(T, S) == Tuple{Type{A}, Union{Val{A}, Val{S} where S<:Union{Val, A}, Val{x} where x<:Val, Val{x} where x<:Union{Val, A}}, Val{A}} where A<:(Val{S} where S<:Val) + @test typeintersect(T, S) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {S<:Val, T<:Val} @test typeintersect(S, T) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {T<:Val, S<:(Union{Val{A}, Val} where A)} end