diff --git a/zio-http/src/main/scala/zio/http/Middleware.scala b/zio-http/src/main/scala/zio/http/Middleware.scala index a6daeec390..71570f9e14 100644 --- a/zio-http/src/main/scala/zio/http/Middleware.scala +++ b/zio-http/src/main/scala/zio/http/Middleware.scala @@ -15,10 +15,14 @@ */ package zio.http +import java.io.File + import zio._ import zio.metrics._ import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.http.codec.{PathCodec, SegmentCodec} + trait Middleware[-UpperEnv] { self => def apply[Env1 <: UpperEnv, Err]( routes: Routes[Env1, Err], @@ -244,10 +248,84 @@ object Middleware extends HandlerAspects { } } + private sealed trait StaticServe[-R, +E] { self => + def run(path: Path, req: Request): Handler[R, E, Request, Response] + + } + + private object StaticServe { + def make[R, E](f: (Path, Request) => Handler[R, E, Request, Response]): StaticServe[R, E] = + new StaticServe[R, E] { + override def run(path: Path, request: Request) = f(path, request) + } + + def fromDirectory(docRoot: File)(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) => + val target = new File(docRoot.getAbsolutePath() + path.encode) + if (target.getCanonicalPath.startsWith(docRoot.getCanonicalPath)) Handler.fromFile(target) + else { + Handler.fromZIO( + ZIO.logWarning(s"attempt to access file outside of docRoot: ${target.getAbsolutePath}"), + ) *> Handler.badRequest + } + } + + def fromResource(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) => + Handler.fromResource(path.dropLeadingSlash.encode) + } + + } + + private def toMiddleware[E](path: Path, staticServe: StaticServe[Any, E])(implicit trace: Trace): Middleware[Any] = + new Middleware[Any] { + + private def checkFishy(acc: Boolean, segment: String): Boolean = { + val stop = segment.indexOf('/') >= 0 || segment.indexOf('\\') >= 0 || segment == ".." + acc || stop + } + + override def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = { + val mountpoint = Method.GET / path.segments.map(PathCodec.literal).reduceLeft(_ / _) + val pattern = mountpoint / trailing + val other = Routes( + pattern -> Handler + .identity[Request] + .flatMap { request => + val isFishy = request.path.segments.foldLeft(false)(checkFishy) + if (isFishy) { + Handler.fromZIO(ZIO.logWarning(s"fishy request detected: ${request.path.encode}")) *> Handler.badRequest + } else { + val segs = pattern.pathCodec.segments.collect { case SegmentCodec.Literal(v, _) => + v + } + val unnest = segs.foldLeft(Path.empty)(_ / _).addLeadingSlash + val path = request.path.unnest(unnest).addLeadingSlash + staticServe.run(path, request).sandbox + } + }, + ) + routes ++ other + } + } + + /** + * Creates a middleware for serving static files from the directory `docRoot` + * at the path `path`. + */ + def serveDirectory(path: Path, docRoot: File)(implicit trace: Trace): Middleware[Any] = + toMiddleware(path, StaticServe.fromDirectory(docRoot)) + + /** + * Creates a middleware for serving static files from resources at the path + * `path`. + */ + def serveResources(path: Path)(implicit trace: Trace): Middleware[Any] = + toMiddleware(path, StaticServe.fromResource) + /** * Creates a middleware for managing the flash scope. */ def flashScopeHandling: HandlerAspect[Any, Unit] = Middleware.intercept { (req, resp) => req.cookie("zio-http-flash").fold(resp)(flash => resp.addCookie(Cookie.clear(flash.name))) } + }