Skip to content

Commit

Permalink
Fix URL resolution for base URI with non-empty path (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkwarren authored Dec 8, 2023
1 parent f834eed commit f95a73f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 87 deletions.
22 changes: 17 additions & 5 deletions library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import com.connectrpc.protocols.GETConfiguration
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.suspendCancellableCoroutine
import java.net.URL
import java.net.URI
import java.util.concurrent.CountDownLatch
import kotlin.coroutines.resume

Expand All @@ -48,6 +48,21 @@ class ProtocolClient(
private val config: ProtocolClientConfig,
) : ProtocolClientInterface {

private val baseURIWithTrailingSlash = if (config.baseUri.path != null && config.baseUri.path.endsWith('/')) {
config.baseUri
} else {
val path = config.baseUri.path ?: ""
URI(
config.baseUri.scheme,
config.baseUri.userInfo,
config.baseUri.host,
config.baseUri.port,
"$path/",
config.baseUri.query,
config.baseUri.fragment,
)
}

override fun <Input : Any, Output : Any> unary(
request: Input,
headers: Headers,
Expand Down Expand Up @@ -268,8 +283,5 @@ class ProtocolClient(
)
}

private fun <Input : Any, Output : Any> urlFromMethodSpec(methodSpec: MethodSpec<Input, Output>): URL {
val host = config.baseUri.resolve("/${methodSpec.path}")
return host.toURL()
}
private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = baseURIWithTrailingSlash.resolve(methodSpec.path).toURL()
}
157 changes: 75 additions & 82 deletions library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)

val client = ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = "https://connectrpc.com/",
serializationStrategy = serializationStrategy,
),
)
val client = createClient("https://connectrpc.com/")
client.unary(
"input",
emptyMap(),
MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType = StreamType.UNARY,
),
createMethodSpec(StreamType.UNARY),
) { _ -> }
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}

@Test
Expand All @@ -69,23 +61,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)

val client = ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = "https://connectrpc.com",
serializationStrategy = serializationStrategy,
),
)
val client = createClient("https://connectrpc.com")
client.unary(
"input",
emptyMap(),
MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType = StreamType.UNARY,
),
createMethodSpec(StreamType.UNARY),
) { _ -> }
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}

@Test
Expand All @@ -94,23 +78,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)

val client = ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = "https://connectrpc.com/",
serializationStrategy = serializationStrategy,
),
)
val client = createClient("https://connectrpc.com/")
CoroutineScope(Dispatchers.IO).launch {
client.stream(
emptyMap(),
MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType = StreamType.BIDI,
),
createMethodSpec(StreamType.BIDI),
)
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).stream(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
}

Expand All @@ -120,23 +96,15 @@ class ProtocolClientTest {
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)

val client = ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = "https://connectrpc.com",
serializationStrategy = serializationStrategy,
),
)
val client = createClient("https://connectrpc.com")
CoroutineScope(Dispatchers.IO).launch {
client.stream(
emptyMap(),
MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType = StreamType.BIDI,
),
createMethodSpec(StreamType.BIDI),
)
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).stream(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}
}

Expand All @@ -145,25 +113,12 @@ class ProtocolClientTest {
whenever(codec.encodingName()).thenReturn("testing")
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)
val client = ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = "https://connectrpc.com",
serializationStrategy = serializationStrategy,
),
)
val client = createClient("https://connectrpc.com")
client.unary(
"",
emptyMap(),
MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType = StreamType.UNARY,
),
createMethodSpec(StreamType.UNARY),
) {}

// Use HTTP client to determine and verify the final URL.
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
Expand All @@ -174,27 +129,65 @@ class ProtocolClientTest {
whenever(codec.encodingName()).thenReturn("testing")
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)
val client = ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = "https://connectrpc.com/",
serializationStrategy = serializationStrategy,
),
)
val client = createClient("https://connectrpc.com/")
client.unary(
"",
emptyMap(),
MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType = StreamType.UNARY,
),
createMethodSpec(StreamType.UNARY),
) {}

// Use HTTP client to determine and verify the final URL.
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service")
}

@Test
fun finalUrlRelativeBaseURI() {
whenever(codec.encodingName()).thenReturn("testing")
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)
val client = createClient("https://connectrpc.com/api")
client.unary(
"",
emptyMap(),
createMethodSpec(StreamType.UNARY),
) {}
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service")
}

@Test
fun finalUrlAbsoluteBaseURI() {
whenever(codec.encodingName()).thenReturn("testing")
whenever(codec.serialize(any())).thenReturn(Buffer())
whenever(serializationStrategy.codec<String>(any())).thenReturn(codec)
val client = createClient("https://connectrpc.com/api/")
client.unary(
"",
emptyMap(),
createMethodSpec(StreamType.UNARY),
) {}
val captor = argumentCaptor<HTTPRequest>()
verify(httpClient).unary(captor.capture(), any())
assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service")
}

private fun createClient(host: String): ProtocolClient {
return ProtocolClient(
httpClient = httpClient,
config = ProtocolClientConfig(
host = host,
serializationStrategy = serializationStrategy,
),
)
}

private fun createMethodSpec(streamType: StreamType): MethodSpec<String, String> {
return MethodSpec(
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
streamType,
)
}
}

0 comments on commit f95a73f

Please sign in to comment.