diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 129a1991..5ec3f534 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -34,7 +34,7 @@ import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.suspendCancellableCoroutine -import java.net.URL +import java.net.URI import java.util.concurrent.CountDownLatch import kotlin.coroutines.resume @@ -48,6 +48,21 @@ class ProtocolClient( private val config: ProtocolClientConfig, ) : ProtocolClientInterface { + private val baseURIWithTrailingSlash = if (config.baseUri.path != null && config.baseUri.path.endsWith('/')) { + config.baseUri + } else { + val path = config.baseUri.path ?: "" + URI( + config.baseUri.scheme, + config.baseUri.userInfo, + config.baseUri.host, + config.baseUri.port, + "$path/", + config.baseUri.query, + config.baseUri.fragment, + ) + } + override fun unary( request: Input, headers: Headers, @@ -268,8 +283,5 @@ class ProtocolClient( ) } - private fun urlFromMethodSpec(methodSpec: MethodSpec): URL { - val host = config.baseUri.resolve("/${methodSpec.path}") - return host.toURL() - } + private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = baseURIWithTrailingSlash.resolve(methodSpec.path).toURL() } diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 9177e823..04ad2a94 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -44,23 +44,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com/", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com/") client.unary( "input", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) { _ -> } + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @Test @@ -69,23 +61,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com") client.unary( "input", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) { _ -> } + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @Test @@ -94,23 +78,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com/", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com/") CoroutineScope(Dispatchers.IO).launch { client.stream( emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.BIDI, - ), + createMethodSpec(StreamType.BIDI), ) + val captor = argumentCaptor() + verify(httpClient).stream(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } } @@ -120,23 +96,15 @@ class ProtocolClientTest { whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com") CoroutineScope(Dispatchers.IO).launch { client.stream( emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.BIDI, - ), + createMethodSpec(StreamType.BIDI), ) + val captor = argumentCaptor() + verify(httpClient).stream(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } } @@ -145,25 +113,12 @@ class ProtocolClientTest { whenever(codec.encodingName()).thenReturn("testing") whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com") client.unary( "", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) {} - - // Use HTTP client to determine and verify the final URL. val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") @@ -174,27 +129,65 @@ class ProtocolClientTest { whenever(codec.encodingName()).thenReturn("testing") whenever(codec.serialize(any())).thenReturn(Buffer()) whenever(serializationStrategy.codec(any())).thenReturn(codec) - val client = ProtocolClient( - httpClient = httpClient, - config = ProtocolClientConfig( - host = "https://connectrpc.com/", - serializationStrategy = serializationStrategy, - ), - ) + val client = createClient("https://connectrpc.com/") client.unary( "", emptyMap(), - MethodSpec( - path = "com.connectrpc.SomeService/Service", - String::class, - String::class, - streamType = StreamType.UNARY, - ), + createMethodSpec(StreamType.UNARY), ) {} - - // Use HTTP client to determine and verify the final URL. val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } + + @Test + fun finalUrlRelativeBaseURI() { + whenever(codec.encodingName()).thenReturn("testing") + whenever(codec.serialize(any())).thenReturn(Buffer()) + whenever(serializationStrategy.codec(any())).thenReturn(codec) + val client = createClient("https://connectrpc.com/api") + client.unary( + "", + emptyMap(), + createMethodSpec(StreamType.UNARY), + ) {} + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") + } + + @Test + fun finalUrlAbsoluteBaseURI() { + whenever(codec.encodingName()).thenReturn("testing") + whenever(codec.serialize(any())).thenReturn(Buffer()) + whenever(serializationStrategy.codec(any())).thenReturn(codec) + val client = createClient("https://connectrpc.com/api/") + client.unary( + "", + emptyMap(), + createMethodSpec(StreamType.UNARY), + ) {} + val captor = argumentCaptor() + verify(httpClient).unary(captor.capture(), any()) + assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") + } + + private fun createClient(host: String): ProtocolClient { + return ProtocolClient( + httpClient = httpClient, + config = ProtocolClientConfig( + host = host, + serializationStrategy = serializationStrategy, + ), + ) + } + + private fun createMethodSpec(streamType: StreamType): MethodSpec { + return MethodSpec( + path = "com.connectrpc.SomeService/Service", + String::class, + String::class, + streamType, + ) + } }