@@ -9,7 +9,7 @@ use rustc_hir::intravisit::FnKind;
9
9
use rustc_hir:: { Block , Body , Expr , ExprKind , FnDecl , LangItem , MatchSource , PatKind , QPath , StmtKind } ;
10
10
use rustc_lint:: { LateContext , LateLintPass , LintContext } ;
11
11
use rustc_middle:: lint:: in_external_macro;
12
- use rustc_middle:: ty:: subst:: GenericArgKind ;
12
+ use rustc_middle:: ty:: { self , subst:: GenericArgKind , Ty } ;
13
13
use rustc_session:: { declare_lint_pass, declare_tool_lint} ;
14
14
use rustc_span:: def_id:: LocalDefId ;
15
15
use rustc_span:: source_map:: Span ;
@@ -175,7 +175,7 @@ impl<'tcx> LateLintPass<'tcx> for Return {
175
175
} else {
176
176
RetReplacement :: Empty
177
177
} ;
178
- check_final_expr ( cx, body. value , vec ! [ ] , replacement) ;
178
+ check_final_expr ( cx, body. value , vec ! [ ] , replacement, None ) ;
179
179
} ,
180
180
FnKind :: ItemFn ( ..) | FnKind :: Method ( ..) => {
181
181
check_block_return ( cx, & body. value . kind , sp, vec ! [ ] ) ;
@@ -188,11 +188,11 @@ impl<'tcx> LateLintPass<'tcx> for Return {
188
188
fn check_block_return < ' tcx > ( cx : & LateContext < ' tcx > , expr_kind : & ExprKind < ' tcx > , sp : Span , mut semi_spans : Vec < Span > ) {
189
189
if let ExprKind :: Block ( block, _) = expr_kind {
190
190
if let Some ( block_expr) = block. expr {
191
- check_final_expr ( cx, block_expr, semi_spans, RetReplacement :: Empty ) ;
191
+ check_final_expr ( cx, block_expr, semi_spans, RetReplacement :: Empty , None ) ;
192
192
} else if let Some ( stmt) = block. stmts . iter ( ) . last ( ) {
193
193
match stmt. kind {
194
194
StmtKind :: Expr ( expr) => {
195
- check_final_expr ( cx, expr, semi_spans, RetReplacement :: Empty ) ;
195
+ check_final_expr ( cx, expr, semi_spans, RetReplacement :: Empty , None ) ;
196
196
} ,
197
197
StmtKind :: Semi ( semi_expr) => {
198
198
// Remove ending semicolons and any whitespace ' ' in between.
@@ -202,7 +202,7 @@ fn check_block_return<'tcx>(cx: &LateContext<'tcx>, expr_kind: &ExprKind<'tcx>,
202
202
span_find_starting_semi ( cx. sess ( ) . source_map ( ) , semi_span. with_hi ( sp. hi ( ) ) ) ;
203
203
semi_spans. push ( semi_span_to_remove) ;
204
204
}
205
- check_final_expr ( cx, semi_expr, semi_spans, RetReplacement :: Empty ) ;
205
+ check_final_expr ( cx, semi_expr, semi_spans, RetReplacement :: Empty , None ) ;
206
206
} ,
207
207
_ => ( ) ,
208
208
}
@@ -216,6 +216,7 @@ fn check_final_expr<'tcx>(
216
216
semi_spans : Vec < Span > , /* containing all the places where we would need to remove semicolons if finding an
217
217
* needless return */
218
218
replacement : RetReplacement < ' tcx > ,
219
+ match_ty_opt : Option < Ty < ' _ > > ,
219
220
) {
220
221
let peeled_drop_expr = expr. peel_drop_temps ( ) ;
221
222
match & peeled_drop_expr. kind {
@@ -244,7 +245,22 @@ fn check_final_expr<'tcx>(
244
245
RetReplacement :: Expr ( snippet, applicability)
245
246
}
246
247
} else {
247
- replacement
248
+ match match_ty_opt {
249
+ Some ( match_ty) => {
250
+ match match_ty. kind ( ) {
251
+ // If the code got till here with
252
+ // tuple not getting detected before it,
253
+ // then we are sure it's going to be Unit
254
+ // type
255
+ ty:: Tuple ( _) => RetReplacement :: Unit ,
256
+ // We don't want to anything in this case
257
+ // cause we can't predict what the user would
258
+ // want here
259
+ _ => return ,
260
+ }
261
+ } ,
262
+ None => replacement,
263
+ }
248
264
} ;
249
265
250
266
if !cx. tcx . hir ( ) . attrs ( expr. hir_id ) . is_empty ( ) {
@@ -268,8 +284,9 @@ fn check_final_expr<'tcx>(
268
284
// note, if without else is going to be a type checking error anyways
269
285
// (except for unit type functions) so we don't match it
270
286
ExprKind :: Match ( _, arms, MatchSource :: Normal ) => {
287
+ let match_ty = cx. typeck_results ( ) . expr_ty ( peeled_drop_expr) ;
271
288
for arm in arms. iter ( ) {
272
- check_final_expr ( cx, arm. body , semi_spans. clone ( ) , RetReplacement :: Unit ) ;
289
+ check_final_expr ( cx, arm. body , semi_spans. clone ( ) , RetReplacement :: Unit , Some ( match_ty ) ) ;
273
290
}
274
291
} ,
275
292
// if it's a whole block, check it
@@ -293,6 +310,7 @@ fn emit_return_lint(cx: &LateContext<'_>, ret_span: Span, semi_spans: Vec<Span>,
293
310
if ret_span. from_expansion ( ) {
294
311
return ;
295
312
}
313
+
296
314
let applicability = replacement. applicability ( ) . unwrap_or ( Applicability :: MachineApplicable ) ;
297
315
let return_replacement = replacement. to_string ( ) ;
298
316
let sugg_help = replacement. sugg_help ( ) ;
0 commit comments