diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
index 129a1991..5ec3f534 100644
--- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
+++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
@@ -34,7 +34,7 @@ import com.connectrpc.protocols.GETConfiguration
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.suspendCancellableCoroutine
-import java.net.URL
+import java.net.URI
import java.util.concurrent.CountDownLatch
import kotlin.coroutines.resume
@@ -48,6 +48,21 @@ class ProtocolClient(
private val config: ProtocolClientConfig,
) : ProtocolClientInterface {
+ private val baseURIWithTrailingSlash = if (config.baseUri.path != null && config.baseUri.path.endsWith('/')) {
+ config.baseUri
+ } else {
+ val path = config.baseUri.path ?: ""
+ URI(
+ config.baseUri.scheme,
+ config.baseUri.userInfo,
+ config.baseUri.host,
+ config.baseUri.port,
+ "$path/",
+ config.baseUri.query,
+ config.baseUri.fragment,
+ )
+ }
+
override fun unary(
request: Input,
headers: Headers,
@@ -268,8 +283,5 @@ class ProtocolClient(
)
}
- private fun urlFromMethodSpec(methodSpec: MethodSpec): URL {
- val host = config.baseUri.resolve("/${methodSpec.path}")
- return host.toURL()
- }
+ private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = baseURIWithTrailingSlash.resolve(methodSpec.path).toURL()
}
diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
index 9177e823..04ad2a94 100644
--- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
@@ -44,23 +44,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec(any())).thenReturn(codec)
- val client = ProtocolClient(
- httpClient = httpClient,
- config = ProtocolClientConfig(
- host = "https://connectrpc.com/",
- serializationStrategy = serializationStrategy,
- ),
- )
+ val client = createClient("https://connectrpc.com/")
client.unary(
"input",
emptyMap(),
- MethodSpec(
- path = "com.connectrpc.SomeService/Service",
- String::class,
- String::class,
- streamType = StreamType.UNARY,
- ),
+ createMethodSpec(StreamType.UNARY),
) { _ -> }
+ val captor = argumentCaptor()
+ verify(httpClient).unary(captor.capture(), any())
+ assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
@Test
@@ -69,23 +61,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec(any())).thenReturn(codec)
- val client = ProtocolClient(
- httpClient = httpClient,
- config = ProtocolClientConfig(
- host = "https://connectrpc.com",
- serializationStrategy = serializationStrategy,
- ),
- )
+ val client = createClient("https://connectrpc.com")
client.unary(
"input",
emptyMap(),
- MethodSpec(
- path = "com.connectrpc.SomeService/Service",
- String::class,
- String::class,
- streamType = StreamType.UNARY,
- ),
+ createMethodSpec(StreamType.UNARY),
) { _ -> }
+ val captor = argumentCaptor()
+ verify(httpClient).unary(captor.capture(), any())
+ assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
@Test
@@ -94,23 +78,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec(any())).thenReturn(codec)
- val client = ProtocolClient(
- httpClient = httpClient,
- config = ProtocolClientConfig(
- host = "https://connectrpc.com/",
- serializationStrategy = serializationStrategy,
- ),
- )
+ val client = createClient("https://connectrpc.com/")
CoroutineScope(Dispatchers.IO).launch {
client.stream(
emptyMap(),
- MethodSpec(
- path = "com.connectrpc.SomeService/Service",
- String::class,
- String::class,
- streamType = StreamType.BIDI,
- ),
+ createMethodSpec(StreamType.BIDI),
)
+ val captor = argumentCaptor()
+ verify(httpClient).stream(captor.capture(), any())
+ assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
}
@@ -120,23 +96,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec(any())).thenReturn(codec)
- val client = ProtocolClient(
- httpClient = httpClient,
- config = ProtocolClientConfig(
- host = "https://connectrpc.com",
- serializationStrategy = serializationStrategy,
- ),
- )
+ val client = createClient("https://connectrpc.com")
CoroutineScope(Dispatchers.IO).launch {
client.stream(
emptyMap(),
- MethodSpec(
- path = "com.connectrpc.SomeService/Service",
- String::class,
- String::class,
- streamType = StreamType.BIDI,
- ),
+ createMethodSpec(StreamType.BIDI),
)
+ val captor = argumentCaptor()
+ verify(httpClient).stream(captor.capture(), any())
+ assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
}
@@ -145,25 +113,12 @@ class ProtocolClientTest {
whenever(codec.encodingName()).thenReturn("testing")
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec(any())).thenReturn(codec)
- val client = ProtocolClient(
- httpClient = httpClient,
- config = ProtocolClientConfig(
- host = "https://connectrpc.com",
- serializationStrategy = serializationStrategy,
- ),
- )
+ val client = createClient("https://connectrpc.com")
client.unary(
"",
emptyMap(),
- MethodSpec(
- path = "com.connectrpc.SomeService/Service",
- String::class,
- String::class,
- streamType = StreamType.UNARY,
- ),
+ createMethodSpec(StreamType.UNARY),
) {}
-
- // Use HTTP client to determine and verify the final URL.
val captor = argumentCaptor()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
@@ -174,27 +129,65 @@ class ProtocolClientTest {
whenever(codec.encodingName()).thenReturn("testing")
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec(any())).thenReturn(codec)
- val client = ProtocolClient(
- httpClient = httpClient,
- config = ProtocolClientConfig(
- host = "https://connectrpc.com/",
- serializationStrategy = serializationStrategy,
- ),
- )
+ val client = createClient("https://connectrpc.com/")
client.unary(
"",
emptyMap(),
- MethodSpec(
- path = "com.connectrpc.SomeService/Service",
- String::class,
- String::class,
- streamType = StreamType.UNARY,
- ),
+ createMethodSpec(StreamType.UNARY),
) {}
-
- // Use HTTP client to determine and verify the final URL.
val captor = argumentCaptor()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
+
+ @Test
+ fun finalUrlRelativeBaseURI() {
+ whenever(codec.encodingName()).thenReturn("testing")
+ whenever(codec.serialize(any())).thenReturn(Buffer())
+ whenever(serializationStrategy.codec(any())).thenReturn(codec)
+ val client = createClient("https://connectrpc.com/api")
+ client.unary(
+ "",
+ emptyMap(),
+ createMethodSpec(StreamType.UNARY),
+ ) {}
+ val captor = argumentCaptor()
+ verify(httpClient).unary(captor.capture(), any())
+ assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service")
+ }
+
+ @Test
+ fun finalUrlAbsoluteBaseURI() {
+ whenever(codec.encodingName()).thenReturn("testing")
+ whenever(codec.serialize(any())).thenReturn(Buffer())
+ whenever(serializationStrategy.codec(any())).thenReturn(codec)
+ val client = createClient("https://connectrpc.com/api/")
+ client.unary(
+ "",
+ emptyMap(),
+ createMethodSpec(StreamType.UNARY),
+ ) {}
+ val captor = argumentCaptor()
+ verify(httpClient).unary(captor.capture(), any())
+ assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service")
+ }
+
+ private fun createClient(host: String): ProtocolClient {
+ return ProtocolClient(
+ httpClient = httpClient,
+ config = ProtocolClientConfig(
+ host = host,
+ serializationStrategy = serializationStrategy,
+ ),
+ )
+ }
+
+ private fun createMethodSpec(streamType: StreamType): MethodSpec {
+ return MethodSpec(
+ path = "com.connectrpc.SomeService/Service",
+ String::class,
+ String::class,
+ streamType,
+ )
+ }
}