Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typeintersect: fix bounds merging during inner intersect_all. #55299

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 56 additions & 134 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ typedef struct jl_varbinding_t {
jl_value_t *lb;
jl_value_t *ub;
int8_t right; // whether this variable came from the right side of `A <: B`
int8_t occurs; // occurs in any position
int8_t occurs_inv; // occurs in invariant position
int8_t occurs_cov; // # of occurrences in covariant position
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
Expand Down Expand Up @@ -179,7 +178,7 @@ static int current_env_length(jl_stenv_t *e)
typedef struct {
int8_t *buf;
int rdepth;
int8_t _space[32]; // == 8 * 4
int8_t _space[24]; // == 8 * 3
jl_gcframe_t gcframe;
jl_value_t *roots[24]; // == 8 * 3
} jl_savedenv_t;
Expand Down Expand Up @@ -208,7 +207,6 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
roots[i++] = v->ub;
roots[i++] = (jl_value_t*)v->innervars;
}
se->buf[j++] = v->occurs;
se->buf[j++] = v->occurs_inv;
se->buf[j++] = v->occurs_cov;
se->buf[j++] = v->max_offset;
Expand Down Expand Up @@ -243,7 +241,7 @@ static void alloc_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
ct->gcstack = &se->gcframe;
}
}
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 4) : se->_space);
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 3) : se->_space);
#ifdef __clang_gcanalyzer__
memset(se->buf, 0, len * 3);
#endif
Expand Down Expand Up @@ -290,7 +288,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
v->ub = roots[i++];
v->innervars = (jl_array_t*)roots[i++];
}
v->occurs = se->buf[j++];
v->occurs_inv = se->buf[j++];
v->occurs_cov = se->buf[j++];
v->max_offset = se->buf[j++];
Expand All @@ -302,15 +299,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
memset(&e->envout[e->envidx], 0, (e->envsz - e->envidx)*sizeof(void*));
}

static void clean_occurs(jl_stenv_t *e)
{
jl_varbinding_t *v = e->vars;
while (v) {
v->occurs = 0;
v = v->prev;
}
}

#define flip_offset(e) ((e)->Loffset *= -1)

// type utilities
Expand Down Expand Up @@ -599,6 +587,8 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)

static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
Expand Down Expand Up @@ -679,8 +669,6 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par
// of determining whether the variable is concrete.
static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT
{
if (vb != NULL)
vb->occurs = 1;
if (vb != NULL && param) {
// saturate counters at 2; we don't need values bigger than that
if (param == 2 && e->invdepth > vb->depth0) {
Expand Down Expand Up @@ -915,7 +903,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
Expand Down Expand Up @@ -3312,7 +3300,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res = NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars);
save_env(e, &se, 1);
Expand Down Expand Up @@ -3341,7 +3329,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
vb.ub = vb.var->ub;
}
restore_env(e, &se, vb.constraintkind == 1 ? 1 : 0);
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
vb.occurs_cov = vb.occurs_inv = 0;
res = intersect_unionall_(t, u, e, R, param, &vb);
}
}
Expand Down Expand Up @@ -4042,79 +4030,12 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
return jl_bottom_type;
}

static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
static int merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se, int count)
{
if (count == 0)
alloc_env(e, se, 1);
jl_value_t **roots = NULL;
int nroots = 0;
if (se->gcframe.nroots == JL_GC_ENCODE_PUSHARGS(1)) {
jl_svec_t *sv = (jl_svec_t*)se->roots[0];
assert(jl_is_svec(sv));
roots = jl_svec_data(sv);
nroots = jl_svec_len(sv);
}
else {
roots = se->roots;
nroots = se->gcframe.nroots >> 2;
}
int m = 0, n = 0;
jl_varbinding_t *v = e->vars;
while (v != NULL) {
if (count == 0) {
// need to initialize this
se->buf[m] = 0;
se->buf[m+1] = 0;
se->buf[m+2] = 0;
se->buf[m+3] = v->max_offset;
}
jl_value_t *b1, *b2;
if (v->occurs) {
// only merge lb/ub if this var occurs.
b1 = roots[n];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->lb;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
roots[n] = b1 ? simple_meet(b1, b2, 0) : b2;
b1 = roots[n+1];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->ub;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
roots[n+1] = b1 ? simple_join(b1, b2) : b2;
// record the meeted vars.
se->buf[m] = 1;
}
// `innervars` might be re-sorted inside `finish_unionall`.
// We'd better always merge it.
b1 = roots[n+2];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = (jl_value_t*)v->innervars;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
else
roots[n+2] = b2;
}
// always merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > se->buf[m+1])
se->buf[m+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[m+2])
se->buf[m+2] = v->occurs_cov;
// always merge max_offset by min
if (!v->intersected && v->max_offset < se->buf[m+3])
se->buf[m+3] = v->max_offset;
m = m + 4;
n = n + 3;
v = v->prev;
if (count == 0) {
save_env(e, me, 1);
return 1;
}
assert(n == nroots); (void)nroots;
return count + 1;
}

// merge untouched vars' info.
static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
{
jl_value_t **merged = NULL;
jl_value_t **saved = NULL;
int nroots = 0;
Expand All @@ -4136,47 +4057,49 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
}
assert(nroots == current_env_length(e) * 3);
assert(nroots % 3 == 0);
for (int n = 0, m = 0; n < nroots; n += 3, m += 4) {
if (merged[n] == NULL)
merged[n] = saved[n];
if (merged[n+1] == NULL)
merged[n+1] = saved[n+1];
jl_value_t *b1, *b2;
int m = 0, n = 0;
jl_varbinding_t *v = e->vars;
while (v != NULL) {
jl_value_t *b0, *b1, *b2;
// merge `lb`
b0 = saved[n];
b1 = merged[n];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->lb;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
merged[n] = (b1 == b0 || b2 == b0) ? b0 : simple_meet(b1, b2, 0);
// merge `ub`
b0 = saved[n+1];
b1 = merged[n+1];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->ub;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
merged[n+1] = (b1 == b0 || b2 == b0) ? b0 : simple_join(b1, b2);
// merge `innervars`
b1 = merged[n+2];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = saved[n+2];
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know this came from our GC frame
b2 = (jl_value_t*)v->innervars;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
else
merged[n+2] = b2;
}
me->buf[m] |= se->buf[m];
}
}

static void expand_local_env(jl_stenv_t *e, jl_value_t *res)
{
jl_varbinding_t *v = e->vars;
// Here we pull in some typevar missed in fastpath.
while (v != NULL) {
v->occurs = v->occurs || jl_has_typevar(res, v->var);
assert(v->occurs == 0 || v->occurs == 1);
v = v->prev;
}
v = e->vars;
while (v != NULL) {
if (v->occurs == 1) {
jl_varbinding_t *v2 = e->vars;
while (v2 != NULL) {
if (v2 != v && v2->occurs == 0)
v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var));
v2 = v2->prev;
}
}
// merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > me->buf[m])
me->buf[m] = v->occurs_inv;
if (v->occurs_cov > me->buf[m+1])
me->buf[m+1] = v->occurs_cov;
// merge max_offset by min
if (!v->intersected && v->max_offset < me->buf[m+2])
me->buf[m+2] = v->max_offset;
m = m + 3;
n = n + 3;
v = v->prev;
}
assert(n == nroots); (void)nroots;
return count + 1;
}

static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
Expand All @@ -4189,25 +4112,19 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
jl_savedenv_t se, me;
save_env(e, &se, 1);
int niter = 0, total_iter = 0;
clean_occurs(e);
is[0] = intersect(x, y, e, 0); // root
if (is[0] != jl_bottom_type) {
expand_local_env(e, is[0]);
niter = merge_env(e, &me, niter);
}
if (is[0] != jl_bottom_type)
niter = merge_env(e, &me, &se, niter);
restore_env(e, &se, 1);
while (next_union_state(e, 1)) {
if (e->emptiness_only && is[0] != jl_bottom_type)
break;
e->Runions.depth = 0;
e->Runions.more = 0;

clean_occurs(e);
is[1] = intersect(x, y, e, 0);
if (is[1] != jl_bottom_type) {
expand_local_env(e, is[1]);
niter = merge_env(e, &me, niter);
}
if (is[1] != jl_bottom_type)
niter = merge_env(e, &me, &se, niter);
restore_env(e, &se, 1);
if (is[0] == jl_bottom_type)
is[0] = is[1];
Expand All @@ -4216,13 +4133,18 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
is[0] = jl_type_union(is, 2);
}
total_iter++;
if (niter > 4 || total_iter > 400000) {
if (has_next_union_state(e, 1) && (niter > 4 || total_iter > 400000)) {
is[0] = y;
// we give up precise intersection here, just restore the saved env
restore_env(e, &se, 1);
if (niter > 0) {
free_env(&me);
niter = 0;
}
break;
}
}
if (niter) {
final_merge_env(e, &me, &se);
restore_env(e, &me, 1);
free_env(&me);
}
Expand Down Expand Up @@ -4707,7 +4629,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {

static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
{
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_value_t *nt;
JL_GC_PUSH2(&vb.innervars, &nt);
if (jl_is_unionall(u->body))
Expand Down
43 changes: 36 additions & 7 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2380,12 +2380,41 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})},
@testintersect(S, T, !Union{})
end

# A simple case which has a small local union.
# make sure the env is not widened too much when we intersect(Int8, Int8).
struct T48006{A1,A2,A3} end
@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}},
Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}},
Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8})
let S = Dict{Int, S1} where {F1, S1<:Union{Int8, Val{F1}}},
T = Dict{F2, S2} where {F2, S2<:Union{Int8, Val{F2}}}
@test_broken typeintersect(S, T) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
@test typeintersect(T, S) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
end

# Ensure inner `intersect_all` never under-esitimate.
let S = Tuple{F1, Dict{Int, S1}} where {F1, S1<:Union{Int8, Val{F1}}},
T = Tuple{Any, Dict{F2, S2}} where {F2, S2<:Union{Int8, Val{F2}}}
@test Tuple{Nothing, Dict{Int, Int8}} <: S
@test Tuple{Nothing, Dict{Int, Int8}} <: T
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(S, T)
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(T, S)
end

let S = Tuple{F1, Val{S1}} where {F1, S1<:Dict{F1}}
T = Tuple{Any, Val{S2}} where {F2, S2<:Union{map(T->Dict{T}, Base.BitInteger_types)...}}
ST = typeintersect(S, T)
TS = typeintersect(S, T)
for U in Base.BitInteger_types
@test Tuple{U, Val{Dict{U,Nothing}}} <: S
@test Tuple{U, Val{Dict{U,Nothing}}} <: T
@test Tuple{U, Val{Dict{U,Nothing}}} <: ST
@test Tuple{U, Val{Dict{U,Nothing}}} <: TS
end
end

#issue 55206
struct T55206{A,B<:Complex{A},C<:Union{Dict{Nothing},Dict{A}}} end
@testintersect(T55206, T55206{<:Any,<:Any,<:Dict{Nothing}}, T55206{A,<:Complex{A},<:Dict{Nothing}} where {A})
@testintersect(
Tuple{Dict{Int8, Int16}, Val{S1}} where {F1, S1<:AbstractSet{F1}},
Tuple{Dict{T1, T2}, Val{S2}} where {T1, T2, S2<:Union{Set{T1},Set{T2}}},
Tuple{Dict{Int8, Int16}, Val{S1}} where {S1<:Union{Set{Int8},Set{Int16}}}
)

f48167(::Type{Val{L2}}, ::Type{Union{Val{L1}, Set{R}}}) where {L1, R, L2<:L1} = 1
f48167(::Type{Val{L1}}, ::Type{Union{Val{L2}, Set{R}}}) where {L1, R, L2<:L1} = 2
Expand Down Expand Up @@ -2554,7 +2583,7 @@ end
let T = Tuple{Union{Type{T}, Type{S}}, Union{Val{T}, Val{S}}, Union{Val{T}, S}} where T<:Val{A} where A where S<:Val,
S = Tuple{Type{T}, T, Val{T}} where T<:(Val{S} where S<:Val)
# optimal = Union{}?
@test typeintersect(T, S) == Tuple{Type{A}, Union{Val{A}, Val{S} where S<:Union{Val, A}, Val{x} where x<:Val, Val{x} where x<:Union{Val, A}}, Val{A}} where A<:(Val{S} where S<:Val)
@test typeintersect(T, S) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {S<:Val, T<:Val}
@test typeintersect(S, T) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {T<:Val, S<:(Union{Val{A}, Val} where A)}
end

Expand Down