Skip to content

Commit

Permalink
checker, cgen: fix match branches return type ref mismatch, when retu…
Browse files Browse the repository at this point in the history
…rn type exists interface or sumtype(fix #16203)
  • Loading branch information
shove70 committed Nov 9, 2023
1 parent cd2e36a commit faf6f4e
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 45 deletions.
16 changes: 10 additions & 6 deletions vlib/v/checker/checker.v
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,13 @@ mut:
inside_assign bool
// doing_line_info int // a quick single file run when called with v -line-info (contains line nr to inspect)
// doing_line_path string // same, but stores the path being parsed
is_index_assign bool
comptime_call_pos int // needed for correctly checking use before decl for templates
goto_labels map[string]ast.GotoLabel // to check for unused goto labels
enum_data_type ast.Type
fn_return_type ast.Type
orm_table_fields map[string][]ast.StructField // known table structs
is_index_assign bool
is_added_ref_by_smartcast bool // type of last stmt in if/match branches maybe added a ref by smartcast (interface)
comptime_call_pos int // needed for correctly checking use before decl for templates
goto_labels map[string]ast.GotoLabel // to check for unused goto labels
enum_data_type ast.Type
fn_return_type ast.Type
orm_table_fields map[string][]ast.StructField // known table structs
//
v_current_commit_hash string // same as old C.V_CURRENT_COMMIT_HASH
}
Expand Down Expand Up @@ -3534,6 +3535,9 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type {
&& !c.prevent_sum_type_unwrapping_once
c.prevent_sum_type_unwrapping_once = false
mut typ := if is_sum_type_cast { obj.smartcasts.last() } else { obj.typ }
if typ.nr_muls() > obj.typ.nr_muls() {
c.is_added_ref_by_smartcast = true
}
if typ == 0 {
if mut obj.expr is ast.Ident {
if obj.expr.kind == .unresolved {
Expand Down
54 changes: 39 additions & 15 deletions vlib/v/checker/match.v
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type {
mut ret_type := ast.void_type
mut nbranches_with_return := 0
mut nbranches_without_return := 0
mut is_added_ref_by_smartcast := false
for mut branch in node.branches {
if node.is_expr {
c.is_added_ref_by_smartcast = false
c.stmts_ending_with_expression(mut branch.stmts)
is_added_ref_by_smartcast = c.is_added_ref_by_smartcast
c.is_added_ref_by_smartcast = false
} else {
c.stmts(mut branch.stmts)
}
Expand Down Expand Up @@ -73,23 +77,43 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type {
c.check_match_branch_last_stmt(stmt, node.expected_type, expr_type)
ret_type = node.expected_type
} else {
ret_type = expr_type
ret_type = if is_added_ref_by_smartcast && expr_type.is_ptr() {
expr_type.deref()
} else {
expr_type
}
}
} else if node.is_expr && ret_type.idx() != expr_type.idx() {
if (node.expected_type.has_flag(.option)
|| node.expected_type.has_flag(.result))
&& c.table.sym(stmt.typ).kind == .struct_
&& c.type_implements(stmt.typ, ast.error_type, node.pos) {
stmt.expr = ast.CastExpr{
expr: stmt.expr
typname: 'IError'
typ: ast.error_type
expr_type: stmt.typ
pos: node.pos
} else {
if node.is_expr && ret_type.idx() != expr_type.idx() {
if (node.expected_type.has_flag(.option)
|| node.expected_type.has_flag(.result))
&& c.table.sym(stmt.typ).kind == .struct_
&& c.type_implements(stmt.typ, ast.error_type, node.pos) {
stmt.expr = ast.CastExpr{
expr: stmt.expr
typname: 'IError'
typ: ast.error_type
expr_type: stmt.typ
pos: node.pos
}
stmt.typ = ast.error_type
} else {
c.check_match_branch_last_stmt(stmt, ret_type, expr_type)
}
}
if node.is_expr && stmt.typ != ast.error_type {
ret_sym := c.table.sym(ret_type)
stmt_sym := c.table.sym(stmt.typ)
if ret_sym.kind !in [.sum_type, .interface_]
&& stmt_sym.kind in [.sum_type, .interface_] {
c.error('return type mismatch, it should be `${ret_sym.name}`',
stmt.pos)
}
if ret_type.nr_muls() != stmt.typ.nr_muls() {
type_name := '&'.repeat(ret_type.nr_muls()) + ret_sym.name
c.error('return type mismatch, it should be `${type_name}`',
stmt.pos)
}
stmt.typ = ast.error_type
} else {
c.check_match_branch_last_stmt(stmt, ret_type, expr_type)
}
}
} else if stmt !in [ast.Return, ast.BranchStmt] {
Expand Down
28 changes: 28 additions & 0 deletions vlib/v/checker/tests/match_return_mismatch_type_err.out
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,31 @@ vlib/v/checker/tests/match_return_mismatch_type_err.vv:4:10: error: return type
| ~~
5 | }
6 | println(a)
vlib/v/checker/tests/match_return_mismatch_type_err.vv:18:10: error: return type mismatch, it should be `&string`
16 | _ = match any {
17 | string { &any }
18 | else { variable }
| ~~~~~~~~
19 | }
20 |
vlib/v/checker/tests/match_return_mismatch_type_err.vv:23:10: error: return type mismatch, it should be `string`
21 | _ = match any {
22 | string { any }
23 | else { &variable }
| ^
24 | }
25 | }
vlib/v/checker/tests/match_return_mismatch_type_err.vv:36:10: error: return type mismatch, it should be `&string`
34 | _ = match any {
35 | string { &any }
36 | else { variable }
| ~~~~~~~~
37 | }
38 |
vlib/v/checker/tests/match_return_mismatch_type_err.vv:41:10: error: return type mismatch, it should be `string`
39 | _ = match any {
40 | string { any }
41 | else { &variable }
| ^
42 | }
43 | }
36 changes: 36 additions & 0 deletions vlib/v/checker/tests/match_return_mismatch_type_err.vv
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,39 @@ fn main() {
}
println(a)
}

// for test the returns both interface or non-interface
interface IAny {}

fn returns_both_interface_and_non_interface() {
any := IAny('abc')
variable := ''

_ = match any {
string { &any }
else { variable }
}

_ = match any {
string { any }
else { &variable }
}
}

// for test the returns both sumtype or non-sumtype
type SAny = int | string

fn returns_both_sumtype_and_non_sumtype() {
any := SAny('abc')
variable := ''

_ = match any {
string { &any }
else { variable }
}

_ = match any {
string { any }
else { &variable }
}
}
61 changes: 37 additions & 24 deletions vlib/v/gen/c/cgen.v
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ mut:
sql_parent_id string
sql_side SqlExprSide // left or right, to distinguish idents in `name == name`
strs_to_free0 []string // strings.Builder
// strs_to_free []string // strings.Builder
// tmp_arg_vars_to_free []string
// autofree_pregen map[string]string
// autofree_pregen_buf strings.Builder
// autofree_tmp_vars []string // to avoid redefining the same tmp vars in a single function
// nr_vars_to_free int
// doing_autofree_tmp bool
// strs_to_free []string // strings.Builder
// tmp_arg_vars_to_free []string
// autofree_pregen map[string]string
// autofree_pregen_buf strings.Builder
// autofree_tmp_vars []string // to avoid redefining the same tmp vars in a single function
// nr_vars_to_free int
// doing_autofree_tmp bool
comptime_for_method string // $for method in T.methods {}
comptime_for_method_var string // $for method in T.methods {}; the variable name
comptime_for_field_var string // $for field in T.fields {}; the variable name
Expand All @@ -214,23 +214,24 @@ mut:
// TypeOne, TypeTwo {}
// where an aggregate (at least two types) is generated
// sum type deref needs to know which index to deref because unions take care of the correct field
aggregate_type_idx int
branch_parent_pos int // used in BranchStmt (continue/break) for autofree stop position
returned_var_name string // to detect that a var doesn't need to be freed since it's being returned
infix_left_var_name string // a && if expr
called_fn_name string
timers &util.Timers = util.get_timers()
force_main_console bool // true when [console] used on fn main()
as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast)
obf_table map[string]string
referenced_fns shared map[string]bool // functions that have been referenced
nr_closures int
expected_cast_type ast.Type // for match expr of sumtypes
or_expr_return_type ast.Type // or { 0, 1 } return type
anon_fn bool
tests_inited bool
has_main bool
// main_fn_decl_node ast.FnDecl
aggregate_type_idx int
branch_parent_pos int // used in BranchStmt (continue/break) for autofree stop position
returned_var_name string // to detect that a var doesn't need to be freed since it's being returned
infix_left_var_name string // a && if expr
called_fn_name string
timers &util.Timers = util.get_timers()
force_main_console bool // true when [console] used on fn main()
as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast)
obf_table map[string]string
referenced_fns shared map[string]bool // functions that have been referenced
nr_closures int
expected_cast_type ast.Type // for match expr of sumtypes
expected_return_type ast.Type // for match expr of both interface and non-interface return type
or_expr_return_type ast.Type // or { 0, 1 } return type
anon_fn bool
tests_inited bool
has_main bool
// main_fn_decl_node ast.FnDecl
cur_mod ast.Module
cur_concrete_types []ast.Type // do not use table.cur_concrete_types because table is global, so should not be accessed by different threads
cur_fn &ast.FnDecl = unsafe { nil } // same here
Expand Down Expand Up @@ -1870,12 +1871,24 @@ fn (mut g Gen) stmts_with_tmp_var(stmts []ast.Stmt, tmp_var string) bool {
if !is_noreturn {
g.write('${tmp_var} = ')
}
if g.expected_return_type != 0 {
if stmt is ast.ExprStmt && stmt.typ != ast.error_type
&& stmt.typ.nr_muls() > g.expected_return_type.nr_muls() {
g.write('*'.repeat(stmt.typ.nr_muls() - g.expected_return_type.nr_muls()))
}
}
g.stmt(stmt)
if !g.out.last_n(2).contains(';') {
g.writeln(';')
}
}
} else {
if i == stmts.len - 1 && g.expected_return_type != 0 {
if stmt is ast.ExprStmt && stmt.typ != ast.error_type
&& stmt.typ.nr_muls() > g.expected_return_type.nr_muls() {
g.write('*'.repeat(stmt.typ.nr_muls() - g.expected_return_type.nr_muls()))
}
}
g.stmt(stmt)
if (g.inside_if_option || g.inside_if_result || g.inside_match_option
|| g.inside_match_result) && stmt is ast.ExprStmt {
Expand Down
4 changes: 4 additions & 0 deletions vlib/v/gen/c/match.v
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
if is_expr && tmp_var.len > 0 && g.table.sym(node.return_type).kind == .sum_type {
g.expected_cast_type = node.return_type
}
if is_expr && branch.stmts.len > 0 {
g.expected_return_type = node.return_type
}
g.stmts_with_tmp_var(branch.stmts, tmp_var)
g.expected_return_type = 0
g.expected_cast_type = 0
if g.inside_ternary == 0 {
g.writeln('}')
Expand Down
26 changes: 26 additions & 0 deletions vlib/v/tests/match_test.v
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,29 @@ fn test_noreturn() {
}
}
}

// for test the returns both interface and non-interface
interface Any {}

fn test_returns_both_interface_and_non_interface() {
any := Any('abc')

mut res := match any {
string { any }
else { 'literal' }
}
assert res == 'abc'

variable := ''
res = match any {
string { any }
else { variable }
}
assert res == 'abc'

res = match any {
string { &any }
else { &variable }
}
assert res == 'abc'
}

0 comments on commit faf6f4e

Please sign in to comment.