From b89b0ddc6e6014975b1c81603d09cdae15a13c57 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 8 Dec 2023 10:42:54 -0600 Subject: [PATCH 1/3] Fix URL resolution for base URI with absolute path If a client uses a base URI with an absolute path (i.e. 'https://connectrpc.com/api/'), then any RPC methods should be resolved relative to the absolute path (instead of ignoring it as they do today). Fixes #149. --- .../com/connectrpc/impl/ProtocolClient.kt | 6 +- .../com/connectrpc/impl/ProtocolClientTest.kt | 155 +++++++++--------- 2 files changed, 75 insertions(+), 86 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 129a1991..fc0576ff 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -34,7 +34,6 @@ import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.suspendCancellableCoroutine -import java.net.URL import java.util.concurrent.CountDownLatch import kotlin.coroutines.resume @@ -268,8 +267,5 @@ class ProtocolClient( ) } - private fun urlFromMethodSpec(methodSpec: MethodSpec): URL { - val host = config.baseUri.resolve("/${methodSpec.path}") - return host.toURL() - } + private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = config.baseUri.resolve(methodSpec.path).toURL() } diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 9177e823..9a03af24 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -44,23 +44,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com/", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com/") client.unary( "input", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) { _ -> } + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @Test @@ -69,23 +61,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com") client.unary( "input", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) { _ -> } + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @Test @@ -94,23 +78,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com/", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com/") CoroutineScope(Dispatchers.IO).launch { client.stream( emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.BIDI, - ), + createMethodSpec(StreamType.BIDI), ) + val captor = argumentCaptor() + verify(httpClient).stream(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } } @@ -120,23 +96,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com") CoroutineScope(Dispatchers.IO).launch { client.stream( emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.BIDI, - ), + createMethodSpec(StreamType.BIDI), ) + val captor = argumentCaptor() + verify(httpClient).stream(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } } @@ -145,25 +113,12 @@ class ProtocolClientTest { whenever(codec.encodingName()).thenReturn("testing") whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com") client.unary( "", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) {} - - // Use HTTP client to determine and verify the final URL. val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") @@ -174,27 +129,65 @@ class ProtocolClientTest { whenever(codec.encodingName()).thenReturn("testing") whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com/", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com/") client.unary( "", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) {} + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") + } - // Use HTTP client to determine and verify the final URL. + @Test + fun finalUrlRelativeBaseURI() { + whenever(codec.encodingName()).thenReturn("testing") + whenever(codec.serialize(any())).thenReturn(Buffer()) + whenever(serializationStrategy.codec(any())).thenReturn(codec) + val client = createClient("https://connectrpc.com/api") + client.unary( + "", + emptyMap(), + createMethodSpec(StreamType.UNARY), + ) {} val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } + + @Test + fun finalUrlAbsoluteBaseURI() { + whenever(codec.encodingName()).thenReturn("testing") + whenever(codec.serialize(any())).thenReturn(Buffer()) + whenever(serializationStrategy.codec(any())).thenReturn(codec) + val client = createClient("https://connectrpc.com/api/") + client.unary( + "", + emptyMap(), + createMethodSpec(StreamType.UNARY), + ) {} + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") + } + + private fun createClient(host: String): ProtocolClient { + return ProtocolClient( + httpClient = httpClient, + config = ProtocolClientConfig( + host = host, + serializationStrategy = serializationStrategy, + ), + ) + } + + private fun createMethodSpec(streamType: StreamType): MethodSpec { + return MethodSpec( + path = "com.connectrpc.SomeService/Service", + String::class, + String::class, + streamType, + ) + } } From 3b426b86af5650a6221ecd28f51d63b41f9b467a Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 8 Dec 2023 12:10:09 -0600 Subject: [PATCH 2/3] review comments --- .../main/kotlin/com/connectrpc/impl/ProtocolClient.kt | 11 ++++++++++- .../kotlin/com/connectrpc/impl/ProtocolClientTest.kt | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index fc0576ff..cb9ebec6 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -34,6 +34,7 @@ import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.suspendCancellableCoroutine +import java.net.URI import java.util.concurrent.CountDownLatch import kotlin.coroutines.resume @@ -47,6 +48,14 @@ class ProtocolClient( private val config: ProtocolClientConfig, ) : ProtocolClientInterface { + private val baseURIWithTrailingSlash = if (config.baseUri.path != null && config.baseUri.path.endsWith('/')) { + config.baseUri + } else { + val path = config.baseUri.path ?: "" + URI(config.baseUri.scheme, config.baseUri.userInfo, config.baseUri.host, config.baseUri.port, "$path/", + config.baseUri.query, config.baseUri.fragment) + } + override fun unary( request: Input, headers: Headers, @@ -267,5 +276,5 @@ class ProtocolClient( ) } - private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = config.baseUri.resolve(methodSpec.path).toURL() + private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = baseURIWithTrailingSlash.resolve(methodSpec.path).toURL() } diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 9a03af24..04ad2a94 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -153,7 +153,7 @@ class ProtocolClientTest { ) {} val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) - assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } @Test From 811c202202e6bd3694be13829dab2cfa3ac6a4f2 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 8 Dec 2023 12:12:37 -0600 Subject: [PATCH 3/3] fix lint --- .../main/kotlin/com/connectrpc/impl/ProtocolClient.kt | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index cb9ebec6..5ec3f534 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -52,8 +52,15 @@ class ProtocolClient( config.baseUri } else { val path = config.baseUri.path ?: "" - URI(config.baseUri.scheme, config.baseUri.userInfo, config.baseUri.host, config.baseUri.port, "$path/", - config.baseUri.query, config.baseUri.fragment) + URI( + config.baseUri.scheme, + config.baseUri.userInfo, + config.baseUri.host, + config.baseUri.port, + "$path/", + config.baseUri.query, + config.baseUri.fragment, + ) } override fun unary(