1
1
package com .cxytiandi .encrypt .core ;
2
2
3
3
import java .io .IOException ;
4
- import java .util .Enumeration ;
5
- import java .util .HashMap ;
6
- import java .util .List ;
7
- import java .util .Map ;
4
+ import java .util .*;
8
5
9
6
import javax .servlet .Filter ;
10
7
import javax .servlet .FilterChain ;
16
13
import javax .servlet .http .HttpServletRequest ;
17
14
import javax .servlet .http .HttpServletResponse ;
18
15
16
+ import com .cxytiandi .encrypt .util .RequestUriUtils ;
19
17
import org .slf4j .Logger ;
20
18
import org .slf4j .LoggerFactory ;
21
19
22
20
import com .cxytiandi .encrypt .algorithm .AesEncryptAlgorithm ;
23
21
import com .cxytiandi .encrypt .algorithm .EncryptAlgorithm ;
22
+ import org .springframework .util .AntPathMatcher ;
24
23
import org .springframework .util .CollectionUtils ;
25
24
import org .springframework .util .StringUtils ;
26
25
import org .springframework .web .bind .annotation .RequestMethod ;
26
+ import org .springframework .web .method .HandlerMethod ;
27
+ import org .springframework .web .servlet .DispatcherServlet ;
28
+ import org .springframework .web .servlet .HandlerExecutionChain ;
29
+ import org .springframework .web .servlet .HandlerMapping ;
27
30
28
31
/**
29
32
* 数据加解密过滤器
@@ -38,6 +41,8 @@ public class EncryptionFilter implements Filter {
38
41
39
42
private EncryptAlgorithm encryptAlgorithm = new AesEncryptAlgorithm ();
40
43
44
+ private DispatcherServlet dispatcherServlet ;
45
+
41
46
public EncryptionFilter () {
42
47
this .encryptionConfig = new EncryptionConfig ();
43
48
}
@@ -46,9 +51,15 @@ public EncryptionFilter(EncryptionConfig config) {
46
51
this .encryptionConfig = config ;
47
52
}
48
53
49
- public EncryptionFilter (EncryptionConfig config , EncryptAlgorithm encryptAlgorithm ) {
54
+ public EncryptionFilter (EncryptionConfig config , DispatcherServlet dispatcherServlet ) {
55
+ this .encryptionConfig = config ;
56
+ this .dispatcherServlet = dispatcherServlet ;
57
+ }
58
+
59
+ public EncryptionFilter (EncryptionConfig config , EncryptAlgorithm encryptAlgorithm , DispatcherServlet dispatcherServlet ) {
50
60
this .encryptionConfig = config ;
51
61
this .encryptAlgorithm = encryptAlgorithm ;
62
+ this .dispatcherServlet = dispatcherServlet ;
52
63
}
53
64
54
65
public EncryptionFilter (String key ) {
@@ -62,6 +73,8 @@ public EncryptionFilter(String key, List<String> responseEncryptUriList, List<St
62
73
this .encryptionConfig = new EncryptionConfig (key , responseEncryptUriList , requestDecryptUriList , responseCharset , debug );
63
74
}
64
75
76
+ private AntPathMatcher antPathMatcher = new AntPathMatcher ();
77
+
65
78
@ Override
66
79
public void init (FilterConfig filterConfig ) throws ServletException {
67
80
@@ -82,10 +95,10 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
82
95
return ;
83
96
}
84
97
85
- boolean decryptionStatus = this .contains (encryptionConfig .getRequestDecryptUriList (), uri , req .getMethod ());
86
- boolean encryptionStatus = this .contains (encryptionConfig .getResponseEncryptUriList (), uri , req .getMethod ());
87
- boolean decryptionIgnoreStatus = this .contains (encryptionConfig .getRequestDecryptUriIgnoreList (), uri , req .getMethod ());
88
- boolean encryptionIgnoreStatus = this .contains (encryptionConfig .getResponseEncryptUriIgnoreList (), uri , req .getMethod ());
98
+ boolean decryptionStatus = this .contains (encryptionConfig .getRequestDecryptUriList (), uri , req .getMethod (), req );
99
+ boolean encryptionStatus = this .contains (encryptionConfig .getResponseEncryptUriList (), uri , req .getMethod (), req );
100
+ boolean decryptionIgnoreStatus = this .contains (encryptionConfig .getRequestDecryptUriIgnoreList (), uri , req .getMethod (), req );
101
+ boolean encryptionIgnoreStatus = this .contains (encryptionConfig .getResponseEncryptUriIgnoreList (), uri , req .getMethod (), req );
89
102
90
103
// 没有配置具体加解密的URI默认全部都开启加解密
91
104
if (CollectionUtils .isEmpty (encryptionConfig .getRequestDecryptUriList ())
@@ -205,7 +218,7 @@ private void writeEncryptContent(String responseData, ServletResponse response)
205
218
}
206
219
}
207
220
208
- private boolean contains (List <String > list , String uri , String methodType ) {
221
+ private boolean contains (List <String > list , String uri , String methodType , HttpServletRequest request ) {
209
222
if (list .contains (uri )) {
210
223
return true ;
211
224
}
@@ -214,9 +227,51 @@ private boolean contains(List<String> list, String uri, String methodType) {
214
227
if (list .contains (prefixUri )) {
215
228
return true ;
216
229
}
230
+
231
+ // 优先用AntPathMatcher,其实用这个也够了,底层是一样的,下面用的方式兜底
232
+ for (String u : list ) {
233
+ boolean match = antPathMatcher .match (u , prefixUri );
234
+ if (match ) {
235
+ return true ;
236
+ }
237
+ }
238
+
239
+ try {
240
+ // 支持RestFul风格API
241
+ // 采用Spring MVC内置的匹配方式将当前请求匹配到对应的Controller Method上,获取注解进行匹配是否要加解密
242
+ HandlerExecutionChain handler = getHandler (request );
243
+ if (Objects .isNull (handler )) {
244
+ return false ;
245
+ }
246
+
247
+ if (Objects .nonNull (handler .getHandler ()) && handler .getHandler () instanceof HandlerMethod ) {
248
+ HandlerMethod handlerMethod = (HandlerMethod ) handler .getHandler ();
249
+ String apiUri = RequestUriUtils .getApiUri (handlerMethod .getClass (), handlerMethod .getMethod (), request .getContextPath ());
250
+ if (list .contains (apiUri )) {
251
+ return true ;
252
+ }
253
+ }
254
+ } catch (Exception e ) {
255
+ throw new RuntimeException (e );
256
+ }
217
257
return false ;
218
258
}
219
259
260
+ protected HandlerExecutionChain getHandler (HttpServletRequest request ) throws Exception {
261
+ if (Objects .isNull (dispatcherServlet )) {
262
+ return null ;
263
+ }
264
+ if (dispatcherServlet .getHandlerMappings () != null ) {
265
+ for (HandlerMapping mapping : dispatcherServlet .getHandlerMappings ()) {
266
+ HandlerExecutionChain handler = mapping .getHandler (request );
267
+ if (handler != null ) {
268
+ return handler ;
269
+ }
270
+ }
271
+ }
272
+ return null ;
273
+ }
274
+
220
275
@ Override
221
276
public void destroy () {
222
277
0 commit comments