diff --git a/scarlet-protocol-socketio-client/src/main/java/com/tinder/scarlet/socketio/client/SocketIoClient.kt b/scarlet-protocol-socketio-client/src/main/java/com/tinder/scarlet/socketio/client/SocketIoClient.kt index 83a0cdbf..d8be6745 100644 --- a/scarlet-protocol-socketio-client/src/main/java/com/tinder/scarlet/socketio/client/SocketIoClient.kt +++ b/scarlet-protocol-socketio-client/src/main/java/com/tinder/scarlet/socketio/client/SocketIoClient.kt @@ -13,12 +13,15 @@ import com.tinder.scarlet.socketio.SocketIoEvent import com.tinder.scarlet.utils.SimpleChannelFactory import com.tinder.scarlet.utils.SimpleProtocolOpenRequestFactory import io.socket.client.IO +import io.socket.client.Manager import io.socket.client.Socket +import io.socket.engineio.client.Transport import org.json.JSONObject class SocketIoClient( - private val url: () -> String, - private val options: IO.Options = IO.Options() + private val url: () -> String, + private val requestHeaders: () -> RequestHeaders = { RequestHeaders(mapOf()) }, + private val options: IO.Options = IO.Options() ) : Protocol { override fun createChannelFactory(): Channel.Factory { @@ -32,7 +35,7 @@ class SocketIoClient( override fun createOpenRequestFactory(channel: Channel): Protocol.OpenRequest.Factory { return SimpleProtocolOpenRequestFactory { - MainChannelOpenRequest(url()) + MainChannelOpenRequest(url(), requestHeaders()) } } @@ -40,7 +43,10 @@ class SocketIoClient( return SocketIoEvent.Adapter.Factory() } - data class MainChannelOpenRequest(val url: String) : Protocol.OpenRequest + data class MainChannelOpenRequest(val url: String, + val requestHeaders: RequestHeaders) : Protocol.OpenRequest + + data class RequestHeaders(val headers: Map) } class SocketIoEventName( @@ -71,7 +77,10 @@ internal class SocketIoMainChannel( override fun open(openRequest: Protocol.OpenRequest) { val mainChannelOpenRequest = openRequest as SocketIoClient.MainChannelOpenRequest - val socket = IO.socket(mainChannelOpenRequest.url, options) + val socket = IO.socket(mainChannelOpenRequest.url, options).apply { + addRequestHeaders(mainChannelOpenRequest.requestHeaders) + } + socket .on(Socket.EVENT_CONNECT) { listener.onOpened(this) @@ -168,4 +177,19 @@ internal class SocketIoMessageChannel( } return true } +} + +fun Socket.addRequestHeaders(requestHeaders: SocketIoClient.RequestHeaders) { + io().on(Manager.EVENT_TRANSPORT) { transportData -> + + val transport = transportData[0] as Transport + + transport.on(Transport.EVENT_REQUEST_HEADERS) { headersData -> + val headersOut = headersData[0] as MutableMap> + + requestHeaders.headers.forEach { (key, value) -> + headersOut[key] = listOf(value) + } + } + } } \ No newline at end of file