diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index be8825d8d881a7..5a28ff53d14c7a 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -130,12 +130,13 @@ mut: inside_assign bool // doing_line_info int // a quick single file run when called with v -line-info (contains line nr to inspect) // doing_line_path string // same, but stores the path being parsed - is_index_assign bool - comptime_call_pos int // needed for correctly checking use before decl for templates - goto_labels map[string]ast.GotoLabel // to check for unused goto labels - enum_data_type ast.Type - fn_return_type ast.Type - orm_table_fields map[string][]ast.StructField // known table structs + is_index_assign bool + is_added_ref_by_smartcast bool // type of last stmt in if/match branches maybe added a ref by smartcast (interface) + comptime_call_pos int // needed for correctly checking use before decl for templates + goto_labels map[string]ast.GotoLabel // to check for unused goto labels + enum_data_type ast.Type + fn_return_type ast.Type + orm_table_fields map[string][]ast.StructField // known table structs // v_current_commit_hash string // same as old C.V_CURRENT_COMMIT_HASH } @@ -3534,6 +3535,9 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type { && !c.prevent_sum_type_unwrapping_once c.prevent_sum_type_unwrapping_once = false mut typ := if is_sum_type_cast { obj.smartcasts.last() } else { obj.typ } + if typ.nr_muls() > obj.typ.nr_muls() { + c.is_added_ref_by_smartcast = true + } if typ == 0 { if mut obj.expr is ast.Ident { if obj.expr.kind == .unresolved { diff --git a/vlib/v/checker/match.v b/vlib/v/checker/match.v index d8302e79225f85..f6beca2a18b32d 100644 --- a/vlib/v/checker/match.v +++ b/vlib/v/checker/match.v @@ -39,9 +39,13 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { mut ret_type := ast.void_type mut nbranches_with_return := 0 mut nbranches_without_return := 0 + mut is_added_ref_by_smartcast := false for mut branch in node.branches { if node.is_expr { + c.is_added_ref_by_smartcast = false c.stmts_ending_with_expression(mut branch.stmts) + is_added_ref_by_smartcast = c.is_added_ref_by_smartcast + c.is_added_ref_by_smartcast = false } else { c.stmts(mut branch.stmts) } @@ -73,23 +77,43 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { c.check_match_branch_last_stmt(stmt, node.expected_type, expr_type) ret_type = node.expected_type } else { - ret_type = expr_type + ret_type = if is_added_ref_by_smartcast && expr_type.is_ptr() { + expr_type.deref() + } else { + expr_type + } } - } else if node.is_expr && ret_type.idx() != expr_type.idx() { - if (node.expected_type.has_flag(.option) - || node.expected_type.has_flag(.result)) - && c.table.sym(stmt.typ).kind == .struct_ - && c.type_implements(stmt.typ, ast.error_type, node.pos) { - stmt.expr = ast.CastExpr{ - expr: stmt.expr - typname: 'IError' - typ: ast.error_type - expr_type: stmt.typ - pos: node.pos + } else { + if node.is_expr && ret_type.idx() != expr_type.idx() { + if (node.expected_type.has_flag(.option) + || node.expected_type.has_flag(.result)) + && c.table.sym(stmt.typ).kind == .struct_ + && c.type_implements(stmt.typ, ast.error_type, node.pos) { + stmt.expr = ast.CastExpr{ + expr: stmt.expr + typname: 'IError' + typ: ast.error_type + expr_type: stmt.typ + pos: node.pos + } + stmt.typ = ast.error_type + } else { + c.check_match_branch_last_stmt(stmt, ret_type, expr_type) + } + } + if node.is_expr && stmt.typ != ast.error_type { + ret_sym := c.table.sym(ret_type) + stmt_sym := c.table.sym(stmt.typ) + if ret_sym.kind !in [.sum_type, .interface_] + && stmt_sym.kind in [.sum_type, .interface_] { + c.error('return type mismatch, it should be `${ret_sym.name}`', + stmt.pos) + } + if ret_type.nr_muls() != stmt.typ.nr_muls() { + type_name := '&'.repeat(ret_type.nr_muls()) + ret_sym.name + c.error('return type mismatch, it should be `${type_name}`', + stmt.pos) } - stmt.typ = ast.error_type - } else { - c.check_match_branch_last_stmt(stmt, ret_type, expr_type) } } } else if stmt !in [ast.Return, ast.BranchStmt] { diff --git a/vlib/v/checker/tests/match_return_mismatch_type_err.out b/vlib/v/checker/tests/match_return_mismatch_type_err.out index 3a8e8fe15c42d2..1bcd963845b8ea 100644 --- a/vlib/v/checker/tests/match_return_mismatch_type_err.out +++ b/vlib/v/checker/tests/match_return_mismatch_type_err.out @@ -5,3 +5,31 @@ vlib/v/checker/tests/match_return_mismatch_type_err.vv:4:10: error: return type | ~~ 5 | } 6 | println(a) +vlib/v/checker/tests/match_return_mismatch_type_err.vv:18:10: error: return type mismatch, it should be `&string` + 16 | _ = match any { + 17 | string { &any } + 18 | else { variable } + | ~~~~~~~~ + 19 | } + 20 | +vlib/v/checker/tests/match_return_mismatch_type_err.vv:23:10: error: return type mismatch, it should be `string` + 21 | _ = match any { + 22 | string { any } + 23 | else { &variable } + | ^ + 24 | } + 25 | } +vlib/v/checker/tests/match_return_mismatch_type_err.vv:36:10: error: return type mismatch, it should be `&string` + 34 | _ = match any { + 35 | string { &any } + 36 | else { variable } + | ~~~~~~~~ + 37 | } + 38 | +vlib/v/checker/tests/match_return_mismatch_type_err.vv:41:10: error: return type mismatch, it should be `string` + 39 | _ = match any { + 40 | string { any } + 41 | else { &variable } + | ^ + 42 | } + 43 | } diff --git a/vlib/v/checker/tests/match_return_mismatch_type_err.vv b/vlib/v/checker/tests/match_return_mismatch_type_err.vv index c422126014ced2..24c8ebecca33a0 100644 --- a/vlib/v/checker/tests/match_return_mismatch_type_err.vv +++ b/vlib/v/checker/tests/match_return_mismatch_type_err.vv @@ -5,3 +5,39 @@ fn main() { } println(a) } + +// for test the returns both interface or non-interface +interface IAny {} + +fn returns_both_interface_and_non_interface() { + any := IAny('abc') + variable := '' + + _ = match any { + string { &any } + else { variable } + } + + _ = match any { + string { any } + else { &variable } + } +} + +// for test the returns both sumtype or non-sumtype +type SAny = int | string + +fn returns_both_sumtype_and_non_sumtype() { + any := SAny('abc') + variable := '' + + _ = match any { + string { &any } + else { variable } + } + + _ = match any { + string { any } + else { &variable } + } +} diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 3771bcb0e5898e..d1b39312ebdf15 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -194,13 +194,13 @@ mut: sql_parent_id string sql_side SqlExprSide // left or right, to distinguish idents in `name == name` strs_to_free0 []string // strings.Builder - // strs_to_free []string // strings.Builder - // tmp_arg_vars_to_free []string - // autofree_pregen map[string]string - // autofree_pregen_buf strings.Builder - // autofree_tmp_vars []string // to avoid redefining the same tmp vars in a single function - // nr_vars_to_free int - // doing_autofree_tmp bool + // strs_to_free []string // strings.Builder + // tmp_arg_vars_to_free []string + // autofree_pregen map[string]string + // autofree_pregen_buf strings.Builder + // autofree_tmp_vars []string // to avoid redefining the same tmp vars in a single function + // nr_vars_to_free int + // doing_autofree_tmp bool comptime_for_method string // $for method in T.methods {} comptime_for_method_var string // $for method in T.methods {}; the variable name comptime_for_field_var string // $for field in T.fields {}; the variable name @@ -214,23 +214,24 @@ mut: // TypeOne, TypeTwo {} // where an aggregate (at least two types) is generated // sum type deref needs to know which index to deref because unions take care of the correct field - aggregate_type_idx int - branch_parent_pos int // used in BranchStmt (continue/break) for autofree stop position - returned_var_name string // to detect that a var doesn't need to be freed since it's being returned - infix_left_var_name string // a && if expr - called_fn_name string - timers &util.Timers = util.get_timers() - force_main_console bool // true when [console] used on fn main() - as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast) - obf_table map[string]string - referenced_fns shared map[string]bool // functions that have been referenced - nr_closures int - expected_cast_type ast.Type // for match expr of sumtypes - or_expr_return_type ast.Type // or { 0, 1 } return type - anon_fn bool - tests_inited bool - has_main bool - // main_fn_decl_node ast.FnDecl + aggregate_type_idx int + branch_parent_pos int // used in BranchStmt (continue/break) for autofree stop position + returned_var_name string // to detect that a var doesn't need to be freed since it's being returned + infix_left_var_name string // a && if expr + called_fn_name string + timers &util.Timers = util.get_timers() + force_main_console bool // true when [console] used on fn main() + as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast) + obf_table map[string]string + referenced_fns shared map[string]bool // functions that have been referenced + nr_closures int + expected_cast_type ast.Type // for match expr of sumtypes + expected_return_type ast.Type // for match expr of both interface and non-interface return type + or_expr_return_type ast.Type // or { 0, 1 } return type + anon_fn bool + tests_inited bool + has_main bool + // main_fn_decl_node ast.FnDecl cur_mod ast.Module cur_concrete_types []ast.Type // do not use table.cur_concrete_types because table is global, so should not be accessed by different threads cur_fn &ast.FnDecl = unsafe { nil } // same here @@ -1870,12 +1871,24 @@ fn (mut g Gen) stmts_with_tmp_var(stmts []ast.Stmt, tmp_var string) bool { if !is_noreturn { g.write('${tmp_var} = ') } + if g.expected_return_type != 0 { + if stmt is ast.ExprStmt && stmt.typ != ast.error_type + && stmt.typ.nr_muls() > g.expected_return_type.nr_muls() { + g.write('*'.repeat(stmt.typ.nr_muls() - g.expected_return_type.nr_muls())) + } + } g.stmt(stmt) if !g.out.last_n(2).contains(';') { g.writeln(';') } } } else { + if i == stmts.len - 1 && g.expected_return_type != 0 { + if stmt is ast.ExprStmt && stmt.typ != ast.error_type + && stmt.typ.nr_muls() > g.expected_return_type.nr_muls() { + g.write('*'.repeat(stmt.typ.nr_muls() - g.expected_return_type.nr_muls())) + } + } g.stmt(stmt) if (g.inside_if_option || g.inside_if_result || g.inside_match_option || g.inside_match_result) && stmt is ast.ExprStmt { diff --git a/vlib/v/gen/c/match.v b/vlib/v/gen/c/match.v index 3b9aa3a17e57f7..b069b9f5b397d3 100644 --- a/vlib/v/gen/c/match.v +++ b/vlib/v/gen/c/match.v @@ -230,7 +230,11 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str if is_expr && tmp_var.len > 0 && g.table.sym(node.return_type).kind == .sum_type { g.expected_cast_type = node.return_type } + if is_expr && branch.stmts.len > 0 { + g.expected_return_type = node.return_type + } g.stmts_with_tmp_var(branch.stmts, tmp_var) + g.expected_return_type = 0 g.expected_cast_type = 0 if g.inside_ternary == 0 { g.writeln('}') diff --git a/vlib/v/tests/match_test.v b/vlib/v/tests/match_test.v index e0c579d471af37..e5ba11293c7540 100644 --- a/vlib/v/tests/match_test.v +++ b/vlib/v/tests/match_test.v @@ -301,3 +301,29 @@ fn test_noreturn() { } } } + +// for test the returns both interface and non-interface +interface Any {} + +fn test_returns_both_interface_and_non_interface() { + any := Any('abc') + + mut res := match any { + string { any } + else { 'literal' } + } + assert res == 'abc' + + variable := '' + res = match any { + string { any } + else { variable } + } + assert res == 'abc' + + res = match any { + string { &any } + else { &variable } + } + assert res == 'abc' +}