Skip to content

Commit 61696f9

Browse files
authored
Merge pull request #29262 from JuliaLang/jn/interpret-phi
interpreter: fix bugs with ssair phi-node handling
2 parents 4851fab + 39f2561 commit 61696f9

File tree

8 files changed

+265
-77
lines changed

8 files changed

+265
-77
lines changed

base/stacktraces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ const top_level_scope_sym = Symbol("top-level scope")
123123
using Base.Meta
124124
is_loc_meta(expr, kind) = isexpr(expr, :meta) && length(expr.args) >= 1 && expr.args[1] === kind
125125
function lookup(ip::Base.InterpreterIP)
126-
if ip.code isa Core.MethodInstance
126+
if ip.code isa Core.MethodInstance && ip.code.def isa Method
127127
codeinfo = ip.code.inferred
128128
func = ip.code.def.name
129129
file = ip.code.def.file

base/tuple.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ firstindex(@nospecialize t::Tuple) = 1
2121
lastindex(@nospecialize t::Tuple) = length(t)
2222
size(@nospecialize(t::Tuple), d) = (d == 1) ? length(t) : throw(ArgumentError("invalid tuple dimension $d"))
2323
axes(@nospecialize t::Tuple) = OneTo(length(t))
24-
@eval getindex(t::Tuple, i::Int) = getfield(t, i, $(Expr(:boundscheck)))
25-
@eval getindex(t::Tuple, i::Real) = getfield(t, convert(Int, i), $(Expr(:boundscheck)))
24+
@eval getindex(@nospecialize(t::Tuple), i::Int) = getfield(t, i, $(Expr(:boundscheck)))
25+
@eval getindex(@nospecialize(t::Tuple), i::Real) = getfield(t, convert(Int, i), $(Expr(:boundscheck)))
2626
getindex(t::Tuple, r::AbstractArray{<:Any,1}) = ([t[ri] for ri in r]...,)
2727
getindex(t::Tuple, b::AbstractArray{Bool,1}) = length(b) == length(t) ? getindex(t, findall(b)) : throw(BoundsError(t, b))
2828
getindex(t::Tuple, c::Colon) = t
@@ -38,7 +38,10 @@ _setindex(v, i::Integer) = ()
3838

3939
## iterating ##
4040

41-
iterate(t::Tuple, i::Int=1) = 1 <= i <= length(t) ? (@inbounds t[i], i+1) : nothing
41+
function iterate(@nospecialize(t::Tuple), i::Int=1)
42+
@_inline_meta
43+
return (1 <= i <= length(t)) ? (@inbounds t[i], i + 1) : nothing
44+
end
4245

4346
keys(@nospecialize t::Tuple) = OneTo(length(t))
4447

src/interpreter.c

Lines changed: 117 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ typedef struct {
2020
jl_module_t *module; // context for globals
2121
jl_value_t **locals; // slots for holding local slots and ssavalues
2222
jl_svec_t *sparam_vals; // method static parameters, if eval-ing a method body
23-
size_t last_branch; // Points at the last branch statement (for evaluating phi nodes)
24-
size_t ip; // Points to the currently-evaluating statement
25-
int preevaluation; // use special rules for pre-evaluating expressions
23+
size_t ip; // Leak the currently-evaluating statement index to backtrace capture
24+
int preevaluation; // use special rules for pre-evaluating expressions (deprecated--only for ccall handling)
2625
int continue_at; // statement index to jump to after leaving exception handler (0 if none)
2726
} interpreter_state;
2827

2928
#include "interpreter-stacktrace.c"
3029

3130
static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s);
32-
static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, int start, int toplevel);
31+
static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, size_t ip, int toplevel);
3332

3433
int jl_is_toplevel_only_expr(jl_value_t *e);
3534

@@ -362,8 +361,6 @@ SECT_INTERP static void eval_stmt_value(jl_value_t *stmt, interpreter_state *s)
362361
{
363362
jl_value_t *res = eval_value(stmt, s);
364363
s->locals[jl_source_nslots(s->src) + s->ip] = res;
365-
if (!jl_is_phinode(stmt))
366-
s->last_branch = s->ip;
367364
}
368365

369366
SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
@@ -403,23 +400,7 @@ SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
403400
#endif
404401
return val;
405402
}
406-
if (jl_is_phinode(e)) {
407-
jl_array_t *edges = (jl_array_t*)jl_fieldref_noalloc(e, 0);
408-
ssize_t edge = -1;
409-
for (int i = 0; i < jl_array_len(edges); ++i) {
410-
size_t from = jl_unbox_long(jl_arrayref(edges, i));
411-
if (from == s->last_branch + 1) {
412-
edge = i;
413-
break;
414-
}
415-
}
416-
if (edge == -1) {
417-
// edges list doesn't contain last branch. this value should be unused.
418-
return NULL;
419-
}
420-
jl_value_t *val = jl_array_ptr_ref((jl_array_t*)jl_fieldref_noalloc(e, 1), edge);
421-
return eval_value(val, s);
422-
}
403+
assert(!jl_is_phinode(e) && !jl_is_phicnode(e) && !jl_is_upsilonnode(e) && "malformed AST");
423404
if (!jl_is_expr(e))
424405
return e;
425406
jl_expr_t *ex = (jl_expr_t*)e;
@@ -523,26 +504,115 @@ SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
523504
abort();
524505
}
525506

526-
SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, int start, int toplevel)
507+
// phi nodes don't behave like proper instructions, so we require a special interpreter to handle them
508+
SECT_INTERP static size_t eval_phi(jl_array_t *stmts, interpreter_state *s, size_t ns, size_t to)
509+
{
510+
size_t from = s->ip;
511+
size_t ip = to;
512+
unsigned nphi = 0;
513+
for (ip = to; ip < ns; ip++) {
514+
jl_value_t *e = jl_array_ptr_ref(stmts, ip);
515+
if (!jl_is_phinode(e))
516+
break;
517+
nphi += 1;
518+
}
519+
if (nphi) {
520+
jl_value_t **dest = &s->locals[jl_source_nslots(s->src) + to];
521+
jl_value_t **phis; // = (jl_value_t**)alloca(sizeof(jl_value_t*) * nphi);
522+
JL_GC_PUSHARGS(phis, nphi);
523+
for (unsigned i = 0; i < nphi; i++) {
524+
jl_value_t *e = jl_array_ptr_ref(stmts, to + i);
525+
assert(jl_is_phinode(e));
526+
jl_array_t *edges = (jl_array_t*)jl_fieldref_noalloc(e, 0);
527+
ssize_t edge = -1;
528+
size_t closest = to; // implicit edge has `to <= edge - 1 < to + i`
529+
// this is because we could see the following IR (all 1-indexed):
530+
// goto %3 unless %cond
531+
// %2 = phi ...
532+
// %3 = phi (1)[1 => %a], (2)[2 => %b]
533+
// from = 1, to = closest = 2, i = 1 --> edge = 2, edge_from = 2, from = 2
534+
for (unsigned j = 0; j < jl_array_len(edges); ++j) {
535+
size_t edge_from = jl_unbox_long(jl_arrayref(edges, j)); // 1-indexed
536+
if (edge_from == from + 1) {
537+
if (edge == -1)
538+
edge = j;
539+
}
540+
else if (closest < edge_from && edge_from < (to + i + 1)) {
541+
// if we found a nearer implicit branch from fall-through,
542+
// that occurred since the last explicit branch,
543+
// we should use the value from that edge instead
544+
edge = j;
545+
closest = edge_from;
546+
}
547+
}
548+
jl_value_t *val = NULL;
549+
unsigned n_oldphi = closest - to;
550+
if (n_oldphi) {
551+
// promote this implicit branch to a basic block start
552+
// and move all phi values to their position in edges
553+
// note that we might have already processed some phi nodes
554+
// in this basic block, so we need to be extra careful here
555+
// to ignore those
556+
for (unsigned j = 0; j < n_oldphi; j++) {
557+
dest[j] = phis[j];
558+
}
559+
for (unsigned j = n_oldphi; j < i; j++) {
560+
// move the rest to the start of phis
561+
phis[j - n_oldphi] = phis[j];
562+
phis[j] = NULL;
563+
}
564+
from = closest - 1;
565+
i -= n_oldphi;
566+
dest += n_oldphi;
567+
to += n_oldphi;
568+
nphi -= n_oldphi;
569+
}
570+
if (edge != -1) {
571+
// if edges list doesn't contain last branch, this value should be unused.
572+
jl_array_t *values = (jl_array_t*)jl_fieldref_noalloc(e, 1);
573+
val = jl_array_ptr_ref(values, edge);
574+
val = eval_value(val, s);
575+
}
576+
phis[i] = val;
577+
}
578+
// now move all phi values to their position in edges
579+
for (unsigned j = 0; j < nphi; j++) {
580+
dest[j] = phis[j];
581+
}
582+
JL_GC_POP();
583+
}
584+
return ip;
585+
}
586+
587+
SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, size_t ip, int toplevel)
527588
{
528589
jl_handler_t __eh;
529-
s->ip = start;
530590
size_t ns = jl_array_len(stmts);
531591

532592
while (1) {
533-
if (s->ip >= ns)
593+
s->ip = ip;
594+
if (ip >= ns)
534595
jl_error("`body` expression must terminate in `return`. Use `block` instead.");
535596
if (toplevel)
536597
jl_get_ptls_states()->world_age = jl_world_counter;
537-
jl_value_t *stmt = jl_array_ptr_ref(stmts, s->ip);
598+
jl_value_t *stmt = jl_array_ptr_ref(stmts, ip);
599+
assert(!jl_is_phinode(stmt));
600+
size_t next_ip = ip + 1;
601+
assert(!jl_is_phinode(stmt) && !jl_is_phicnode(stmt) && "malformed AST");
538602
if (jl_is_gotonode(stmt)) {
539-
s->last_branch = s->ip;
540-
s->ip = jl_gotonode_label(stmt) - 1;
541-
continue;
603+
next_ip = jl_gotonode_label(stmt) - 1;
604+
}
605+
else if (jl_is_upsilonnode(stmt)) {
606+
jl_value_t *val = jl_fieldref_noalloc(stmt, 0);
607+
if (val)
608+
val = eval_value(val, s);
609+
jl_value_t *phic = s->locals[jl_source_nslots(s->src) + ip];
610+
assert(jl_is_ssavalue(phic));
611+
ssize_t id = ((jl_ssavalue_t*)phic)->id - 1;
612+
s->locals[jl_source_nslots(s->src) + id] = val;
542613
}
543614
else if (jl_is_expr(stmt)) {
544615
// Most exprs are allowed to end a BB by fall through
545-
s->last_branch = s->ip;
546616
jl_sym_t *head = ((jl_expr_t*)stmt)->head;
547617
assert(head != unreachable_sym);
548618
if (head == return_sym) {
@@ -554,7 +624,7 @@ SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s
554624
if (jl_is_slot(lhs)) {
555625
ssize_t n = jl_slot_number(lhs);
556626
assert(n <= jl_source_nslots(s->src) && n > 0);
557-
s->locals[n-1] = rhs;
627+
s->locals[n - 1] = rhs;
558628
}
559629
else {
560630
jl_module_t *modu;
@@ -577,8 +647,7 @@ SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s
577647
else if (head == goto_ifnot_sym) {
578648
jl_value_t *cond = eval_value(jl_exprarg(stmt, 0), s);
579649
if (cond == jl_false) {
580-
s->ip = jl_unbox_long(jl_exprarg(stmt, 1)) - 1;
581-
continue;
650+
next_ip = jl_unbox_long(jl_exprarg(stmt, 1)) - 1;
582651
}
583652
else if (cond != jl_true) {
584653
jl_type_error_rt("toplevel", "if", (jl_value_t*)jl_bool_type, cond);
@@ -605,29 +674,32 @@ SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s
605674
for (size_t i = 0; i < jl_array_len(values); ++i) {
606675
jl_value_t *val = jl_array_ptr_ref(values, i);
607676
assert(jl_is_ssavalue(val));
608-
s->locals[jl_source_nslots(s->src) + ((jl_ssavalue_t*)val)->id - 1] = jl_box_ssavalue(catch_ip);
677+
size_t upsilon = ((jl_ssavalue_t*)val)->id - 1;
678+
assert(jl_is_upsilonnode(jl_array_ptr_ref(stmts, upsilon)));
679+
s->locals[jl_source_nslots(s->src) + upsilon] = jl_box_ssavalue(catch_ip + 1);
609680
}
681+
s->locals[jl_source_nslots(s->src) + catch_ip] = NULL;
610682
catch_ip += 1;
611683
}
612-
if (!jl_setjmp(__eh.eh_ctx,1)) {
613-
return eval_body(stmts, s, s->ip + 1, toplevel);
684+
if (!jl_setjmp(__eh.eh_ctx, 1)) {
685+
return eval_body(stmts, s, next_ip, toplevel);
614686
}
615-
else if (s->continue_at) {
616-
s->ip = s->continue_at;
687+
else if (s->continue_at) { // means we reached a :leave expression
688+
ip = s->continue_at;
617689
s->continue_at = 0;
618690
continue;
619691
}
620-
else {
692+
else { // a real exeception
621693
#ifdef _OS_WINDOWS_
622694
if (jl_get_ptls_states()->exception_in_transit == jl_stackovf_exception)
623695
_resetstkoflw();
624696
#endif
625-
s->ip = jl_unbox_long(jl_exprarg(stmt, 0)) - 1;
697+
ip = catch_ip;
626698
continue;
627699
}
628700
}
629701
else if (head == leave_sym) {
630-
int hand_n_leave = jl_unbox_long(jl_exprarg(stmt,0));
702+
int hand_n_leave = jl_unbox_long(jl_exprarg(stmt, 0));
631703
assert(hand_n_leave > 0);
632704
// equivalent to jl_pop_handler(hand_n_leave) :
633705
jl_ptls_t ptls = jl_get_ptls_states();
@@ -636,7 +708,7 @@ SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s
636708
eh = eh->prev;
637709
jl_eh_restore_state(eh);
638710
// pop jmp_bufs from stack
639-
s->continue_at = s->ip + 1;
711+
s->continue_at = next_ip;
640712
jl_longjmp(eh->eh_ctx, 1);
641713
}
642714
else if (head == const_sym) {
@@ -683,7 +755,6 @@ SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s
683755
}
684756
}
685757
else if (jl_is_newvarnode(stmt)) {
686-
s->last_branch = s->ip;
687758
jl_value_t *var = jl_fieldref(stmt, 0);
688759
assert(jl_is_slot(var));
689760
ssize_t n = jl_slot_number(var);
@@ -696,7 +767,7 @@ SECT_INTERP static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s
696767
else {
697768
eval_stmt_value(stmt, s);
698769
}
699-
s->ip++;
770+
ip = eval_phi(stmts, s, ns, next_ip);
700771
}
701772
abort();
702773
}
@@ -766,6 +837,7 @@ SECT_INTERP CALLBACK_ABI void *jl_interpret_call_callback(interpreter_state *s,
766837
}
767838
s->locals = locals + 2;
768839
s->sparam_vals = args->lam->sparam_vals;
840+
s->preevaluation = 0;
769841
s->continue_at = 0;
770842
s->mi = args->lam;
771843
size_t i;

test/choosetests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function choosetests(choices = [])
107107
prepend!(tests, ["subarray"])
108108
end
109109

110-
compilertests = ["compiler/compiler", "compiler/validation"]
110+
compilertests = ["compiler/compiler", "compiler/validation", "compiler/ssair"]
111111

112112
if "compiler" in skip_tests
113113
filter!(x -> (x != "compiler" && !(x in compilertests)), tests)

test/compiler/compiler.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,13 +954,14 @@ let f, m
954954
f() = 0
955955
m = first(methods(f))
956956
m.source = Base.uncompressed_ast(m)::CodeInfo
957-
m.source.ssavaluetypes = 3
958-
m.source.codelocs = Int32[1, 1, 1]
959957
m.source.code = Any[
960958
Expr(:call, GlobalRef(Core, :svec), 1, 2, 3),
961959
Expr(:call, Core._apply, GlobalRef(Base, :+), SSAValue(1)),
962960
Expr(:return, SSAValue(2))
963961
]
962+
nstmts = length(m.source.code)
963+
m.source.ssavaluetypes = nstmts
964+
m.source.codelocs = fill(Int32(1), nstmts)
964965
@test @inferred(f()) == 6
965966
end
966967

0 commit comments

Comments
 (0)