diff --git a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java index 98fd0f2c6..1cdf0e133 100644 --- a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java @@ -28,6 +28,7 @@ import org.jetbrains.kotlin.fir.FirElement; import org.jetbrains.kotlin.fir.declarations.FirResolvedImport; import org.jetbrains.kotlin.fir.references.FirResolvedCallableReference; +import org.jetbrains.kotlin.fir.symbols.impl.FirConstructorSymbol; import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol; import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol; import org.jetbrains.kotlin.kdoc.psi.api.KDoc; @@ -286,6 +287,12 @@ public J visitCallableReferenceExpression(KtCallableReferenceExpression expressi expression.getReceiverExpression() ); } + if (reference != null && reference.getResolvedSymbol() instanceof FirConstructorSymbol) { + methodReferenceType = psiElementAssociations.getTypeMapping().methodDeclarationType( + ((FirConstructorSymbol) reference.getResolvedSymbol()).getFir(), + expression.getReceiverExpression() + ); + } JavaType.Variable fieldReferenceType = null; if (reference != null && reference.getResolvedSymbol() instanceof FirPropertySymbol) { fieldReferenceType = psiElementAssociations.getTypeMapping().variableType( diff --git a/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java b/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java index f51aa068e..8935ec4da 100644 --- a/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java +++ b/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java @@ -1569,5 +1569,35 @@ public J.NewClass visitNewClass(J.NewClass newClass, Integer integer) { ) ); } + + @Test + void constructorMemberReferenceType() { + rewriteRun( + kotlin( + """ + open class A( + val foo : ( ( Any ) -> A) -> A + ) + class B : A ( foo = { x -> ( :: A ) ( x ) } ) { + @Suppress("UNUSED_PARAMETER") + fun mRef(a: Any) {} + } + """, spec -> spec.afterRecipe(cu -> { + AtomicBoolean found = new AtomicBoolean(false); + new KotlinIsoVisitor() { + @Override + public J.MemberReference visitMemberReference(J.MemberReference memberRef, Integer integer) { + if ("A".equals(memberRef.getReference().getSimpleName())) { + assertThat(memberRef.getMethodType().toString()).isEqualTo("A{name=,return=A,parameters=[kotlin.Function1, A>]}"); + found.set(true); + } + return super.visitMemberReference(memberRef, integer); + } + }.visit(cu, 0); + assertThat(found.get()).isEqualTo(true); + }) + ) + ); + } } }