Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multipart body support in sttp stub #4117

Merged
merged 9 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@ import sttp.tapir.RawBodyType
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, InputStream}
import java.nio.ByteBuffer
import scala.annotation.tailrec
import sttp.client3
import sttp.model.Part
import sttp.model.MediaType
import sttp.tapir.FileRange
import java.nio.file.Files
import java.io.FileInputStream


class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, AnyStreams] {
override val streams: AnyStreams = AnyStreams
Expand All @@ -26,7 +33,12 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
case RawBodyType.InputStreamRangeBody => ME.unit(RawValue(InputStreamRange(() => new ByteArrayInputStream(bytes))))
case _: RawBodyType.MultipartBody => ME.error(new UnsupportedOperationException)
}
case _ => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type")
case Right(value) =>
bodyType match {
case mp: RawBodyType.MultipartBody =>
ME.unit(RawValue(extractMultipartParts(value.asInstanceOf[Seq[Part[client3.RequestBody[_]]]], mp)))
case _ => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type")
}
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = body(serverRequest) match {
Expand All @@ -36,7 +48,6 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A

private def sttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request[_, _]]

/** Either bytes or any stream */
private def body(serverRequest: ServerRequest): Either[Array[Byte], Any] = sttpRequest(serverRequest).body match {
case NoBody => Left(Array.emptyByteArray)
case StringBody(s, encoding, _) => Left(s.getBytes(encoding))
Expand All @@ -45,8 +56,7 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
case InputStreamBody(b, _) => Left(toByteArray(b))
case FileBody(f, _) => Left(f.readAsByteArray)
case StreamBody(s) => Right(s)
case MultipartBody(_) =>
throw new IllegalArgumentException("Stub cannot handle multipart bodies")
case MultipartBody(parts) => Right(parts)
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
}

private def toByteArray(is: InputStream): Array[Byte] = {
Expand All @@ -66,4 +76,55 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
transfer()
os.toByteArray
}

private def extractMultipartParts(parts: Seq[Part[client3.RequestBody[_]]], bodyType: RawBodyType.MultipartBody): List[Part[Any]] = {
parts.flatMap { part =>
bodyType.partType(part.name).flatMap { partType =>
extractPartBody(part, partType).map { body =>
Part(
name = part.name,
body = body,
contentType = part.contentType.flatMap(ct => MediaType.parse(ct).toOption),
fileName = part.fileName
)
}
}
}.toList
}

private def extractPartBody[B](part: Part[client3.RequestBody[_]], bodyType: RawBodyType[B]): Option[Any] = {
part.body match {
case ByteArrayBody(b, _) =>
bodyType match {
case RawBodyType.StringBody(charset) => Some(b)
case RawBodyType.ByteArrayBody => Some(b)
case RawBodyType.ByteBufferBody => Some(ByteBuffer.wrap(b))
case RawBodyType.InputStreamBody => Some(new ByteArrayInputStream(b))
case RawBodyType.InputStreamRangeBody => Some(InputStreamRange(() => new ByteArrayInputStream(b)))
case RawBodyType.FileBody => throw new IllegalArgumentException("ByteArray body provided while endpoint accepts FileBody")
case _: RawBodyType.MultipartBody => None
}
case FileBody(f, _) =>
bodyType match {
case RawBodyType.FileBody => Some(FileRange(f.toFile))
case RawBodyType.ByteArrayBody => Some(Files.readAllBytes(f.toPath))
case RawBodyType.ByteBufferBody => Some(ByteBuffer.wrap(Files.readAllBytes(f.toPath)))
case RawBodyType.InputStreamBody => Some(new FileInputStream(f.toFile))
case _ => None
}
case StringBody(s, charset, _) =>
bodyType match {
case RawBodyType.StringBody(_) => Some(s)
case RawBodyType.ByteArrayBody => Some(s.getBytes(charset))
case RawBodyType.ByteBufferBody => Some(ByteBuffer.wrap(s.getBytes(charset)))
case _ => None
}
case InputStreamBody(is, _) =>
bodyType match {
case RawBodyType.InputStreamBody => Some(is)
case _ => None
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
}
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import sttp.tapir.server.interceptor.exception.ExceptionHandler
import sttp.tapir.server.interceptor.reject.RejectHandler
import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor}
import sttp.tapir.server.model.ValuedEndpointOutput
import sttp.tapir.generic.auto._
import sttp.tapir.tests.TestUtil.{readFromFile, writeToFile}
import sttp.model.Part
import sttp.tapir.TapirFile
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt

class TapirStubInterpreterTest extends AnyFlatSpec with Matchers {

Expand Down Expand Up @@ -204,8 +210,136 @@ class TapirStubInterpreterTest extends AnyFlatSpec with Matchers {
response.body shouldBe Left("Internal server error")
response.code shouldBe StatusCode.InternalServerError
}

it should "handle multipart body" in {
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody)
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenEndpoint(e)
.thenRespond("success")
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipartFile("file", writeToFile("file_content"))
)
.send(server)

// then
response.body shouldBe Right("success")
}

it should "correctly process a multipart body" in {
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody)
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub.synchronous)
.whenServerEndpointRunLogic(e.serverLogic((multipartData) => {
val partOpt = multipartData.find(_.name == "name")
val fileOpt = multipartData.find(_.name == "file")

(partOpt, fileOpt) match {
case (Some(part), Some(filePart)) =>
val partData = new String(part.body)
val fileData = new String(filePart.body)
IdMonad.unit(Right("name: " + partData + " file: " + fileData))

case (Some(_), None) =>
IdMonad.unit(Right("File part not found"))

case (None, Some(_)) =>
IdMonad.unit(Right("Part not found"))

case (None, None) =>
IdMonad.unit(Right("Both parts not found"))
}
}))
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipartFile("file", writeToFile("file_content"))
)
.send(server)

// then
response.body shouldBe Right("name: abc file: file_content")
}

it should "correctly handle derived multipart body" in {
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody[MultipartData])
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenServerEndpointRunLogic(e.serverLogic(multipartData => {
val fileContent = Await.result(readFromFile(multipartData.file.body), 3.seconds)
IdMonad.unit(Right("name: " + multipartData.name + " year: " + multipartData.year + " file: " + fileContent))
}))
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipart("year", "2024"),
multipartFile("file", writeToFile("file_content"))
)
.send(server)

// then
response.body shouldBe Right("name: abc year: 2024 file: file_content")
}

it should "throw exception when bytearray body provided while endpoint accepts fileBody" in {
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody[MultipartData])
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenEndpoint(e)
.thenRespond("success")
.backend()

// when
val response = the[IllegalArgumentException] thrownBy sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipart("year", "2024"),
multipart("file", "file_content".getBytes())
)
.send(server)

// then
response.getMessage shouldBe "ByteArray body provided while endpoint accepts FileBody"
}
}

case class MultipartData(name: String, year: Int, file: Part[TapirFile])

object ProductsApi {

val getProduct: Endpoint[Unit, Unit, String, String, Any] = endpoint.get
Expand Down
Loading