@@ -3,14 +3,15 @@ use std::ops::ControlFlow;
3
3
4
4
use clippy_utils:: comparisons:: { Rel , normalize_comparison} ;
5
5
use clippy_utils:: diagnostics:: span_lint_and_then;
6
+ use clippy_utils:: macros:: { find_assert_eq_args, first_node_macro_backtrace} ;
6
7
use clippy_utils:: source:: snippet;
7
8
use clippy_utils:: visitors:: for_each_expr_without_closures;
8
9
use clippy_utils:: { eq_expr_value, hash_expr, higher} ;
9
- use rustc_ast:: { LitKind , RangeLimits } ;
10
+ use rustc_ast:: { BinOpKind , LitKind , RangeLimits } ;
10
11
use rustc_data_structures:: packed:: Pu128 ;
11
12
use rustc_data_structures:: unhash:: UnindexMap ;
12
13
use rustc_errors:: { Applicability , Diag } ;
13
- use rustc_hir:: { BinOp , Block , Body , Expr , ExprKind , UnOp } ;
14
+ use rustc_hir:: { Block , Body , Expr , ExprKind , UnOp } ;
14
15
use rustc_lint:: { LateContext , LateLintPass } ;
15
16
use rustc_session:: declare_lint_pass;
16
17
use rustc_span:: source_map:: Spanned ;
@@ -97,7 +98,7 @@ enum LengthComparison {
97
98
///
98
99
/// E.g. for `v.len() > 5` this returns `Some((LengthComparison::IntLessThanLength, 5, v.len()))`
99
100
fn len_comparison < ' hir > (
100
- bin_op : BinOp ,
101
+ bin_op : BinOpKind ,
101
102
left : & ' hir Expr < ' hir > ,
102
103
right : & ' hir Expr < ' hir > ,
103
104
) -> Option < ( LengthComparison , usize , & ' hir Expr < ' hir > ) > {
@@ -112,7 +113,7 @@ fn len_comparison<'hir>(
112
113
113
114
// normalize comparison, `v.len() > 4` becomes `4 < v.len()`
114
115
// this simplifies the logic a bit
115
- let ( op, left, right) = normalize_comparison ( bin_op. node , left, right) ?;
116
+ let ( op, left, right) = normalize_comparison ( bin_op, left, right) ?;
116
117
match ( op, left. kind , right. kind ) {
117
118
( Rel :: Lt , int_lit_pat ! ( left) , _) => Some ( ( LengthComparison :: IntLessThanLength , left as usize , right) ) ,
118
119
( Rel :: Lt , _, int_lit_pat ! ( right) ) => Some ( ( LengthComparison :: LengthLessThanInt , right as usize , left) ) ,
@@ -134,18 +135,30 @@ fn assert_len_expr<'hir>(
134
135
cx : & LateContext < ' _ > ,
135
136
expr : & ' hir Expr < ' hir > ,
136
137
) -> Option < ( LengthComparison , usize , & ' hir Expr < ' hir > ) > {
137
- if let Some ( higher:: If { cond, then, .. } ) = higher:: If :: hir ( expr)
138
+ let ( cmp , asserted_len , slice_len ) = if let Some ( higher:: If { cond, then, .. } ) = higher:: If :: hir ( expr)
138
139
&& let ExprKind :: Unary ( UnOp :: Not , condition) = & cond. kind
139
140
&& let ExprKind :: Binary ( bin_op, left, right) = & condition. kind
140
-
141
- && let Some ( ( cmp, asserted_len, slice_len) ) = len_comparison ( * bin_op, left, right)
142
- && let ExprKind :: MethodCall ( method, recv, [ ] , _) = & slice_len. kind
143
- && cx. typeck_results ( ) . expr_ty_adjusted ( recv) . peel_refs ( ) . is_slice ( )
144
- && method. ident . name == sym:: len
145
-
146
141
// check if `then` block has a never type expression
147
142
&& let ExprKind :: Block ( Block { expr : Some ( then_expr) , .. } , _) = then. kind
148
143
&& cx. typeck_results ( ) . expr_ty ( then_expr) . is_never ( )
144
+ {
145
+ len_comparison ( bin_op. node , left, right) ?
146
+ } else if let Some ( ( macro_call, bin_op) ) = first_node_macro_backtrace ( cx, expr) . find_map ( |macro_call| {
147
+ match cx. tcx . get_diagnostic_name ( macro_call. def_id ) {
148
+ Some ( sym:: assert_eq_macro) => Some ( ( macro_call, BinOpKind :: Eq ) ) ,
149
+ Some ( sym:: assert_ne_macro) => Some ( ( macro_call, BinOpKind :: Ne ) ) ,
150
+ _ => None ,
151
+ }
152
+ } ) && let Some ( ( left, right, _) ) = find_assert_eq_args ( cx, expr, macro_call. expn )
153
+ {
154
+ len_comparison ( bin_op, left, right) ?
155
+ } else {
156
+ return None ;
157
+ } ;
158
+
159
+ if let ExprKind :: MethodCall ( method, recv, [ ] , _) = & slice_len. kind
160
+ && cx. typeck_results ( ) . expr_ty_adjusted ( recv) . peel_refs ( ) . is_slice ( )
161
+ && method. ident . name == sym:: len
149
162
{
150
163
Some ( ( cmp, asserted_len, recv) )
151
164
} else {
@@ -310,7 +323,7 @@ fn check_assert<'hir>(cx: &LateContext<'_>, expr: &'hir Expr<'hir>, map: &mut Un
310
323
indexes : mem:: take ( indexes) ,
311
324
is_first_highest : * is_first_highest,
312
325
slice,
313
- assert_span : expr. span ,
326
+ assert_span : expr. span . source_callsite ( ) ,
314
327
comparison,
315
328
asserted_len,
316
329
} ;
@@ -319,7 +332,7 @@ fn check_assert<'hir>(cx: &LateContext<'_>, expr: &'hir Expr<'hir>, map: &mut Un
319
332
indexes. push ( IndexEntry :: StrayAssert {
320
333
asserted_len,
321
334
comparison,
322
- assert_span : expr. span ,
335
+ assert_span : expr. span . source_callsite ( ) ,
323
336
slice,
324
337
} ) ;
325
338
}
0 commit comments