diff --git a/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java b/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java
index 0a34e733c..aa221c480 100644
--- a/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java
+++ b/src/main/java/org/openrewrite/java/spring/boot2/ConvertToSecurityDslVisitor.java
@@ -39,6 +39,9 @@ public class ConvertToSecurityDslVisitor
extends JavaIsoVisitor
{
public static final String FQN_CUSTOMIZER = "org.springframework.security.config.Customizer";
+ private static final JavaType.FullyQualified CUSTOMIZER_SHALLOW_TYPE =
+ (JavaType.ShallowClass) JavaType.buildType(FQN_CUSTOMIZER);
+
private final String securityFqn;
private final Collection convertableMethods;
@@ -79,7 +82,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation initialMethod
J.MethodInvocation method = super.visitMethodInvocation(initialMethod, executionContext);
if (isApplicableMethod(method)) {
J.MethodInvocation m = method;
- method = findDesiredReplacement(method)
+ method = createDesiredReplacement(method)
.map(newMethodType -> {
List chain = computeAndMarkChain();
boolean keepArg = keepArg(m.getSimpleName());
@@ -89,9 +92,9 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation initialMethod
.withName(m.getName().withSimpleName(newMethodType.getName()))
.withArguments(ListUtils.concat(
keepArg ? m.getArguments().get(0) : null,
- Collections.singletonList(chain.isEmpty()
- ? createDefaultsCall(newMethodType.getParameterTypes().get(keepArg ? 1 : 0))
- : createLambdaParam(paramName, newMethodType.getParameterTypes().get(keepArg ? 1 : 0), chain))
+ Collections.singletonList(chain.isEmpty() ?
+ createDefaultsCall() :
+ createLambdaParam(paramName, newMethodType.getParameterTypes().get(keepArg ? 1 : 0), chain))
)
);
})
@@ -171,37 +174,38 @@ private boolean hasHandleableArg(J.MethodInvocation m) {
&& !TypeUtils.isAssignableTo(FQN_CUSTOMIZER, m.getMethodType().getParameterTypes().get(0));
}
- private Optional findDesiredReplacement(J.MethodInvocation m) {
+ private Optional createDesiredReplacement(J.MethodInvocation m) {
JavaType.Method methodType = m.getMethodType();
if (methodType == null) {
return Optional.empty();
}
- JavaType.FullyQualified httpSecurityType = methodType.getDeclaringType();
+ JavaType.Parameterized customizerArgType = new JavaType.Parameterized(null,
+ CUSTOMIZER_SHALLOW_TYPE, Collections.singletonList(methodType.getReturnType()));
boolean keepArg = keepArg(m.getSimpleName());
- int expectedParamCount = keepArg ? 2 : 1;
- int customizerParamIndex = keepArg ? 1 : 0;
- return httpSecurityType.getMethods().stream()
- .filter(availableMethod -> availableMethod.getName().equals(methodRenames.getOrDefault(m.getSimpleName(), m.getSimpleName())) &&
- availableMethod.getParameterTypes().size() == expectedParamCount &&
- availableMethod.getParameterTypes().get(customizerParamIndex) instanceof JavaType.FullyQualified &&
- FQN_CUSTOMIZER.equals(((JavaType.FullyQualified) availableMethod.getParameterTypes().get(customizerParamIndex)).getFullyQualifiedName()))
- .findFirst();
+ List paramNames = keepArg ? ListUtils.concat(methodType.getParameterNames(), "arg1")
+ : Collections.singletonList("arg0");
+ List paramTypes = keepArg ? ListUtils.concat(methodType.getParameterTypes(), customizerArgType)
+ : Collections.singletonList(customizerArgType);
+ return Optional.of(methodType.withReturnType(methodType.getDeclaringType())
+ .withName(methodRenames.getOrDefault(methodType.getName(), methodType.getName()))
+ .withParameterNames(paramNames)
+ .withParameterTypes(paramTypes)
+ );
}
private boolean keepArg(String methodName) {
return argReplacements.containsKey(methodName) && argReplacements.get(methodName) == null;
}
- private Optional findDesiredReplacementForArg(J.MethodInvocation m) {
+ private Optional createDesiredReplacementForArg(J.MethodInvocation m) {
JavaType.Method methodType = m.getMethodType();
if (methodType == null || !hasHandleableArg(m) || !(methodType.getReturnType() instanceof JavaType.Class)) {
return Optional.empty();
}
- JavaType.Class returnType = (JavaType.Class) methodType.getReturnType();
- return returnType.getMethods().stream()
- .filter(availableMethod -> availableMethod.getName().equals(argReplacements.get(m.getSimpleName())) &&
- availableMethod.getParameterTypes().size() == 1)
- .findFirst();
+ return Optional.of(
+ methodType.withName(argReplacements.get(m.getSimpleName()))
+ .withDeclaringType((JavaType.FullyQualified) methodType.getReturnType())
+ );
}
// this method is unused in this repo, but, useful in Spring Tool Suite integration
@@ -232,7 +236,7 @@ private List computeAndMarkChain() {
List chain = new ArrayList<>();
Cursor cursor = getCursor();
J.MethodInvocation initialMethodInvocation = cursor.getValue();
- findDesiredReplacementForArg(initialMethodInvocation).ifPresent(methodType ->
+ createDesiredReplacementForArg(initialMethodInvocation).ifPresent(methodType ->
chain.add(initialMethodInvocation.withMethodType(methodType)
.withName(initialMethodInvocation.getName().withSimpleName(methodType.getName()))));
cursor = cursor.getParent(2);
@@ -272,13 +276,10 @@ private boolean isDisableMethod(J.MethodInvocation method) {
return new MethodMatcher("org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer disable()", true).matches(method);
}
- private J.MethodInvocation createDefaultsCall(JavaType type) {
- JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(type);
- assert fullyQualified != null;
- JavaType.Method methodType = fullyQualified.getMethods().stream().filter(m -> "withDefaults".equals(m.getName()) && m.getParameterTypes().isEmpty() && m.getFlags().contains(Flag.Static)).findFirst().orElse(null);
- if (methodType == null) {
- throw new IllegalStateException();
- }
+ private J.MethodInvocation createDefaultsCall() {
+ JavaType.Method methodType = new JavaType.Method(null, 9, CUSTOMIZER_SHALLOW_TYPE, "withDefaults",
+ new JavaType.GenericTypeVariable(null, "T", JavaType.GenericTypeVariable.Variance.INVARIANT, null),
+ null, null, null, null);
maybeAddImport(methodType.getDeclaringType().getFullyQualifiedName(), methodType.getName());
return new J.MethodInvocation(Tree.randomId(), Space.EMPTY, Markers.EMPTY, null, null,
new J.Identifier(Tree.randomId(), Space.EMPTY, Markers.EMPTY, emptyList(), "withDefaults", null, null),