Skip to content

Commit

Permalink
Redirect
Browse files Browse the repository at this point in the history
  • Loading branch information
spenes committed Aug 17, 2023
1 parent 96b0761 commit 84736c6
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.http4s.dsl.Http4sDsl
import org.http4s.implicits._
import com.comcast.ip4s.Dns

class Routes[F[_]: Sync](service: IService[F]) extends Http4sDsl[F] {
class Routes[F[_]: Sync](enableDefaultRedirect: Boolean, service: IService[F]) extends Http4sDsl[F] {

implicit val dns: Dns[F] = Dns.forSync[F]

Expand Down Expand Up @@ -55,5 +55,14 @@ class Routes[F[_]: Sync](service: IService[F]) extends Http4sDsl[F] {
)
}

val value: HttpApp[F] = (healthRoutes <+> corsRoute <+> cookieRoutes).orNotFound
def rejectRedirect = HttpRoutes.of[F] {
case _ -> Root / "r" / _ =>
NotFound("redirects disabled")
}

val value: HttpApp[F] = {
val routes = healthRoutes <+> corsRoute <+> cookieRoutes
val res = if (enableDefaultRedirect) routes else rejectRedirect <+> routes
res.orNotFound
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object Run {
appInfo
)
httpServer = HttpServer.build[F](
new Routes[F](collectorService).value,
new Routes[F](config.enableDefaultRedirect, collectorService).value,
config.interface,
config.port
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Service[F[_]: Sync](
): F[Response[F]] =
for {
body <- body
redirect = path.startsWith("/r/")
hostname = extractHostname(request)
userAgent = extractHeader(request, "User-Agent")
refererUri = extractHeader(request, "Referer")
Expand Down Expand Up @@ -104,7 +105,12 @@ class Service[F[_]: Sync](
).flatten
responseHeaders = Headers(headerList)
_ <- sinkEvent(event, partitionKey)
resp = buildHttpResponse(responseHeaders, pixelExpected)
resp = buildHttpResponse(
queryParams = request.uri.query.params,
headers = responseHeaders,
redirect = redirect,
pixelExpected = pixelExpected
)
} yield resp

override def determinePath(vendor: String, version: String): String = {
Expand Down Expand Up @@ -170,11 +176,19 @@ class Service[F[_]: Sync](
e
}

// TODO: Handle necessary cases to build http response in here
def buildHttpResponse(
queryParams: Map[String, String],
headers: Headers,
redirect: Boolean,
pixelExpected: Boolean
): Response[F] =
if (redirect)
buildRedirectHttpResponse(queryParams, headers)
else
buildUsualHttpResponse(pixelExpected, headers)

/** Builds the appropriate http response when not dealing with click redirects. */
def buildUsualHttpResponse(pixelExpected: Boolean, headers: Headers): Response[F] =
pixelExpected match {
case true =>
Response[F](
Expand All @@ -190,6 +204,32 @@ class Service[F[_]: Sync](
)
}

/** Builds the appropriate http response when dealing with click redirects. */
def buildRedirectHttpResponse(queryParams: Map[String, String], headers: Headers): Response[F] = {
val targetUri = for {
target <- queryParams.get("u")
uri <- Uri.fromString(target).toOption
if redirectTargetAllowed(uri)
} yield uri

targetUri match {
case Some(t) =>
Response[F](
status = Found,
headers = headers.put(Location(t))
)
case _ =>
Response[F](
status = BadRequest,
headers = headers
)
}
}

private def redirectTargetAllowed(target: Uri): Boolean =
if (config.redirectDomains.isEmpty) true
else config.redirectDomains.contains(target.host.map(_.renderString).getOrElse(""))

// TODO: Since Remote-Address and Raw-Request-URI is akka-specific headers,
// they aren't included in here. It might be good to search for counterparts in Http4s.
/** If the SP-Anonymous header is not present, retrieves all headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ class RoutesSpec extends Specification {
override def determinePath(vendor: String, version: String): String = "/p1/p2"
}

def createTestServices = {
def createTestServices(enabledDefaultRedirect: Boolean = true) = {
val service = new TestService()
val routes = new Routes(service).value
val routes = new Routes(enabledDefaultRedirect, service).value
(service, routes)
}

"The collector route" should {
"respond to the health route with an ok response" in {
val (_, routes) = createTestServices
val (_, routes) = createTestServices()
val request = Request[IO](method = Method.GET, uri = uri"/health")
val response = routes.run(request).unsafeRunSync()

Expand All @@ -74,7 +74,7 @@ class RoutesSpec extends Specification {
}

"respond to the cors route with a preflight response" in {
val (_, routes) = createTestServices
val (_, routes) = createTestServices()
def test(uri: Uri) = {
val request = Request[IO](method = Method.OPTIONS, uri = uri)
val response = routes.run(request).unsafeRunSync()
Expand All @@ -86,7 +86,7 @@ class RoutesSpec extends Specification {
}

"respond to the post cookie route with the cookie response" in {
val (collectorService, routes) = createTestServices
val (collectorService, routes) = createTestServices()

val request = Request[IO](method = Method.POST, uri = uri"/p3/p4")
.withEntity("testBody")
Expand All @@ -106,7 +106,7 @@ class RoutesSpec extends Specification {

"respond to the get or head cookie route with the cookie response" in {
def test(method: Method) = {
val (collectorService, routes) = createTestServices
val (collectorService, routes) = createTestServices()

val request = Request[IO](method = method, uri = uri"/p3/p4").withEntity("testBody")
val response = routes.run(request).unsafeRunSync()
Expand All @@ -128,7 +128,7 @@ class RoutesSpec extends Specification {

"respond to the get or head pixel route with the cookie response" in {
def test(method: Method, uri: String) = {
val (collectorService, routes) = createTestServices
val (collectorService, routes) = createTestServices()

val request = Request[IO](method = method, uri = Uri.unsafeFromString(uri)).withEntity("testBody")
val response = routes.run(request).unsafeRunSync()
Expand All @@ -149,6 +149,36 @@ class RoutesSpec extends Specification {
test(Method.GET, "/ice.png")
test(Method.HEAD, "/ice.png")
}

"allow redirect routes when redirects enabled" in {
def test(method: Method) = {
val (_, routes) = createTestServices()

val request = Request[IO](method = method, uri = uri"/r/abc")
val response = routes.run(request).unsafeRunSync()

response.status must beEqualTo(Status.Ok)
response.bodyText.compile.string.unsafeRunSync() must beEqualTo("cookie")
}

test(Method.GET)
test(Method.POST)
}

"disallow redirect routes when redirects disabled" in {
def test(method: Method) = {
val (_, routes) = createTestServices(enabledDefaultRedirect = false)

val request = Request[IO](method = method, uri = uri"/r/abc")
val response = routes.run(request).unsafeRunSync()

response.status must beEqualTo(Status.NotFound)
response.bodyText.compile.string.unsafeRunSync() must beEqualTo("redirects disabled")
}

test(Method.GET)
test(Method.POST)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class ServiceSpec extends Specification {
secure = false
)

def probeService(): ProbeService = {
def probeService(config: Config[Any] = TestUtils.testConfig): ProbeService = {
val good = new TestSink
val bad = new TestSink
val service = new Service(
config = TestUtils.testConfig,
config = config,
sinks = Sinks(good, bad),
appInfo = TestUtils.appInfo
)
Expand Down Expand Up @@ -365,6 +365,35 @@ class ServiceSpec extends Specification {
Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
)
}

"redirect if path starts with '/r/'" in {
val testConf = TestUtils
.testConfig
.copy(
redirectDomains = Set("snowplow.acme.com", "example.com")
)
val testPath = "/r/example?u=https://snowplow.acme.com/12"
val ProbeService(service, good, bad) = probeService(config = testConf)
val req = Request[IO](
method = Method.GET,
uri = Uri.unsafeFromString(testPath)
)
val r = service
.cookie(
body = IO.pure(Some("b")),
path = testPath,
request = req,
pixelExpected = false,
doNotTrack = false,
contentType = None
)
.unsafeRunSync()

r.status mustEqual Status.Found
r.headers.get[Location] must beSome(Location(Uri.unsafeFromString("https://snowplow.acme.com/12")))
good.storedRawEvents must have size 1
bad.storedRawEvents must have size 0
}
}

"preflightResponse" in {
Expand Down Expand Up @@ -451,20 +480,126 @@ class ServiceSpec extends Specification {
}

"buildHttpResponse" in {
"rely on buildRedirectHttpResponse if redirect is true" in {
val testConfig = TestUtils
.testConfig
.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildHttpResponse(
queryParams = Map("u" -> "https://example1.com/12"),
headers = testHeaders,
redirect = true,
pixelExpected = true
)
res.status shouldEqual Status.Found
res.headers shouldEqual testHeaders.put(Location(Uri.unsafeFromString("https://example1.com/12")))
}
"send back a gif if pixelExpected is true" in {
val res = service.buildHttpResponse(
queryParams = Map.empty,
headers = testHeaders,
redirect = false,
pixelExpected = true
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders.put(`Content-Type`(MediaType.image.gif))
res.body.compile.toList.unsafeRunSync().toArray shouldEqual Service.pixel
}
"send back ok otherwise" in {
val res = service.buildHttpResponse(
queryParams = Map.empty,
headers = testHeaders,
redirect = false,
pixelExpected = false
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders
res.bodyText.compile.toList.unsafeRunSync() shouldEqual List("ok")
}
}

"buildUsualHttpResponse" in {
"send back a gif if pixelExpected is true" in {
val res = service.buildHttpResponse(testHeaders, pixelExpected = true)
val res = service.buildUsualHttpResponse(
headers = testHeaders,
pixelExpected = true
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders.put(`Content-Type`(MediaType.image.gif))
res.body.compile.toList.unsafeRunSync().toArray shouldEqual Service.pixel
}
"send back ok otherwise" in {
val res = service.buildHttpResponse(testHeaders, pixelExpected = false)
val res = service.buildUsualHttpResponse(
headers = testHeaders,
pixelExpected = false
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders
res.bodyText.compile.toList.unsafeRunSync() shouldEqual List("ok")
}
}

"buildRedirectHttpResponse" in {
"give back a 302 if redirecting and there is a u query param" in {
val testConfig = TestUtils
.testConfig
.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map("u" -> "https://example1.com/12"),
headers = testHeaders
)
res.status shouldEqual Status.Found
res.headers shouldEqual testHeaders.put(Location(Uri.unsafeFromString("https://example1.com/12")))
}
"give back a 400 if redirecting and there are no u query params" in {
val testConfig = TestUtils
.testConfig
.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map.empty,
headers = testHeaders
)
res.status shouldEqual Status.BadRequest
res.headers shouldEqual testHeaders
}
"give back a 400 if redirecting to a disallowed domain" in {
val testConfig = TestUtils
.testConfig
.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map("u" -> "https://invalidexample1.com/12"),
headers = testHeaders
)
res.status shouldEqual Status.BadRequest
res.headers shouldEqual testHeaders
}
"give back a 302 if redirecting to an unknown domain, with no restrictions on domains" in {
val testConfig = TestUtils
.testConfig
.copy(
redirectDomains = Set.empty
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map("u" -> "https://unknown.example.com/12"),
headers = testHeaders
)
res.status shouldEqual Status.Found
res.headers shouldEqual testHeaders.put(Location(Uri.unsafeFromString("https://unknown.example.com/12")))
}
}

"ipAndPartitionkey" in {
"give back the ip and partition key as ip if remote address is defined" in {
val address = Some("127.0.0.1")
Expand Down

0 comments on commit 84736c6

Please sign in to comment.