Skip to content

Commit

Permalink
cleanly separate HTTPRequest (base type, applies to streams and unary…
Browse files Browse the repository at this point in the history
…) and UnaryHTTPRequest (only for unary RPCs, includes request message contents)
  • Loading branch information
jhump committed Jan 30, 2024
1 parent d0b8507 commit 77ae245
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 133 deletions.
3 changes: 2 additions & 1 deletion library/src/main/kotlin/com/connectrpc/Interceptor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package com.connectrpc

import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.UnaryHTTPRequest
import okio.Buffer

/**
Expand Down Expand Up @@ -52,7 +53,7 @@ interface Interceptor {
}

class UnaryFunction(
val requestFunction: (HTTPRequest) -> HTTPRequest = { it },
val requestFunction: (UnaryHTTPRequest) -> UnaryHTTPRequest = { it },
val responseFunction: (HTTPResponse) -> HTTPResponse = { it },
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ interface HTTPClientInterface {
*
* @return A function to cancel the underlying network call.
*/
fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable
fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable

/**
* Initialize a new HTTP stream.
Expand Down
108 changes: 74 additions & 34 deletions library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package com.connectrpc.http

import com.connectrpc.Headers
import com.connectrpc.MethodSpec
import okio.Buffer
import java.net.URL

internal object HTTPMethod {
Expand All @@ -24,50 +25,89 @@ internal object HTTPMethod {
}

/**
* HTTP request used for sending primitive data to the server.
* HTTP request used to initiate RPCs.
*/
class HTTPRequest internal constructor(
open class HTTPRequest internal constructor(
// The URL for the request.
val url: URL,
// Value to assign to the `content-type` header.
val contentType: String,
// Additional outbound headers for the request.
val headers: Headers,
// Body data to send with the request.
val message: ByteArray? = null,
// The method spec associated with the request.
val methodSpec: MethodSpec<*, *>,
// HTTP method to use with the request.
// Almost always POST, but side effect free unary RPCs may be made with GET.
val httpMethod: String = HTTPMethod.POST,
) {
/**
* Clones the [HTTPRequest] with override values.
*
* Intended to make mutations for [HTTPRequest] safe for
* [com.connectrpc.Interceptor] implementation.
*/
fun clone(
// The URL for the request.
url: URL = this.url,
// Value to assign to the `content-type` header.
contentType: String = this.contentType,
// Additional outbound headers for the request.
headers: Headers = this.headers,
// Body data to send with the request.
message: ByteArray? = this.message,
// The method spec associated with the request.
methodSpec: MethodSpec<*, *> = this.methodSpec,
// The HTTP method to use with the request.
httpMethod: String = this.httpMethod,
): HTTPRequest {
return HTTPRequest(
url,
contentType,
headers,
message,
methodSpec,
httpMethod,
)
}
)

/**
* Clones the [HTTPRequest] with override values.
*
* Intended to make mutations for [HTTPRequest] safe for
* [com.connectrpc.Interceptor] implementation.
*/
fun HTTPRequest.clone(
// The URL for the request.
url: URL = this.url,
// Value to assign to the `content-type` header.
contentType: String = this.contentType,
// Additional outbound headers for the request.
headers: Headers = this.headers,
// The method spec associated with the request.
methodSpec: MethodSpec<*, *> = this.methodSpec,
// The HTTP method to use with the request.
httpMethod: String = this.httpMethod,
): HTTPRequest {
return HTTPRequest(
url,
contentType,
headers,
methodSpec,
httpMethod,
)
}

/**
* HTTP request used to initiate unary RPCs. In addition
* to RPC metadata, this also includes the request data.
*/
class UnaryHTTPRequest(
// The URL for the request.
url: URL,
// Value to assign to the `content-type` header.
contentType: String,
// Additional outbound headers for the request.
headers: Headers,
// The method spec associated with the request.
methodSpec: MethodSpec<*, *>,
// Body data for the request.
val message: Buffer,
// HTTP method to use with the request.
// Almost always POST, but side effect free unary RPCs may be made with GET.
httpMethod: String = HTTPMethod.POST,
) : HTTPRequest(url, contentType, headers, methodSpec, httpMethod)

fun UnaryHTTPRequest.clone(
// The URL for the request.
url: URL = this.url,
// Value to assign to the `content-type` header.
contentType: String = this.contentType,
// Additional outbound headers for the request.
headers: Headers = this.headers,
// The method spec associated with the request.
methodSpec: MethodSpec<*, *> = this.methodSpec,
// Body data for the request.
message: Buffer = this.message,
// The HTTP method to use with the request.
httpMethod: String = this.httpMethod,
): UnaryHTTPRequest {
return UnaryHTTPRequest(
url,
contentType,
headers,
methodSpec,
message,
httpMethod,
)
}
5 changes: 3 additions & 2 deletions library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.connectrpc.UnaryBlockingCall
import com.connectrpc.http.Cancelable
import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.UnaryHTTPRequest
import com.connectrpc.http.transform
import com.connectrpc.protocols.GETConfiguration
import kotlinx.coroutines.CompletableDeferred
Expand Down Expand Up @@ -79,12 +80,12 @@ class ProtocolClient(
} else {
requestCodec.serialize(request)
}
val unaryRequest = HTTPRequest(
val unaryRequest = UnaryHTTPRequest(
url = urlFromMethodSpec(methodSpec),
contentType = "application/${requestCodec.encodingName()}",
headers = headers,
message = requestMessage.readByteArray(),
methodSpec = methodSpec,
message = requestMessage,
)
val unaryFunc = config.createInterceptorChain()
val finalRequest = unaryFunc.requestFunction(unaryRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import com.connectrpc.compression.CompressionPool
import com.connectrpc.http.HTTPMethod
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.UnaryHTTPRequest
import com.connectrpc.http.clone
import com.connectrpc.toLowercase
import com.squareup.moshi.Moshi
import okio.Buffer
Expand Down Expand Up @@ -67,15 +69,11 @@ internal class ConnectInterceptor(
requestHeaders[USER_AGENT] = listOf("connect-kotlin/${ConnectConstants.VERSION}")
}
val requestCompression = clientConfig.requestCompression
val requestMessage = Buffer()
if (request.message != null) {
requestMessage.write(request.message)
}
val finalRequestBody = if (requestCompression?.shouldCompress(requestMessage) == true) {
val finalRequestBody = if (requestCompression?.shouldCompress(request.message) == true) {
requestHeaders.put(CONTENT_ENCODING, listOf(requestCompression.compressionPool.name()))
requestCompression.compressionPool.compress(requestMessage)
requestCompression.compressionPool.compress(request.message)
} else {
requestMessage
request.message
}
if (shouldUseGETRequest(request, finalRequestBody)) {
constructGETRequest(request, finalRequestBody, requestCompression)
Expand All @@ -84,8 +82,8 @@ internal class ConnectInterceptor(
url = request.url,
contentType = request.contentType,
headers = requestHeaders,
message = finalRequestBody.readByteArray(),
methodSpec = request.methodSpec,
message = finalRequestBody,
)
}
},
Expand Down Expand Up @@ -153,7 +151,6 @@ internal class ConnectInterceptor(
url = request.url,
contentType = request.contentType,
headers = requestHeaders,
message = request.message,
methodSpec = request.methodSpec,
)
},
Expand Down Expand Up @@ -196,10 +193,10 @@ internal class ConnectInterceptor(
}

private fun constructGETRequest(
request: HTTPRequest,
request: UnaryHTTPRequest,
finalRequestBody: Buffer,
requestCompression: RequestCompression?,
): HTTPRequest {
): UnaryHTTPRequest {
val serializationStrategy = clientConfig.serializationStrategy
val requestCodec = serializationStrategy.codec(request.methodSpec.requestClass)
val url = getUrlFromMethodSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.connectrpc.StreamResult
import com.connectrpc.UnaryFunction
import com.connectrpc.compression.CompressionPool
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.clone
import okio.Buffer

/**
Expand All @@ -46,16 +47,10 @@ internal class GRPCInterceptor(
requestHeaders[GRPC_ACCEPT_ENCODING] = clientConfig.compressionPools()
.map { compressionPool -> compressionPool.name() }
}
val requestMessage = Buffer().use { buffer ->
if (request.message != null) {
buffer.write(request.message)
}
buffer
}
val requestCompression = clientConfig.requestCompression
// GRPC unary payloads are enveloped.
val envelopedMessage = Envelope.pack(
requestMessage,
request.message,
requestCompression?.compressionPool,
requestCompression?.minBytes,
)
Expand All @@ -64,7 +59,7 @@ internal class GRPCInterceptor(
// The underlying content type is overridden here.
contentType = "application/grpc+${serializationStrategy.serializationName()}",
headers = requestHeaders.withGRPCRequestHeaders(),
message = envelopedMessage.readByteArray(),
message = envelopedMessage,
)
},
responseFunction = { response ->
Expand Down Expand Up @@ -128,7 +123,6 @@ internal class GRPCInterceptor(
url = request.url,
contentType = "application/grpc+${serializationStrategy.serializationName()}",
headers = request.headers.withGRPCRequestHeaders(),
message = request.message,
)
},
requestBodyFunction = { buffer ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.connectrpc.Trailers
import com.connectrpc.UnaryFunction
import com.connectrpc.compression.CompressionPool
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.clone
import okio.Buffer

internal const val TRAILERS_BIT = 0b10000000
Expand All @@ -50,15 +51,9 @@ internal class GRPCWebInterceptor(
.map { compressionPool -> compressionPool.name() }
}
val requestCompressionPool = clientConfig.requestCompression
val requestMessage = Buffer().use { buffer ->
if (request.message != null) {
buffer.write(request.message)
}
buffer
}
// GRPC unary payloads are enveloped.
val envelopedMessage = Envelope.pack(
requestMessage,
request.message,
requestCompressionPool?.compressionPool,
requestCompressionPool?.minBytes,
)
Expand All @@ -68,7 +63,7 @@ internal class GRPCWebInterceptor(
// The underlying content type is overridden here.
contentType = "application/grpc-web+${serializationStrategy.serializationName()}",
headers = requestHeaders.withGRPCRequestHeaders(),
message = envelopedMessage.readByteArray(),
message = envelopedMessage,
)
},
responseFunction = { response ->
Expand Down Expand Up @@ -175,7 +170,6 @@ internal class GRPCWebInterceptor(
url = request.url,
contentType = "application/grpc-web+${serializationStrategy.serializationName()}",
headers = request.headers.withGRPCRequestHeaders(),
message = request.message,
)
},
requestBodyFunction = { buffer ->
Expand Down
22 changes: 6 additions & 16 deletions library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package com.connectrpc
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.TracingInfo
import com.connectrpc.http.UnaryHTTPRequest
import com.connectrpc.http.clone
import com.connectrpc.protocols.Envelope
import com.connectrpc.protocols.NetworkProtocol
import okio.Buffer
Expand Down Expand Up @@ -72,7 +74,7 @@ class InterceptorChainTest {

@Test
fun fifo_request_unary() {
val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC))
val response = unaryChain.requestFunction(UnaryHTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), UNARY_METHOD_SPEC, Buffer()))
assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4")
}

Expand All @@ -84,7 +86,7 @@ class InterceptorChainTest {

@Test
fun fifo_request_stream() {
val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC))
val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), STREAM_METHOD_SPEC))
assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4")
}

Expand Down Expand Up @@ -115,13 +117,7 @@ class InterceptorChainTest {
val sequence = headers.get("id")?.toMutableList() ?: mutableListOf()
sequence.add(id)
headers.put("id", sequence)
HTTPRequest(
it.url,
it.contentType,
headers,
it.message,
UNARY_METHOD_SPEC,
)
it.clone(headers = headers)
},
responseFunction = {
val headers = it.headers.toMutableMap()
Expand All @@ -147,13 +143,7 @@ class InterceptorChainTest {
val sequence = headers.get("id")?.toMutableList() ?: mutableListOf()
sequence.add(id)
headers.put("id", sequence)
HTTPRequest(
it.url,
it.contentType,
headers,
it.message,
STREAM_METHOD_SPEC,
)
it.clone(headers = headers)
},
requestBodyFunction = {
it.writeString(id, Charsets.UTF_8)
Expand Down
Loading

0 comments on commit 77ae245

Please sign in to comment.