Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multipart body support in sttp stub #4117

Merged
merged 9 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
import java.nio.ByteBuffer
import scala.annotation.tailrec
import sttp.client3
import sttp.model.Part
import sttp.model.MediaType
import sttp.tapir.FileRange
import java.nio.file.Files
import java.io.FileInputStream

class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, AnyStreams] {
override val streams: AnyStreams = AnyStreams
Expand All @@ -26,27 +32,32 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
case RawBodyType.InputStreamRangeBody => ME.unit(RawValue(InputStreamRange(() => new ByteArrayInputStream(bytes))))
case _: RawBodyType.MultipartBody => ME.error(new UnsupportedOperationException)
}
case _ => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type")
case Right(parts) =>
bodyType match {
case mp: RawBodyType.MultipartBody => ME.unit(RawValue(extractMultipartParts(parts, mp)))
case _ => throw new IllegalArgumentException(s"Multipart body provided while endpoint accepts raw body type: ${bodyType}")
}
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = body(serverRequest) match {
case Right(stream) => stream
case _ => throw new IllegalArgumentException("Raw body provided while endpoint accepts stream body")
}
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream =
sttpRequest(serverRequest).body match {
case StreamBody(s) => s
case _ => throw new IllegalArgumentException("Raw body provided while endpoint accepts stream body")
}

private def sttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request[_, _]]

/** Either bytes or any stream */
private def body(serverRequest: ServerRequest): Either[Array[Byte], Any] = sttpRequest(serverRequest).body match {
private def body(serverRequest: ServerRequest): Either[Array[Byte], Seq[Part[client3.RequestBody[_]]]] = sttpRequest(
serverRequest
).body match {
case NoBody => Left(Array.emptyByteArray)
case StringBody(s, encoding, _) => Left(s.getBytes(encoding))
case ByteArrayBody(b, _) => Left(b)
case ByteBufferBody(b, _) => Left(b.array())
case InputStreamBody(b, _) => Left(toByteArray(b))
case FileBody(f, _) => Left(f.readAsByteArray)
case StreamBody(s) => Right(s)
case MultipartBody(_) =>
throw new IllegalArgumentException("Stub cannot handle multipart bodies")
case StreamBody(_) => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type")
case MultipartBody(parts) => Right(parts)
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
}

private def toByteArray(is: InputStream): Array[Byte] = {
Expand All @@ -66,4 +77,56 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
transfer()
os.toByteArray
}

private def extractMultipartParts(parts: Seq[Part[client3.RequestBody[_]]], bodyType: RawBodyType.MultipartBody): List[Part[Any]] = {
parts.flatMap { part =>
bodyType.partType(part.name).flatMap { partType =>
val body = extractPartBody(part, partType)
Some(
Part(
name = part.name,
body = body,
contentType = part.contentType.flatMap(ct => MediaType.parse(ct).toOption),
fileName = part.fileName
)
)
}
}.toList
}

private def extractPartBody[B](part: Part[client3.RequestBody[_]], bodyType: RawBodyType[B]): Any = {
part.body match {
case ByteArrayBody(b, _) =>
bodyType match {
case RawBodyType.StringBody(_) => b
case RawBodyType.ByteArrayBody => b
case RawBodyType.ByteBufferBody => ByteBuffer.wrap(b)
case RawBodyType.InputStreamBody => new ByteArrayInputStream(b)
case RawBodyType.InputStreamRangeBody => InputStreamRange(() => new ByteArrayInputStream(b))
case RawBodyType.FileBody => throw new IllegalArgumentException("ByteArray part provided while expecting a File part")
case _: RawBodyType.MultipartBody => throw new IllegalArgumentException("Nested multipart bodies are not allowed")
}
case FileBody(f, _) =>
bodyType match {
case RawBodyType.FileBody => FileRange(f.toFile)
case RawBodyType.ByteArrayBody => Files.readAllBytes(f.toPath)
case RawBodyType.ByteBufferBody => ByteBuffer.wrap(Files.readAllBytes(f.toPath))
case RawBodyType.InputStreamBody => new FileInputStream(f.toFile)
case _ => throw new IllegalArgumentException(s"File part provided, while expecting $bodyType")
}
case StringBody(s, charset, _) =>
bodyType match {
case RawBodyType.StringBody(_) => s
case RawBodyType.ByteArrayBody => s.getBytes(charset)
case RawBodyType.ByteBufferBody => ByteBuffer.wrap(s.getBytes(charset))
case _ => throw new IllegalArgumentException(s"String part provided, while expecting $bodyType")
}
case InputStreamBody(is, _) =>
bodyType match {
case RawBodyType.InputStreamBody => is
case _ => throw new IllegalArgumentException(s"InputStream part provided, while expecting $bodyType")
}
case _ => throw new IllegalArgumentException(s"Unsupported part body type provided: ${part.body}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import sttp.tapir.server.interceptor.exception.ExceptionHandler
import sttp.tapir.server.interceptor.reject.RejectHandler
import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor}
import sttp.tapir.server.model.ValuedEndpointOutput
import sttp.tapir.generic.auto._
import sttp.tapir.tests.TestUtil.{readFromFile, writeToFile}
import sttp.model.Part
import sttp.tapir.TapirFile
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt

class TapirStubInterpreterTest extends AnyFlatSpec with Matchers {

Expand Down Expand Up @@ -204,8 +210,136 @@ class TapirStubInterpreterTest extends AnyFlatSpec with Matchers {
response.body shouldBe Left("Internal server error")
response.code shouldBe StatusCode.InternalServerError
}

it should "handle multipart body" in {
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody)
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenEndpoint(e)
.thenRespond("success")
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipartFile("file", writeToFile("file_content"))
)
.send(server)

// then
response.body shouldBe Right("success")
}

it should "correctly process a multipart body" in {
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody)
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub.synchronous)
.whenServerEndpointRunLogic(e.serverLogic((multipartData) => {
val partOpt = multipartData.find(_.name == "name")
val fileOpt = multipartData.find(_.name == "file")

(partOpt, fileOpt) match {
case (Some(part), Some(filePart)) =>
val partData = new String(part.body)
val fileData = new String(filePart.body)
IdMonad.unit(Right("name: " + partData + " file: " + fileData))

case (Some(_), None) =>
IdMonad.unit(Right("File part not found"))

case (None, Some(_)) =>
IdMonad.unit(Right("Part not found"))

case (None, None) =>
IdMonad.unit(Right("Both parts not found"))
}
}))
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipartFile("file", writeToFile("file_content"))
)
.send(server)

// then
response.body shouldBe Right("name: abc file: file_content")
}

it should "correctly handle derived multipart body" in {
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody[MultipartData])
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenServerEndpointRunLogic(e.serverLogic(multipartData => {
val fileContent = Await.result(readFromFile(multipartData.file.body), 3.seconds)
IdMonad.unit(Right("name: " + multipartData.name + " year: " + multipartData.year + " file: " + fileContent))
}))
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipart("year", "2024"),
multipartFile("file", writeToFile("file_content"))
)
.send(server)

// then
response.body shouldBe Right("name: abc year: 2024 file: file_content")
}

it should "throw exception when bytearray body provided while endpoint accepts fileBody" in {
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody[MultipartData])
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenEndpoint(e)
.thenRespond("success")
.backend()

// when
val response = the[IllegalArgumentException] thrownBy sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc"),
multipart("year", "2024"),
multipart("file", "file_content".getBytes())
)
.send(server)

// then
response.getMessage shouldBe "ByteArray part provided while expecting a File part"
}
}

case class MultipartData(name: String, year: Int, file: Part[TapirFile])

object ProductsApi {

val getProduct: Endpoint[Unit, Unit, String, String, Any] = endpoint.get
Expand Down
Loading