diff --git a/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java b/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java index 0a34e733c..aa221c480 100644 --- a/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java +++ b/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java @@ -39,6 +39,9 @@ public class ConvertToSecurityDslVisitor

extends JavaIsoVisitor

{ public static final String FQN_CUSTOMIZER = "org.springframework.security.config.Customizer"; + private static final JavaType.FullyQualified CUSTOMIZER_SHALLOW_TYPE = + (JavaType.ShallowClass) JavaType.buildType(FQN_CUSTOMIZER); + private final String securityFqn; private final Collection convertableMethods; @@ -79,7 +82,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation initialMethod J.MethodInvocation method = super.visitMethodInvocation(initialMethod, executionContext); if (isApplicableMethod(method)) { J.MethodInvocation m = method; - method = findDesiredReplacement(method) + method = createDesiredReplacement(method) .map(newMethodType -> { List chain = computeAndMarkChain(); boolean keepArg = keepArg(m.getSimpleName()); @@ -89,9 +92,9 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation initialMethod .withName(m.getName().withSimpleName(newMethodType.getName())) .withArguments(ListUtils.concat( keepArg ? m.getArguments().get(0) : null, - Collections.singletonList(chain.isEmpty() - ? createDefaultsCall(newMethodType.getParameterTypes().get(keepArg ? 1 : 0)) - : createLambdaParam(paramName, newMethodType.getParameterTypes().get(keepArg ? 1 : 0), chain)) + Collections.singletonList(chain.isEmpty() ? + createDefaultsCall() : + createLambdaParam(paramName, newMethodType.getParameterTypes().get(keepArg ? 1 : 0), chain)) ) ); }) @@ -171,37 +174,38 @@ private boolean hasHandleableArg(J.MethodInvocation m) { && !TypeUtils.isAssignableTo(FQN_CUSTOMIZER, m.getMethodType().getParameterTypes().get(0)); } - private Optional findDesiredReplacement(J.MethodInvocation m) { + private Optional createDesiredReplacement(J.MethodInvocation m) { JavaType.Method methodType = m.getMethodType(); if (methodType == null) { return Optional.empty(); } - JavaType.FullyQualified httpSecurityType = methodType.getDeclaringType(); + JavaType.Parameterized customizerArgType = new JavaType.Parameterized(null, + CUSTOMIZER_SHALLOW_TYPE, Collections.singletonList(methodType.getReturnType())); boolean keepArg = keepArg(m.getSimpleName()); - int expectedParamCount = keepArg ? 2 : 1; - int customizerParamIndex = keepArg ? 1 : 0; - return httpSecurityType.getMethods().stream() - .filter(availableMethod -> availableMethod.getName().equals(methodRenames.getOrDefault(m.getSimpleName(), m.getSimpleName())) && - availableMethod.getParameterTypes().size() == expectedParamCount && - availableMethod.getParameterTypes().get(customizerParamIndex) instanceof JavaType.FullyQualified && - FQN_CUSTOMIZER.equals(((JavaType.FullyQualified) availableMethod.getParameterTypes().get(customizerParamIndex)).getFullyQualifiedName())) - .findFirst(); + List paramNames = keepArg ? ListUtils.concat(methodType.getParameterNames(), "arg1") + : Collections.singletonList("arg0"); + List paramTypes = keepArg ? ListUtils.concat(methodType.getParameterTypes(), customizerArgType) + : Collections.singletonList(customizerArgType); + return Optional.of(methodType.withReturnType(methodType.getDeclaringType()) + .withName(methodRenames.getOrDefault(methodType.getName(), methodType.getName())) + .withParameterNames(paramNames) + .withParameterTypes(paramTypes) + ); } private boolean keepArg(String methodName) { return argReplacements.containsKey(methodName) && argReplacements.get(methodName) == null; } - private Optional findDesiredReplacementForArg(J.MethodInvocation m) { + private Optional createDesiredReplacementForArg(J.MethodInvocation m) { JavaType.Method methodType = m.getMethodType(); if (methodType == null || !hasHandleableArg(m) || !(methodType.getReturnType() instanceof JavaType.Class)) { return Optional.empty(); } - JavaType.Class returnType = (JavaType.Class) methodType.getReturnType(); - return returnType.getMethods().stream() - .filter(availableMethod -> availableMethod.getName().equals(argReplacements.get(m.getSimpleName())) && - availableMethod.getParameterTypes().size() == 1) - .findFirst(); + return Optional.of( + methodType.withName(argReplacements.get(m.getSimpleName())) + .withDeclaringType((JavaType.FullyQualified) methodType.getReturnType()) + ); } // this method is unused in this repo, but, useful in Spring Tool Suite integration @@ -232,7 +236,7 @@ private List computeAndMarkChain() { List chain = new ArrayList<>(); Cursor cursor = getCursor(); J.MethodInvocation initialMethodInvocation = cursor.getValue(); - findDesiredReplacementForArg(initialMethodInvocation).ifPresent(methodType -> + createDesiredReplacementForArg(initialMethodInvocation).ifPresent(methodType -> chain.add(initialMethodInvocation.withMethodType(methodType) .withName(initialMethodInvocation.getName().withSimpleName(methodType.getName())))); cursor = cursor.getParent(2); @@ -272,13 +276,10 @@ private boolean isDisableMethod(J.MethodInvocation method) { return new MethodMatcher("org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer disable()", true).matches(method); } - private J.MethodInvocation createDefaultsCall(JavaType type) { - JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(type); - assert fullyQualified != null; - JavaType.Method methodType = fullyQualified.getMethods().stream().filter(m -> "withDefaults".equals(m.getName()) && m.getParameterTypes().isEmpty() && m.getFlags().contains(Flag.Static)).findFirst().orElse(null); - if (methodType == null) { - throw new IllegalStateException(); - } + private J.MethodInvocation createDefaultsCall() { + JavaType.Method methodType = new JavaType.Method(null, 9, CUSTOMIZER_SHALLOW_TYPE, "withDefaults", + new JavaType.GenericTypeVariable(null, "T", JavaType.GenericTypeVariable.Variance.INVARIANT, null), + null, null, null, null); maybeAddImport(methodType.getDeclaringType().getFullyQualifiedName(), methodType.getName()); return new J.MethodInvocation(Tree.randomId(), Space.EMPTY, Markers.EMPTY, null, null, new J.Identifier(Tree.randomId(), Space.EMPTY, Markers.EMPTY, emptyList(), "withDefaults", null, null),