diff --git a/changelog.d/pr-1800 b/changelog.d/pr-1800 new file mode 100644 index 000000000..06f346c06 --- /dev/null +++ b/changelog.d/pr-1800 @@ -0,0 +1,7 @@ +synopsis: Add Host API combinator +packages: servant servant-client-core servant-client servant-server +prs: #1800 +description: { + Adding a Host combinator allows servant users to select APIs according + to the Host header provided by clients. +} diff --git a/servant-client-core/src/Servant/Client/Core/HasClient.hs b/servant-client-core/src/Servant/Client/Core/HasClient.hs index 69b97541c..b3b40e40a 100644 --- a/servant-client-core/src/Servant/Client/Core/HasClient.hs +++ b/servant-client-core/src/Servant/Client/Core/HasClient.hs @@ -66,7 +66,7 @@ import Servant.API ReflectMethod (..), StreamBody', Verb, - getResponse, AuthProtect, BasicAuth, BasicAuthData, Capture', CaptureAll, DeepQuery, Description, Fragment, FramingRender (..), FramingUnrender (..), Header', Headers (..), HttpVersion, MimeRender (mimeRender), NoContent (NoContent), QueryFlag, QueryParam', QueryParams, QueryString, Raw, RawM, RemoteHost, ReqBody', SBoolI, Stream, Summary, ToHttpApiData, ToSourceIO (..), Vault, WithNamedContext, WithResource, WithStatus (..), contentType, getHeadersHList, toEncodedUrlPiece, NamedRoutes) + getResponse, AuthProtect, BasicAuth, BasicAuthData, Capture', CaptureAll, DeepQuery, Description, Fragment, FramingRender (..), FramingUnrender (..), Header', Headers (..), HttpVersion, MimeRender (mimeRender), NoContent (NoContent), QueryFlag, QueryParam', QueryParams, QueryString, Raw, RawM, RemoteHost, ReqBody', SBoolI, Stream, Summary, ToHttpApiData, ToSourceIO (..), Vault, WithNamedContext, WithResource, WithStatus (..), contentType, getHeadersHList, toEncodedUrlPiece, NamedRoutes, Host) import Servant.API.Generic (GenericMode(..), ToServant, ToServantApi , GenericServant, toServant, fromServant) @@ -494,6 +494,15 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire hoistClientMonad pm _ f cl = \arg -> hoistClientMonad pm (Proxy :: Proxy api) f (cl arg) +instance (KnownSymbol sym, HasClient m api) => HasClient m (Host sym :> api) where + type Client m (Host sym :> api) = Client m api + + clientWithRoute pm Proxy req = + clientWithRoute pm (Proxy :: Proxy api) $ + addHeader "Host" (symbolVal (Proxy :: Proxy sym)) req + + hoistClientMonad pm _ = hoistClientMonad pm (Proxy :: Proxy api) + -- | Using a 'HttpVersion' combinator in your API doesn't affect the client -- functions. instance HasClient m api diff --git a/servant-client/test/Servant/ClientTestUtils.hs b/servant-client/test/Servant/ClientTestUtils.hs index dbde8c193..1639527f9 100644 --- a/servant-client/test/Servant/ClientTestUtils.hs +++ b/servant-client/test/Servant/ClientTestUtils.hs @@ -68,7 +68,7 @@ import Servant.API JSON, MimeRender (mimeRender), MimeUnrender (mimeUnrender), NoContent (NoContent), PlainText, Post, QueryFlag, QueryParam, QueryParams, QueryString, Raw, ReqBody, StdMethod (GET), ToHttpApiData (..), - UVerb, Union, Verb, WithStatus (WithStatus), NamedRoutes, addHeader) + UVerb, Union, Verb, WithStatus (WithStatus), NamedRoutes, addHeader, Host) import Servant.API.Generic ((:-)) import Servant.API.QueryString (FromDeepQuery(..), ToDeepQuery(..)) import Servant.Client @@ -221,6 +221,7 @@ type Api = :<|> NamedRoutes RecordRoutes :<|> "multiple-choices-int" :> MultipleChoicesInt :<|> "captureVerbatim" :> Capture "someString" Verbatim :> Get '[PlainText] Text + :<|> "host-test" :> Host "servant.example" :> Get '[JSON] Bool api :: Proxy Api api = Proxy @@ -256,6 +257,7 @@ uverbGetCreated :: ClientM (Union '[WithStatus 201 Person]) recordRoutes :: RecordRoutes (AsClientT ClientM) multiChoicesInt :: Int -> ClientM MultipleChoicesIntResult captureVerbatim :: Verbatim -> ClientM Text +getHost :: ClientM Bool getRoot :<|> getGet @@ -285,7 +287,8 @@ getRoot :<|> uverbGetCreated :<|> recordRoutes :<|> multiChoicesInt - :<|> captureVerbatim = client api + :<|> captureVerbatim + :<|> getHost = client api server :: Application server = serve api ( @@ -349,6 +352,7 @@ server = serve api ( ) :<|> pure . decodeUtf8 . unVerbatim + :<|> pure True ) -- * api for testing failures diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index a8e0e5834..48dfcd9fe 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -16,7 +16,7 @@ module Servant.Server.Internal ) where import Control.Monad - (join, when) + (join, when, unless) import Control.Monad.Trans (liftIO, lift) import Control.Monad.Trans.Resource @@ -48,13 +48,13 @@ import Network.Socket (SockAddr) import Network.Wai (Application, Request, Response, ResponseReceived, httpVersion, isSecure, lazyRequestBody, - queryString, remoteHost, getRequestBodyChunk, requestHeaders, + queryString, remoteHost, getRequestBodyChunk, requestHeaders, requestHeaderHost, requestMethod, responseLBS, responseStream, vault) import Servant.API ((:<|>) (..), (:>), Accept (..), BasicAuth, Capture', CaptureAll, DeepQuery, Description, EmptyAPI, Fragment, FramingRender (..), FramingUnrender (..), FromSourceIO (..), - Header', If, IsSecure (..), NoContentVerb, QueryFlag, + Host, Header', If, IsSecure (..), NoContentVerb, QueryFlag, QueryParam', QueryParams, QueryString, Raw, RawM, ReflectMethod (reflectMethod), RemoteHost, ReqBody', SBool (..), SBoolI (..), SourceIO, Stream, StreamBody', Summary, ToSourceIO (..), Vault, Verb, @@ -461,6 +461,30 @@ instance <> headerName <> " failed: " <> e +instance + ( KnownSymbol sym + , HasServer api context + , HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters + ) => HasServer (Host sym :> api) context where + type ServerT (Host sym :> api) m = ServerT api m + + hoistServerWithContext _ = hoistServerWithContext (Proxy :: Proxy api) + + route _ context (Delayed {..}) = route (Proxy :: Proxy api) context $ + let formatError = + headerParseErrorFormatter $ getContextEntry $ mkContextWithErrorFormatter context + rep = typeRep (Proxy :: Proxy Host) + targetHost = symbolVal (Proxy :: Proxy sym) + hostCheck :: DelayedIO () + hostCheck = withRequest $ \req -> + case requestHeaderHost req of + Just hostBytes -> + let host = BC8.unpack hostBytes + in unless (host == targetHost) $ + delayedFail $ formatError rep req $ "Invalid host: " ++ host + _ -> delayedFail $ formatError rep req "Host header missing" + in Delayed { headersD = headersD <* hostCheck, .. } + -- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API, -- this automatically requires your server-side handler to be a function -- that takes an argument of type @'Maybe' 'Text'@. diff --git a/servant/servant.cabal b/servant/servant.cabal index 477957212..e9c545032 100644 --- a/servant/servant.cabal +++ b/servant/servant.cabal @@ -89,6 +89,7 @@ library Servant.API.Fragment Servant.API.Generic Servant.API.Header + Servant.API.Host Servant.API.HttpVersion Servant.API.IsSecure Servant.API.Modifiers diff --git a/servant/src/Servant/API.hs b/servant/src/Servant/API.hs index 47fa1fadf..de6476b7f 100644 --- a/servant/src/Servant/API.hs +++ b/servant/src/Servant/API.hs @@ -14,6 +14,8 @@ module Servant.API ( module Servant.API.Capture, -- | Capturing parts of the url path as parsed values: @'Capture'@ and @'CaptureAll'@ module Servant.API.Header, + -- | Matching the @Host@ header. + module Servant.API.Host, -- | Retrieving specific headers from the request module Servant.API.HttpVersion, -- | Retrieving the HTTP version of the request @@ -110,6 +112,7 @@ import Servant.API.Generic ToServant, ToServantApi, fromServant, genericApi, toServant) import Servant.API.Header (Header, Header') +import Servant.API.Host (Host) import Servant.API.HttpVersion (HttpVersion (..)) import Servant.API.IsSecure diff --git a/servant/src/Servant/API/Host.hs b/servant/src/Servant/API/Host.hs new file mode 100644 index 000000000..029d528a8 --- /dev/null +++ b/servant/src/Servant/API/Host.hs @@ -0,0 +1,13 @@ +module Servant.API.Host (Host) where + +import Data.Typeable (Typeable) +import GHC.TypeLits (Symbol) + +-- | Match against the given host. +-- +-- This allows you to define APIs over multiple domains. For example: +-- +-- > type API = Host "api1.example" :> API1 +-- > :<|> Host "api2.example" :> API2 +-- +data Host (sym :: Symbol) deriving Typeable