diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.scalafmt.conf b/.scalafmt.conf index d417d12..86c9f3a 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,6 +1,22 @@ version = "3.8.0" runner.dialect = scala213 maxColumn = 120 + +rewrite { + rules = [ + ExpandImportSelectors, + Imports + ] + + imports { + groups = [ + ["[a-z].*"], + ["java\\..*", "scala\\..*"] + ] + sort = original + } +} + fileOverride { "glob:**/fs2/src/**" { runner.dialect = scala213source3 diff --git a/README.md b/README.md index 64acc58..29e23ec 100644 --- a/README.md +++ b/README.md @@ -39,3 +39,27 @@ override def ivyDeps = super.ivyDeps() ++ Agg(ivy"tech.neander::jsonrpclib-fs2:: **/!\ Please be aware that this library is in its early days and offers strictly no guarantee with regards to backward compatibility** See the modules/examples folder. + +## Smithy Integration + +You can now use `jsonrpclib` directly with [Smithy](https://smithy.io/) and [smithy4s](https://disneystreaming.github.io/smithy4s/), enabling type-safe, +schema-first JSON-RPC APIs with minimal boilerplate. + +This integration is supported by the following modules: + +```scala +// Defines the Smithy protocol for JSON-RPC +libraryDependencies += "tech.neander" % "jsonrpclib-smithy" % + +// Provides smithy4s client/server bindings for JSON-RPC +libraryDependencies += "tech.neander" %%% "jsonrpclib-smithy4s" % +``` + +With these modules, you can: + +- Annotate your Smithy operations with `@jsonRpcRequest` or `@jsonRpcNotification` +- Generate client and server interfaces using smithy4s +- Use ClientStub to invoke remote services over JSON-RPC +- Use ServerEndpoints to expose service implementations via a Channel + +This allows you to define your API once in Smithy and interact with it as a fully typed JSON-RPC service—without writing manual encoders, decoders, or dispatch logic. diff --git a/build.sbt b/build.sbt index f13de53..ed731eb 100644 --- a/build.sbt +++ b/build.sbt @@ -16,7 +16,8 @@ inThisBuild( ) val scala213 = "2.13.16" -val scala3 = "3.3.5" +val scala3 = "3.3.6" +val jdkVersion = 11 val allScalaVersions = List(scala213, scala3) val jvmScalaVersions = allScalaVersions val jsScalaVersions = allScalaVersions @@ -24,6 +25,7 @@ val nativeScalaVersions = allScalaVersions val fs2Version = "3.12.0" +ThisBuild / versionScheme := Some("early-semver") ThisBuild / tpolecatOptionsMode := DevMode val commonSettings = Seq( @@ -31,18 +33,29 @@ val commonSettings = Seq( "com.disneystreaming" %%% "weaver-cats" % "0.8.4" % Test ), mimaPreviousArtifacts := Set( - organization.value %%% name.value % "0.0.7" + // organization.value %%% name.value % "0.0.7" ), - scalacOptions += "-java-output-version:8" + scalacOptions ++= { + CrossVersion.partialVersion(scalaVersion.value) match { + case Some((2, _)) => Seq(s"-release:$jdkVersion") + case _ => Seq(s"-java-output-version:$jdkVersion") + } + }, +) + +val commonJvmSettings = Seq( + javacOptions ++= Seq("--release", jdkVersion.toString) ) val core = projectMatrix .in(file("modules") / "core") .jvmPlatform( jvmScalaVersions, - Test / unmanagedSourceDirectories ++= Seq( - (projectMatrixBaseDirectory.value / "src" / "test" / "scalajvm-native").getAbsoluteFile - ) + Seq( + Test / unmanagedSourceDirectories ++= Seq( + (projectMatrixBaseDirectory.value / "src" / "test" / "scalajvm-native").getAbsoluteFile + ) + ) ++ commonJvmSettings ) .jsPlatform(jsScalaVersions) .nativePlatform( @@ -56,13 +69,13 @@ val core = projectMatrix name := "jsonrpclib-core", commonSettings, libraryDependencies ++= Seq( - "com.github.plokhotnyuk.jsoniter-scala" %%% "jsoniter-scala-macros" % "2.30.2" + "com.github.plokhotnyuk.jsoniter-scala" %%% "jsoniter-scala-circe" % "2.30.2" ) ) val fs2 = projectMatrix .in(file("modules") / "fs2") - .jvmPlatform(jvmScalaVersions) + .jvmPlatform(jvmScalaVersions, commonJvmSettings) .jsPlatform(jsScalaVersions) .nativePlatform(nativeScalaVersions) .disablePlugins(AssemblyPlugin) @@ -71,19 +84,97 @@ val fs2 = projectMatrix name := "jsonrpclib-fs2", commonSettings, libraryDependencies ++= Seq( - "co.fs2" %%% "fs2-core" % fs2Version + "co.fs2" %%% "fs2-core" % fs2Version, + "io.circe" %%% "circe-generic" % "0.14.7" % Test ) ) +val smithy = projectMatrix + .in(file("modules") / "smithy") + .jvmPlatform(false) + .disablePlugins(AssemblyPlugin, MimaPlugin) + .enablePlugins(SmithyTraitCodegenPlugin) + .settings( + name := "jsonrpclib-smithy", + commonJvmSettings, + smithyTraitCodegenDependencies := List(Dependencies.alloy.core), + smithyTraitCodegenJavaPackage := "jsonrpclib", + smithyTraitCodegenNamespace := "jsonrpclib" + ) + +val smithyTests = projectMatrix + .in(file("modules/smithy-tests")) + .jvmPlatform(Seq(scala213)) + .dependsOn(smithy) + .settings( + publish / skip := true, + libraryDependencies ++= Seq( + "com.disneystreaming" %%% "weaver-cats" % "0.8.4" % Test + ) + ) + .disablePlugins(MimaPlugin) + +lazy val buildTimeProtocolDependency = + /** By default, smithy4sInternalDependenciesAsJars doesn't contain the jars in the "smithy4s" configuration. We have + * to add them manually - this is the equivalent of a "% Smithy4s"-scoped dependency. + * + * Ideally, this would be + * {{{ + * (Compile / smithy4sInternalDependenciesAsJars) ++= + * Smithy4s / smithy4sInternalDependenciesAsJars).value.map(_.data) + * }}} + * + * but that doesn't work because the Smithy4s configuration doesn't extend from Compile so it doesn't have the + * `internalDependencyAsJars` setting. + */ + Compile / smithy4sInternalDependenciesAsJars ++= + (smithy.jvm(autoScalaLibrary = false) / Compile / fullClasspathAsJars).value.map(_.data) + +val smithy4s = projectMatrix + .in(file("modules") / "smithy4s") + .jvmPlatform(jvmScalaVersions, commonJvmSettings) + .jsPlatform(jsScalaVersions) + .nativePlatform(Seq(scala3)) + .disablePlugins(AssemblyPlugin) + .enablePlugins(Smithy4sCodegenPlugin) + .dependsOn(core) + .settings( + name := "jsonrpclib-smithy4s", + commonSettings, + mimaPreviousArtifacts := Set.empty, + libraryDependencies ++= Seq( + "com.disneystreaming.smithy4s" %%% "smithy4s-json" % smithy4sVersion.value + ), + buildTimeProtocolDependency + ) + +val smithy4sTests = projectMatrix + .in(file("modules") / "smithy4s-tests") + .jvmPlatform(jvmScalaVersions, commonJvmSettings) + .jsPlatform(jsScalaVersions) + .nativePlatform(Seq(scala3)) + .disablePlugins(AssemblyPlugin) + .enablePlugins(Smithy4sCodegenPlugin) + .dependsOn(smithy4s, fs2 % Test) + .settings( + commonSettings, + publish / skip := true, + libraryDependencies ++= Seq( + "io.circe" %%% "circe-generic" % "0.14.7" + ), + buildTimeProtocolDependency + ) + val exampleServer = projectMatrix .in(file("modules") / "examples/server") - .jvmPlatform(List(scala213)) + .jvmPlatform(List(scala213), commonJvmSettings) .dependsOn(fs2) .settings( commonSettings, publish / skip := true, libraryDependencies ++= Seq( - "co.fs2" %%% "fs2-io" % fs2Version + "co.fs2" %%% "fs2-io" % fs2Version, + "io.circe" %%% "circe-generic" % "0.14.7" ) ) .disablePlugins(MimaPlugin) @@ -95,7 +186,7 @@ val exampleClient = projectMatrix Seq( fork := true, envVars += "SERVER_JAR" -> (exampleServer.jvm(scala213) / assembly).value.toString - ) + ) ++ commonJvmSettings ) .disablePlugins(AssemblyPlugin) .dependsOn(fs2) @@ -103,30 +194,85 @@ val exampleClient = projectMatrix commonSettings, publish / skip := true, libraryDependencies ++= Seq( - "co.fs2" %%% "fs2-io" % fs2Version + "co.fs2" %%% "fs2-io" % fs2Version, + "io.circe" %%% "circe-generic" % "0.14.7" ) ) .disablePlugins(MimaPlugin) +val exampleSmithyShared = projectMatrix + .in(file("modules") / "examples/smithyShared") + .jvmPlatform(List(scala213), commonJvmSettings) + .dependsOn(smithy4s, fs2) + .enablePlugins(Smithy4sCodegenPlugin) + .settings( + commonSettings, + publish / skip := true, + buildTimeProtocolDependency + ) + .disablePlugins(MimaPlugin) + +val exampleSmithyServer = projectMatrix + .in(file("modules") / "examples/smithyServer") + .jvmPlatform(List(scala213), commonJvmSettings) + .dependsOn(exampleSmithyShared) + .settings( + commonSettings, + publish / skip := true, + libraryDependencies ++= Seq( + "co.fs2" %%% "fs2-io" % fs2Version + ), + assembly / assemblyMergeStrategy := { + case PathList("META-INF", "smithy", _*) => MergeStrategy.concat + case PathList("jsonrpclib", "package.class") => MergeStrategy.first + case PathList("META-INF", xs @ _*) if xs.nonEmpty => MergeStrategy.discard + case x => MergeStrategy.first + } + ) + .disablePlugins(MimaPlugin) + +val exampleSmithyClient = projectMatrix + .in(file("modules") / "examples/smithyClient") + .jvmPlatform( + List(scala213), + Seq( + fork := true, + envVars += "SERVER_JAR" -> (exampleSmithyServer.jvm(scala213) / assembly).value.toString + ) ++ commonJvmSettings + ) + .dependsOn(exampleSmithyShared) + .settings( + commonSettings, + publish / skip := true, + libraryDependencies ++= Seq( + "co.fs2" %%% "fs2-io" % fs2Version + ) + ) + .disablePlugins(MimaPlugin, AssemblyPlugin) + val root = project .in(file(".")) .settings( publish / skip := true ) .disablePlugins(MimaPlugin, AssemblyPlugin) - .aggregate(List(core, fs2, exampleServer, exampleClient).flatMap(_.projectRefs): _*) - -// The core compiles are a workaround for https://github.com/plokhotnyuk/jsoniter-scala/issues/564 -// when we switch to SN 0.5, we can use `makeWithSkipNestedOptionValues` instead: https://github.com/plokhotnyuk/jsoniter-scala/issues/564#issuecomment-2787096068 -val compileCoreModules = { - for { - scalaVersionSuffix <- List("", "3") - platformSuffix <- List("", "JS", "Native") - task <- List("compile", "package") - } yield s"core$platformSuffix$scalaVersionSuffix/$task" -}.mkString(";") + .aggregate( + List( + core, + fs2, + exampleServer, + exampleClient, + smithy, + smithyTests, + smithy4s, + smithy4sTests, + exampleSmithyShared, + exampleSmithyServer, + exampleSmithyClient + ).flatMap(_.projectRefs): _* + ) addCommandAlias( "ci", - s"$compileCoreModules;test;scalafmtCheckAll;mimaReportBinaryIssues" + s"compile;test;scalafmtCheckAll;mimaReportBinaryIssues" ) diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..330a04d --- /dev/null +++ b/flake.lock @@ -0,0 +1,82 @@ +{ + "nodes": { + "flake-parts": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib" + }, + "locked": { + "lastModified": 1743550720, + "narHash": "sha256-hIshGgKZCgWh6AYJpJmRgFdR3WUbkY04o82X05xqQiY=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "c621e8422220273271f52058f618c94e405bb0f5", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1746422338, + "narHash": "sha256-NTtKOTLQv6dPfRe00OGSywg37A1FYqldS6xiNmqBUYc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "5b35d248e9206c1f3baf8de6a7683fee126364aa", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-lib": { + "locked": { + "lastModified": 1743296961, + "narHash": "sha256-b1EdN3cULCqtorQ4QeWgLMrd5ZGOjLSLemfa00heasc=", + "owner": "nix-community", + "repo": "nixpkgs.lib", + "rev": "e4822aea2a6d1cdd36653c134cacfd64c97ff4fa", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixpkgs.lib", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-parts": "flake-parts", + "nixpkgs": "nixpkgs", + "treefmt-nix": "treefmt-nix" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1746216483, + "narHash": "sha256-4h3s1L/kKqt3gMDcVfN8/4v2jqHrgLIe4qok4ApH5x4=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "29ec5026372e0dec56f890e50dbe4f45930320fd", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..a7e5619 --- /dev/null +++ b/flake.nix @@ -0,0 +1,40 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11"; + flake-parts = { + url = "github:hercules-ci/flake-parts"; + }; + treefmt-nix = { + url = "github:numtide/treefmt-nix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = inputs@{ nixpkgs, flake-parts, ... }: + flake-parts.lib.mkFlake { inherit inputs; } { + systems = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin" ]; + imports = [ + inputs.treefmt-nix.flakeModule + ]; + perSystem = { system, config, pkgs, ... }: + { + devShells.default = pkgs.mkShell { + packages = [ pkgs.openjdk21 pkgs.scalafmt pkgs.sbt pkgs.clang pkgs.glibc.dev pkgs.nodejs_23]; + inputsFrom = [ + config.treefmt.build.devShell + ]; + }; + + treefmt.config = { + projectRootFile = "flake.nix"; + + programs = { + nixpkgs-fmt.enable = true; + scalafmt.enable = true; + }; + settings.formatter.scalafmt.include = [ "*.scala" "*.sc" ]; + }; + }; + }; +} + diff --git a/modules/core/src/main/scala/jsonrpclib/CallId.scala b/modules/core/src/main/scala/jsonrpclib/CallId.scala index c70fd9f..918145a 100644 --- a/modules/core/src/main/scala/jsonrpclib/CallId.scala +++ b/modules/core/src/main/scala/jsonrpclib/CallId.scala @@ -1,7 +1,8 @@ package jsonrpclib -import com.github.plokhotnyuk.jsoniter_scala.core._ -import scala.annotation.switch +import io.circe.Codec +import io.circe.Decoder +import io.circe.Json sealed trait CallId object CallId { @@ -9,24 +10,17 @@ object CallId { final case class StringId(string: String) extends CallId case object NullId extends CallId - implicit val callIdRW: JsonValueCodec[CallId] = new JsonValueCodec[CallId] { - def decodeValue(in: JsonReader, default: CallId): CallId = { - val nt = in.nextToken() - - (nt: @switch) match { - case 'n' => in.readNullOrError(default, "expected null") - case '"' => in.rollbackToken(); StringId(in.readString(null)) - case _ => in.rollbackToken(); NumberId(in.readLong()) - - } + implicit val codec: Codec[CallId] = Codec.from( + Decoder + .decodeOption(Decoder.decodeString.map(StringId(_): CallId).or(Decoder.decodeLong.map(NumberId(_): CallId))) + .map { + case None => NullId + case Some(v) => v + }, + { + case NumberId(n) => Json.fromLong(n) + case StringId(str) => Json.fromString(str) + case NullId => Json.Null } - - def encodeValue(x: CallId, out: JsonWriter): Unit = x match { - case NumberId(long) => out.writeVal(long) - case StringId(string) => out.writeVal(string) - case NullId => out.writeNull() - } - - def nullValue: CallId = CallId.NullId - } + ) } diff --git a/modules/core/src/main/scala/jsonrpclib/Channel.scala b/modules/core/src/main/scala/jsonrpclib/Channel.scala index 6efcda6..3d75d96 100644 --- a/modules/core/src/main/scala/jsonrpclib/Channel.scala +++ b/modules/core/src/main/scala/jsonrpclib/Channel.scala @@ -1,12 +1,16 @@ package jsonrpclib +import io.circe.Decoder +import io.circe.Encoder +import jsonrpclib.ErrorCodec.errorPayloadCodec + trait Channel[F[_]] { def mountEndpoint(endpoint: Endpoint[F]): F[Unit] def unmountEndpoint(method: String): F[Unit] - def notificationStub[In: Codec](method: String): In => F[Unit] - def simpleStub[In: Codec, Out: Codec](method: String): In => F[Out] - def stub[In: Codec, Err: ErrorCodec, Out: Codec](method: String): In => F[Either[Err, Out]] + def notificationStub[In: Encoder](method: String): In => F[Unit] + def simpleStub[In: Encoder, Out: Decoder](method: String): In => F[Out] + def stub[In: Encoder, Err: ErrorDecoder, Out: Decoder](method: String): In => F[Either[Err, Out]] def stub[In, Err, Out](template: StubTemplate[In, Err, Out]): In => F[Either[Err, Out]] } @@ -25,7 +29,7 @@ object Channel { (in: In) => F.doFlatMap(stub(in))(unit => F.doPure(Right(unit))) } - final def simpleStub[In: Codec, Out: Codec](method: String): In => F[Out] = { + final def simpleStub[In: Encoder, Out: Decoder](method: String): In => F[Out] = { val s = stub[In, ErrorPayload, Out](method) (in: In) => F.doFlatMap(s(in)) { diff --git a/modules/core/src/main/scala/jsonrpclib/Codec.scala b/modules/core/src/main/scala/jsonrpclib/Codec.scala deleted file mode 100644 index 1cc3059..0000000 --- a/modules/core/src/main/scala/jsonrpclib/Codec.scala +++ /dev/null @@ -1,35 +0,0 @@ -package jsonrpclib - -import com.github.plokhotnyuk.jsoniter_scala.core._ - -trait Codec[A] { - - def encode(a: A): Payload - def decode(payload: Option[Payload]): Either[ProtocolError, A] - -} - -object Codec { - - def encode[A](a: A)(implicit codec: Codec[A]): Payload = codec.encode(a) - def decode[A](payload: Option[Payload])(implicit codec: Codec[A]): Either[ProtocolError, A] = codec.decode(payload) - - implicit def fromJsonCodec[A](implicit jsonCodec: JsonValueCodec[A]): Codec[A] = new Codec[A] { - def encode(a: A): Payload = { - Payload(writeToArray(a)) - } - - def decode(payload: Option[Payload]): Either[ProtocolError, A] = { - try { - payload match { - case Some(Payload.Data(payload)) => Right(readFromArray(payload)) - case Some(Payload.NullPayload) => Right(readFromArray(nullArray)) - case None => Left(ProtocolError.ParseError("Expected to decode a payload")) - } - } catch { case e: JsonReaderException => Left(ProtocolError.ParseError(e.getMessage())) } - } - } - - private val nullArray = "null".getBytes() - -} diff --git a/modules/core/src/main/scala/jsonrpclib/Endpoint.scala b/modules/core/src/main/scala/jsonrpclib/Endpoint.scala index f46267c..9a228c9 100644 --- a/modules/core/src/main/scala/jsonrpclib/Endpoint.scala +++ b/modules/core/src/main/scala/jsonrpclib/Endpoint.scala @@ -1,7 +1,32 @@ package jsonrpclib +import io.circe.Decoder +import io.circe.Encoder +import jsonrpclib.ErrorCodec.errorPayloadCodec + +/** Represents a JSON-RPC method handler that can be invoked by the server. + * + * An `Endpoint[F]` defines how to decode input from a JSON-RPC message, execute some effectful logic, and optionally + * return a response. + * + * The endpoint's `method` field is used to match incoming JSON-RPC requests. + */ sealed trait Endpoint[F[_]] { + + /** The JSON-RPC method name this endpoint responds to. Used for dispatching incoming requests. */ def method: String + + /** Transforms the effect type of this endpoint using the provided `PolyFunction`. + * + * This allows reinterpreting the endpoint’s logic in a different effect context (e.g., from `IO` to `Kleisli[IO, + * Ctx, *]`, or from `F` to `EitherT[F, E, *]`). + * + * @param f + * A polymorphic function that transforms `F[_]` into `G[_]` + * @return + * A new `Endpoint[G]` with the same behavior but in a new effect type + */ + def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] } object Endpoint { @@ -14,17 +39,17 @@ object Endpoint { class PartiallyAppliedEndpoint[F[_]](method: MethodPattern) { def apply[In, Err, Out]( run: In => F[Either[Err, Out]] - )(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] = - RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errCodec, outCodec) + )(implicit inCodec: Decoder[In], errEncoder: ErrorEncoder[Err], outCodec: Encoder[Out]): Endpoint[F] = + RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errEncoder, outCodec) def full[In, Err, Out]( run: (InputMessage, In) => F[Either[Err, Out]] - )(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] = - RequestResponseEndpoint(method, run, inCodec, errCodec, outCodec) + )(implicit inCodec: Decoder[In], errEncoder: ErrorEncoder[Err], outCodec: Encoder[Out]): Endpoint[F] = + RequestResponseEndpoint(method, run, inCodec, errEncoder, outCodec) def simple[In, Out]( run: In => F[Out] - )(implicit F: Monadic[F], inCodec: Codec[In], outCodec: Codec[Out]) = + )(implicit F: Monadic[F], inCodec: Decoder[In], outCodec: Encoder[Out]) = apply[In, ErrorPayload, Out](in => F.doFlatMap(F.doAttempt(run(in))) { case Left(error) => F.doPure(Left(ErrorPayload(0, error.getMessage(), None))) @@ -32,26 +57,33 @@ object Endpoint { } ) - def notification[In](run: In => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] = + def notification[In](run: In => F[Unit])(implicit inCodec: Decoder[In]): Endpoint[F] = NotificationEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec) - def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] = + def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Decoder[In]): Endpoint[F] = NotificationEndpoint(method, run, inCodec) } - final case class NotificationEndpoint[F[_], In]( + private[jsonrpclib] final case class NotificationEndpoint[F[_], In]( method: MethodPattern, run: (InputMessage, In) => F[Unit], - inCodec: Codec[In] - ) extends Endpoint[F] + inCodec: Decoder[In] + ) extends Endpoint[F] { - final case class RequestResponseEndpoint[F[_], In, Err, Out]( + def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] = + NotificationEndpoint[G, In](method, (msg, in) => f(run(msg, in)), inCodec) + } + + private[jsonrpclib] final case class RequestResponseEndpoint[F[_], In, Err, Out]( method: Method, run: (InputMessage, In) => F[Either[Err, Out]], - inCodec: Codec[In], - errCodec: ErrorCodec[Err], - outCodec: Codec[Out] - ) extends Endpoint[F] + inCodec: Decoder[In], + errEncoder: ErrorEncoder[Err], + outCodec: Encoder[Out] + ) extends Endpoint[F] { + def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] = + RequestResponseEndpoint[G, In, Err, Out](method, (msg, in) => f(run(msg, in)), inCodec, errEncoder, outCodec) + } } diff --git a/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala b/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala index f150c1f..8af58e9 100644 --- a/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala +++ b/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala @@ -1,12 +1,15 @@ package jsonrpclib -trait ErrorCodec[E] { - +trait ErrorEncoder[E] { def encode(a: E): ErrorPayload - def decode(error: ErrorPayload): Either[ProtocolError, E] +} +trait ErrorDecoder[E] { + def decode(error: ErrorPayload): Either[ProtocolError, E] } +trait ErrorCodec[E] extends ErrorDecoder[E] with ErrorEncoder[E] + object ErrorCodec { implicit val errorPayloadCodec: ErrorCodec[ErrorPayload] = new ErrorCodec[ErrorPayload] { def encode(a: ErrorPayload): ErrorPayload = a diff --git a/modules/core/src/main/scala/jsonrpclib/ErrorPayload.scala b/modules/core/src/main/scala/jsonrpclib/ErrorPayload.scala index b0a2cc3..b40ab72 100644 --- a/modules/core/src/main/scala/jsonrpclib/ErrorPayload.scala +++ b/modules/core/src/main/scala/jsonrpclib/ErrorPayload.scala @@ -1,7 +1,7 @@ package jsonrpclib -import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec -import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker +import io.circe.Decoder +import io.circe.Encoder case class ErrorPayload(code: Int, message: String, data: Option[Payload]) extends Throwable { override def getMessage(): String = s"JsonRPC Error $code: $message" @@ -9,7 +9,9 @@ case class ErrorPayload(code: Int, message: String, data: Option[Payload]) exten object ErrorPayload { - implicit val rawMessageStubJsonValueCodecs: JsonValueCodec[ErrorPayload] = - JsonCodecMaker.make + implicit val errorPayloadEncoder: Encoder[ErrorPayload] = + Encoder.forProduct3("code", "message", "data")(e => (e.code, e.message, e.data)) + implicit val errorPayloadDecoder: Decoder[ErrorPayload] = + Decoder.forProduct3("code", "message", "data")(ErrorPayload.apply) } diff --git a/modules/core/src/main/scala/jsonrpclib/Message.scala b/modules/core/src/main/scala/jsonrpclib/Message.scala index 10d50fa..e4364b4 100644 --- a/modules/core/src/main/scala/jsonrpclib/Message.scala +++ b/modules/core/src/main/scala/jsonrpclib/Message.scala @@ -1,42 +1,42 @@ package jsonrpclib -import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader -import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec -import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter +import io.circe.syntax._ +import io.circe.Codec sealed trait Message { def maybeCallId: Option[CallId] } sealed trait InputMessage extends Message { def method: String } sealed trait OutputMessage extends Message { - def callId: CallId; final override def maybeCallId: Option[CallId] = Some(callId) + def callId: CallId + final override def maybeCallId: Option[CallId] = Some(callId) } object InputMessage { case class RequestMessage(method: String, callId: CallId, params: Option[Payload]) extends InputMessage { def maybeCallId: Option[CallId] = Some(callId) } + case class NotificationMessage(method: String, params: Option[Payload]) extends InputMessage { def maybeCallId: Option[CallId] = None } + } + object OutputMessage { def errorFrom(callId: CallId, protocolError: ProtocolError): OutputMessage = ErrorMessage(callId, ErrorPayload(protocolError.code, protocolError.getMessage(), None)) case class ErrorMessage(callId: CallId, payload: ErrorPayload) extends OutputMessage case class ResponseMessage(callId: CallId, data: Payload) extends OutputMessage + } object Message { + import jsonrpclib.internals.RawMessage - implicit val messageJsonValueCodecs: JsonValueCodec[Message] = new JsonValueCodec[Message] { - val rawMessageCodec = implicitly[JsonValueCodec[internals.RawMessage]] - def decodeValue(in: JsonReader, default: Message): Message = - rawMessageCodec.decodeValue(in, null).toMessage match { - case Left(error) => throw error - case Right(value) => value - } - def encodeValue(x: Message, out: JsonWriter): Unit = - rawMessageCodec.encodeValue(internals.RawMessage.from(x), out) - def nullValue: Message = null - } + implicit val codec: Codec[Message] = Codec.from( + { c => + c.as[RawMessage].flatMap(_.toMessage.left.map(e => io.circe.DecodingFailure(e.getMessage, c.history))) + }, + RawMessage.from(_).asJson + ) } diff --git a/modules/core/src/main/scala/jsonrpclib/Monadic.scala b/modules/core/src/main/scala/jsonrpclib/Monadic.scala index 5168dd5..2acf020 100644 --- a/modules/core/src/main/scala/jsonrpclib/Monadic.scala +++ b/modules/core/src/main/scala/jsonrpclib/Monadic.scala @@ -1,16 +1,20 @@ package jsonrpclib -import scala.concurrent.Future import scala.concurrent.ExecutionContext +import scala.concurrent.Future trait Monadic[F[_]] { def doFlatMap[A, B](fa: F[A])(f: A => F[B]): F[B] def doPure[A](a: A): F[A] def doAttempt[A](fa: F[A]): F[Either[Throwable, A]] def doRaiseError[A](e: Throwable): F[A] + def doMap[A, B](fa: F[A])(f: A => B): F[B] = doFlatMap(fa)(a => doPure(f(a))) + def doVoid[A](fa: F[A]): F[Unit] = doMap(fa)(_ => ()) } object Monadic { + def apply[F[_]](implicit F: Monadic[F]): Monadic[F] = F + implicit def monadicFuture(implicit ec: ExecutionContext): Monadic[Future] = new Monadic[Future] { def doFlatMap[A, B](fa: Future[A])(f: A => Future[B]): Future[B] = fa.flatMap(f) @@ -19,5 +23,22 @@ object Monadic { def doAttempt[A](fa: Future[A]): Future[Either[Throwable, A]] = fa.map(Right(_)).recover(Left(_)) def doRaiseError[A](e: Throwable): Future[A] = Future.failed(e) + + override def doMap[A, B](fa: Future[A])(f: A => B): Future[B] = fa.map(f) + } + + object syntax { + implicit class MonadicOps[F[_], A](private val fa: F[A]) extends AnyVal { + def flatMap[B](f: A => F[B])(implicit m: Monadic[F]): F[B] = m.doFlatMap(fa)(f) + def map[B](f: A => B)(implicit m: Monadic[F]): F[B] = m.doMap(fa)(f) + def attempt(implicit m: Monadic[F]): F[Either[Throwable, A]] = m.doAttempt(fa) + def void(implicit m: Monadic[F]): F[Unit] = m.doVoid(fa) + } + implicit class MonadicOpsPure[A](private val a: A) extends AnyVal { + def pure[F[_]](implicit m: Monadic[F]): F[A] = m.doPure(a) + } + implicit class MonadicOpsThrowable(private val t: Throwable) extends AnyVal { + def raiseError[F[_], A](implicit m: Monadic[F]): F[A] = m.doRaiseError(t) + } } } diff --git a/modules/core/src/main/scala/jsonrpclib/Payload.scala b/modules/core/src/main/scala/jsonrpclib/Payload.scala index a423c2b..ba2c19a 100644 --- a/modules/core/src/main/scala/jsonrpclib/Payload.scala +++ b/modules/core/src/main/scala/jsonrpclib/Payload.scala @@ -1,50 +1,17 @@ package jsonrpclib -import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader -import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec -import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter +import io.circe.Decoder +import io.circe.Encoder +import io.circe.Json -import java.util.Base64 -import jsonrpclib.Payload.Data -import jsonrpclib.Payload.NullPayload - -sealed trait Payload extends Product with Serializable { - def stripNull: Option[Payload.Data] = this match { - case d @ Data(_) => Some(d) - case NullPayload => None - } +case class Payload(data: Json) { + def stripNull: Option[Payload] = Option(Payload(data)).filter(p => !p.data.isNull) } object Payload { - def apply(value: Array[Byte]) = { - if (value == null) NullPayload - else Data(value) - } - final case class Data(array: Array[Byte]) extends Payload { - override def equals(other: Any) = other match { - case bytes: Data => java.util.Arrays.equals(array, bytes.array) - case _ => false - } - - override lazy val hashCode: Int = java.util.Arrays.hashCode(array) - - override def toString = Base64.getEncoder.encodeToString(array) - } - - case object NullPayload extends Payload - - implicit val payloadJsonValueCodec: JsonValueCodec[Payload] = new JsonValueCodec[Payload] { - def decodeValue(in: JsonReader, default: Payload): Payload = { - Data(in.readRawValAsBytes()) - } - - def encodeValue(bytes: Payload, out: JsonWriter): Unit = - bytes match { - case Data(array) => out.writeRawVal(array) - case NullPayload => out.writeNull() - } + val NullPayload: Payload = Payload(Json.Null) - def nullValue: Payload = null - } + implicit val payloadEncoder: Encoder[Payload] = Encoder[Json].contramap(_.data) + implicit val payloadDecoder: Decoder[Payload] = Decoder[Json].map(Payload(_)) } diff --git a/modules/core/src/main/scala/jsonrpclib/PolyFunction.scala b/modules/core/src/main/scala/jsonrpclib/PolyFunction.scala new file mode 100644 index 0000000..637c50f --- /dev/null +++ b/modules/core/src/main/scala/jsonrpclib/PolyFunction.scala @@ -0,0 +1,12 @@ +package jsonrpclib + +/** A polymorphic natural transformation from `F[_]` to `G[_]`. + * + * @tparam F + * Source effect type + * @tparam G + * Target effect type + */ +trait PolyFunction[F[_], G[_]] { self => + def apply[A0](fa: => F[A0]): G[A0] +} diff --git a/modules/core/src/main/scala/jsonrpclib/StubTemplate.scala b/modules/core/src/main/scala/jsonrpclib/StubTemplate.scala index 36f0a17..17491e3 100644 --- a/modules/core/src/main/scala/jsonrpclib/StubTemplate.scala +++ b/modules/core/src/main/scala/jsonrpclib/StubTemplate.scala @@ -1,5 +1,7 @@ package jsonrpclib +import io.circe.Codec + sealed trait StubTemplate[In, Err, Out] object StubTemplate { def notification[In](method: String)(implicit inCodec: Codec[In]): StubTemplate[In, Nothing, Unit] = diff --git a/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala b/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala index cb73e08..c5e3df3 100644 --- a/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala +++ b/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala @@ -1,5 +1,6 @@ package jsonrpclib +import io.circe.Encoder import jsonrpclib.internals._ import java.util.concurrent.atomic.AtomicLong @@ -25,7 +26,7 @@ abstract class FutureBasedChannel(endpoints: List[Endpoint[Future]])(implicit ec protected def getEndpoint(method: String): Future[Option[Endpoint[Future]]] = Future.successful(endpointsMap.get(method)) protected def sendMessage(message: Message): Future[Unit] = { - sendPayload(Codec.encode(message)).map(_ => ()) + sendPayload(Payload(Encoder[Message].apply(message))).map(_ => ()) } protected def nextCallId(): Future[CallId] = Future.successful(CallId.NumberId(nextID.incrementAndGet())) diff --git a/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala b/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala index 6042597..fdda32b 100644 --- a/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala +++ b/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala @@ -1,10 +1,14 @@ package jsonrpclib package internals +import io.circe.Decoder +import io.circe.Encoder +import io.circe.HCursor import jsonrpclib.Endpoint.NotificationEndpoint import jsonrpclib.Endpoint.RequestResponseEndpoint import jsonrpclib.OutputMessage.ErrorMessage import jsonrpclib.OutputMessage.ResponseMessage + import scala.util.Try private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F]) extends Channel.MonadicChannel[F] { @@ -20,21 +24,21 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F protected def storePendingCall(callId: CallId, handle: OutputMessage => F[Unit]): F[Unit] protected def removePendingCall(callId: CallId): F[Option[OutputMessage => F[Unit]]] - def notificationStub[In](method: String)(implicit inCodec: Codec[In]): In => F[Unit] = { (input: In) => - val encoded = inCodec.encode(input) - val message = InputMessage.NotificationMessage(method, Some(encoded)) + def notificationStub[In](method: String)(implicit inCodec: Encoder[In]): In => F[Unit] = { (input: In) => + val encoded = inCodec(input) + val message = InputMessage.NotificationMessage(method, Some(Payload(encoded))) sendMessage(message) } def stub[In, Err, Out]( method: String - )(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): In => F[Either[Err, Out]] = { + )(implicit inCodec: Encoder[In], errDecoder: ErrorDecoder[Err], outCodec: Decoder[Out]): In => F[Either[Err, Out]] = { (input: In) => - val encoded = inCodec.encode(input) + val encoded = inCodec(input) doFlatMap(nextCallId()) { callId => - val message = InputMessage.RequestMessage(method, callId, Some(encoded)) + val message = InputMessage.RequestMessage(method, callId, Some(Payload(encoded))) doFlatMap(createPromise[Either[Err, Out]](callId)) { case (fulfill, future) => - val pc = createPendingCall(errCodec, outCodec, fulfill) + val pc = createPendingCall(errDecoder, outCodec, fulfill) doFlatMap(storePendingCall(callId, pc))(_ => doFlatMap(sendMessage(message))(_ => future())) } } @@ -70,25 +74,37 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F private def executeInputMessage(input: InputMessage, endpoint: Endpoint[F]): F[Unit] = { (input, endpoint) match { - case (InputMessage.NotificationMessage(_, params), ep: NotificationEndpoint[F, in]) => - ep.inCodec.decode(params) match { + case (InputMessage.NotificationMessage(_, Some(params)), ep: NotificationEndpoint[F, in]) => + ep.inCodec(HCursor.fromJson(params.data)) match { case Right(value) => ep.run(input, value) - case Left(value) => reportError(params, value, ep.method) + case Left(value) => reportError(Some(params), ProtocolError.ParseError(value.getMessage), ep.method) } - case (InputMessage.RequestMessage(_, callId, params), ep: RequestResponseEndpoint[F, in, err, out]) => - ep.inCodec.decode(params) match { + case (InputMessage.RequestMessage(_, callId, Some(params)), ep: RequestResponseEndpoint[F, in, err, out]) => + ep.inCodec(HCursor.fromJson(params.data)) match { case Right(value) => - doFlatMap(ep.run(input, value)) { - case Right(data) => - val responseData = ep.outCodec.encode(data) - sendMessage(OutputMessage.ResponseMessage(callId, responseData)) - case Left(error) => - val errorPayload = ep.errCodec.encode(error) + doFlatMap(doAttempt(ep.run(input, value))) { + case Right(Right(data)) => + val responseData = ep.outCodec(data) + sendMessage(OutputMessage.ResponseMessage(callId, Payload(responseData))) + case Right(Left(error)) => + val errorPayload = ep.errEncoder.encode(error) sendMessage(OutputMessage.ErrorMessage(callId, errorPayload)) + case Left(err) => + sendMessage( + OutputMessage.ErrorMessage(callId, ErrorPayload(0, s"ServerInternalError: ${err.getMessage}", None)) + ) } case Left(pError) => - sendProtocolError(callId, pError) + sendProtocolError(callId, ProtocolError.ParseError(pError.getMessage)) } + case (InputMessage.NotificationMessage(_, None), _: NotificationEndpoint[F, in]) => + val message = "Missing payload" + val pError = ProtocolError.InvalidRequest(message) + sendProtocolError(pError) + case (InputMessage.RequestMessage(_, _, None), _: RequestResponseEndpoint[F, in, err, out]) => + val message = "Missing payload" + val pError = ProtocolError.InvalidRequest(message) + sendProtocolError(pError) case (InputMessage.NotificationMessage(_, _), ep: RequestResponseEndpoint[F, in, err, out]) => val message = s"This ${ep.method} endpoint cannot process notifications, request is missing callId" val pError = ProtocolError.InvalidRequest(message) @@ -101,18 +117,18 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F } private def createPendingCall[Err, Out]( - errCodec: ErrorCodec[Err], - outCodec: Codec[Out], + errDecoder: ErrorDecoder[Err], + outCodec: Decoder[Out], fulfill: Try[Either[Err, Out]] => F[Unit] ): OutputMessage => F[Unit] = { (message: OutputMessage) => message match { case ErrorMessage(_, errorPayload) => - errCodec.decode(errorPayload) match { + errDecoder.decode(errorPayload) match { case Left(_) => fulfill(scala.util.Failure(errorPayload)) case Right(value) => fulfill(scala.util.Success(Left(value))) } - case ResponseMessage(_, data) => - outCodec.decode(Some(data)) match { + case ResponseMessage(_, payload) => + outCodec(HCursor.fromJson(payload.data)) match { case Left(decodeError) => fulfill(scala.util.Failure(decodeError)) case Right(value) => fulfill(scala.util.Success(Right(value))) } diff --git a/modules/core/src/main/scala/jsonrpclib/internals/RawMessage.scala b/modules/core/src/main/scala/jsonrpclib/internals/RawMessage.scala index 7738dd3..1b40daf 100644 --- a/modules/core/src/main/scala/jsonrpclib/internals/RawMessage.scala +++ b/modules/core/src/main/scala/jsonrpclib/internals/RawMessage.scala @@ -1,9 +1,10 @@ package jsonrpclib package internals -import com.github.plokhotnyuk.jsoniter_scala.core._ -import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker -import com.github.plokhotnyuk.jsoniter_scala.macros.CodecMakerConfig +import io.circe.syntax._ +import io.circe.Decoder +import io.circe.Encoder +import io.circe.Json private[jsonrpclib] case class RawMessage( jsonrpc: String, @@ -44,7 +45,8 @@ private[jsonrpclib] object RawMessage { val `2.0` = "2.0" def from(message: Message): RawMessage = message match { - case InputMessage.NotificationMessage(method, params) => RawMessage(`2.0`, method = Some(method), params = params) + case InputMessage.NotificationMessage(method, params) => + RawMessage(`2.0`, method = Some(method), params = params) case InputMessage.RequestMessage(method, callId, params) => RawMessage(`2.0`, method = Some(method), params = params, id = Some(callId)) case OutputMessage.ErrorMessage(callId, errorPayload) => @@ -53,7 +55,39 @@ private[jsonrpclib] object RawMessage { RawMessage(`2.0`, result = Some(data.stripNull), id = Some(callId)) } - implicit val rawMessageJsonValueCodecs: JsonValueCodec[RawMessage] = - JsonCodecMaker.make(CodecMakerConfig.withSkipNestedOptionValues(true)) + // Custom encoder to flatten nested Option[Option[Payload]] + implicit val rawMessageEncoder: Encoder[RawMessage] = { msg => + Json + .obj( + List( + "jsonrpc" -> msg.jsonrpc.asJson, + "method" -> msg.method.asJson, + "params" -> msg.params.asJson, + "error" -> msg.error.asJson, + "id" -> msg.id.asJson + ).filterNot(_._2.isNull) ++ { + msg.result match { + case Some(Some(payload)) => List("result" -> payload.asJson) + case Some(None) => List("result" -> Json.Null) + case None => Nil + } + }: _* + ) + } + // Custom decoder to wrap result into Option[Option[Payload]] + implicit val rawMessageDecoder: Decoder[RawMessage] = Decoder.instance { c => + for { + jsonrpc <- c.downField("jsonrpc").as[String] + method <- c.downField("method").as[Option[String]] + params <- c.downField("params").as[Option[Payload]] + error <- c.downField("error").as[Option[ErrorPayload]] + id <- c.downField("id").as[Option[CallId]] + resultOpt <- + c.downField("result") + .success + .map(_.as[Option[Payload]].map(Some(_))) + .getOrElse(Right(None)) + } yield RawMessage(jsonrpc, method, resultOpt, error, params, id) + } } diff --git a/modules/core/src/main/scala/jsonrpclib/package.scala b/modules/core/src/main/scala/jsonrpclib/package.scala index 9093575..5c0f070 100644 --- a/modules/core/src/main/scala/jsonrpclib/package.scala +++ b/modules/core/src/main/scala/jsonrpclib/package.scala @@ -2,5 +2,4 @@ package object jsonrpclib { type ErrorCode = Int type ErrorMessage = String - } diff --git a/modules/core/src/test/scala/jsonrpclib/CallIdSpec.scala b/modules/core/src/test/scala/jsonrpclib/CallIdSpec.scala index b227173..2116150 100644 --- a/modules/core/src/test/scala/jsonrpclib/CallIdSpec.scala +++ b/modules/core/src/test/scala/jsonrpclib/CallIdSpec.scala @@ -1,7 +1,10 @@ package jsonrpclib -import weaver._ +import cats.syntax.all._ +import com.github.plokhotnyuk.jsoniter_scala.circe.JsoniterScalaCodec._ import com.github.plokhotnyuk.jsoniter_scala.core._ +import io.circe.Json +import weaver._ object CallIdSpec extends FunSuite { test("json parsing") { @@ -12,9 +15,9 @@ object CallIdSpec extends FunSuite { val longJson = Long.MaxValue.toString val nullJson = "null" - assert.same(readFromString[CallId](strJson), CallId.StringId("25")) && - assert.same(readFromString[CallId](intJson), CallId.NumberId(25)) && - assert.same(readFromString[CallId](longJson), CallId.NumberId(Long.MaxValue)) && - assert.same(readFromString[CallId](nullJson), CallId.NullId) + assert.same(readFromString[Json](strJson).as[CallId], CallId.StringId("25").asRight) && + assert.same(readFromString[Json](intJson).as[CallId], CallId.NumberId(25).asRight) && + assert.same(readFromString[Json](longJson).as[CallId], CallId.NumberId(Long.MaxValue).asRight) && + assert.same(readFromString[Json](nullJson).as[CallId], CallId.NullId.asRight) } } diff --git a/modules/core/src/test/scala/jsonrpclib/RawMessageSpec.scala b/modules/core/src/test/scala/jsonrpclib/RawMessageSpec.scala index be5e41a..0771b79 100644 --- a/modules/core/src/test/scala/jsonrpclib/RawMessageSpec.scala +++ b/modules/core/src/test/scala/jsonrpclib/RawMessageSpec.scala @@ -1,19 +1,25 @@ package jsonrpclib -import weaver._ -import jsonrpclib.internals._ +import com.github.plokhotnyuk.jsoniter_scala.circe.JsoniterScalaCodec._ import com.github.plokhotnyuk.jsoniter_scala.core._ +import io.circe.syntax._ +import io.circe.Json +import jsonrpclib.internals._ import jsonrpclib.CallId.NumberId import jsonrpclib.OutputMessage.ResponseMessage +import weaver._ object RawMessageSpec extends FunSuite { test("json parsing with null result") { // This is a perfectly valid response object, as result field has to be present, // but can be null: https://www.jsonrpc.org/specification#response_object - val rawMessage = readFromString[RawMessage](""" {"jsonrpc":"2.0","result":null,"id":3} """.trim) + val rawMessage = readFromString[Json](""" {"jsonrpc":"2.0","id":3,"result":null}""".trim) + .as[RawMessage] + .fold(throw _, identity) // This, on the other hand, is an invalid response message, as result field is missing - val invalidRawMessage = readFromString[RawMessage](""" {"jsonrpc":"2.0","id":3} """.trim) + val invalidRawMessage = + readFromString[Json](""" {"jsonrpc":"2.0","id":3} """.trim).as[RawMessage].fold(throw _, identity) assert.same( rawMessage, @@ -26,4 +32,49 @@ object RawMessageSpec extends FunSuite { ) && assert(invalidRawMessage.toMessage.isLeft, invalidRawMessage.toMessage.toString) } + + test("request message serialization") { + val input: Message = InputMessage.RequestMessage("my/method", CallId.NumberId(1), None) + val expected = """{"jsonrpc":"2.0","method":"my/method","id":1}""" + val result = writeToString(input.asJson) + + assert(result == expected, s"Expected: $expected, got: $result") + } + + test("notification message serialization") { + val input: Message = InputMessage.NotificationMessage("my/method", None) + val expected = """{"jsonrpc":"2.0","method":"my/method"}""" + val result = writeToString(input.asJson) + + assert(result == expected, s"Expected: $expected, got: $result") + } + + test("response message serialization") { + val input: Message = OutputMessage.ResponseMessage(CallId.NumberId(1), Payload.NullPayload) + val expected = """{"jsonrpc":"2.0","id":1,"result":null}""" + val result = writeToString(input.asJson) + + assert(result == expected, s"Expected: $expected, got: $result") + } + + test("response message serialization with nested results") { + val input: Message = + OutputMessage.ResponseMessage(CallId.NumberId(1), Payload(Json.obj("result" -> Json.fromInt(1)))) + val expected = """{"jsonrpc":"2.0","id":1,"result":{"result":1}}""" + val result = writeToString(input.asJson) + + assert(result == expected, s"Expected: $expected, got: $result") + } + + test("error message serialization") { + val input: Message = OutputMessage.ErrorMessage( + CallId.NumberId(1), + ErrorPayload(-32603, "Internal error", None) + ) + val expected = """{"jsonrpc":"2.0","error":{"code":-32603,"message":"Internal error","data":null},"id":1}""" + val result = writeToString(input.asJson) + + assert(result == expected, s"Expected: $expected, got: $result") + } + } diff --git a/modules/core/src/test/scalajvm-native/jsonrpclib/internals/HeaderSpec.scala b/modules/core/src/test/scalajvm-native/jsonrpclib/internals/HeaderSpec.scala index e58bd59..88d629e 100644 --- a/modules/core/src/test/scalajvm-native/jsonrpclib/internals/HeaderSpec.scala +++ b/modules/core/src/test/scalajvm-native/jsonrpclib/internals/HeaderSpec.scala @@ -1,11 +1,12 @@ package jsonrpclib.internals +import jsonrpclib.ProtocolError import weaver._ -import java.io.ByteArrayInputStream + import java.io.BufferedReader -import java.io.InputStreamReader -import jsonrpclib.ProtocolError +import java.io.ByteArrayInputStream import java.io.IOException +import java.io.InputStreamReader import java.io.UncheckedIOException object HeaderSpec extends FunSuite { diff --git a/modules/examples/client/src/main/scala/examples/client/ClientMain.scala b/modules/examples/client/src/main/scala/examples/client/ClientMain.scala index 5097f2d..a424ad5 100644 --- a/modules/examples/client/src/main/scala/examples/client/ClientMain.scala +++ b/modules/examples/client/src/main/scala/examples/client/ClientMain.scala @@ -2,13 +2,13 @@ package examples.client import cats.effect._ import cats.syntax.all._ -import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec -import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker -import fs2.Stream import fs2.io._ import fs2.io.process.Processes -import jsonrpclib.CallId +import fs2.Stream +import io.circe.generic.semiauto._ +import io.circe.Codec import jsonrpclib.fs2._ +import jsonrpclib.CallId object ClientMain extends IOApp.Simple { @@ -18,7 +18,7 @@ object ClientMain extends IOApp.Simple { // Creating a datatype that'll serve as a request (and response) of an endpoint case class IntWrapper(value: Int) object IntWrapper { - implicit val jcodec: JsonValueCodec[IntWrapper] = JsonCodecMaker.make + implicit val codec: Codec[IntWrapper] = deriveCodec } type IOStream[A] = fs2.Stream[IO, A] @@ -32,7 +32,7 @@ object ClientMain extends IOApp.Simple { // Starting the server rp <- fs2.Stream.resource(Processes[IO].spawn(process.ProcessBuilder("java", "-jar", serverJar))) // Creating a channel that will be used to communicate to the server - fs2Channel <- FS2Channel[IO](cancelTemplate = cancelEndpoint.some) + fs2Channel <- FS2Channel.stream[IO](cancelTemplate = cancelEndpoint.some) _ <- Stream(()) .concurrently(fs2Channel.output.through(lsp.encodeMessages).through(rp.stdin)) .concurrently(rp.stdout.through(lsp.decodeMessages).through(fs2Channel.inputOrBounce)) diff --git a/modules/examples/server/src/main/scala/examples/server/ServerMain.scala b/modules/examples/server/src/main/scala/examples/server/ServerMain.scala index 72c9804..2704b5b 100644 --- a/modules/examples/server/src/main/scala/examples/server/ServerMain.scala +++ b/modules/examples/server/src/main/scala/examples/server/ServerMain.scala @@ -1,11 +1,11 @@ package examples.server -import jsonrpclib.CallId -import jsonrpclib.fs2._ import cats.effect._ import fs2.io._ -import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec -import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker +import io.circe.generic.semiauto._ +import io.circe.Codec +import jsonrpclib.fs2._ +import jsonrpclib.CallId import jsonrpclib.Endpoint object ServerMain extends IOApp.Simple { @@ -16,7 +16,7 @@ object ServerMain extends IOApp.Simple { // Creating a datatype that'll serve as a request (and response) of an endpoint case class IntWrapper(value: Int) object IntWrapper { - implicit val jcodec: JsonValueCodec[IntWrapper] = JsonCodecMaker.make + implicit val codec: Codec[IntWrapper] = deriveCodec } // Implementing an incrementation endpoint @@ -28,7 +28,8 @@ object ServerMain extends IOApp.Simple { def run: IO[Unit] = { // Using errorln as stdout is used by the RPC channel IO.consoleForIO.errorln("Starting server") >> - FS2Channel[IO](cancelTemplate = Some(cancelEndpoint)) + FS2Channel + .stream[IO](cancelTemplate = Some(cancelEndpoint)) .flatMap(_.withEndpointStream(increment)) // mounting an endpoint onto the channel .flatMap(channel => fs2.Stream diff --git a/modules/examples/smithyClient/src/main/scala/examples/smithy/client/ClientMain.scala b/modules/examples/smithyClient/src/main/scala/examples/smithy/client/ClientMain.scala new file mode 100644 index 0000000..06016f4 --- /dev/null +++ b/modules/examples/smithyClient/src/main/scala/examples/smithy/client/ClientMain.scala @@ -0,0 +1,61 @@ +package examples.smithy.client + +import cats.effect._ +import cats.syntax.all._ +import fs2.io.process.Processes +import fs2.Stream +import jsonrpclib.fs2._ +import jsonrpclib.smithy4sinterop.ClientStub +import jsonrpclib.smithy4sinterop.ServerEndpoints +import jsonrpclib.CallId +import test._ + +object SmithyClientMain extends IOApp.Simple { + + // Reserving a method for cancelation. + val cancelEndpoint = CancelTemplate.make[CallId]("$/cancel", identity, identity) + + type IOStream[A] = fs2.Stream[IO, A] + def log(str: String): IOStream[Unit] = Stream.eval(IO.consoleForIO.errorln(str)) + + // Implementing the generated interface + object Client extends TestClient[IO] { + def pong(pong: String): IO[Unit] = IO.consoleForIO.errorln(s"Client received pong: $pong") + } + + def run: IO[Unit] = { + val run = for { + //////////////////////////////////////////////////////// + /////// BOOTSTRAPPING + //////////////////////////////////////////////////////// + _ <- log("Starting client") + serverJar <- sys.env.get("SERVER_JAR").liftTo[IOStream](new Exception("SERVER_JAR env var does not exist")) + // Starting the server + rp <- Stream.resource( + Processes[IO] + .spawn( + fs2.io.process.ProcessBuilder("java", "-jar", serverJar) + ) + ) + // Creating a channel that will be used to communicate to the server + fs2Channel <- FS2Channel.stream[IO](cancelTemplate = cancelEndpoint.some) + // Mounting our implementation of the generated interface onto the channel + _ <- fs2Channel.withEndpointsStream(ServerEndpoints(Client)) + // Creating stubs to talk to the remote server + server: TestServer[IO] = ClientStub(test.TestServer, fs2Channel) + _ <- Stream(()) + .concurrently(fs2Channel.output.through(lsp.encodeMessages).through(rp.stdin)) + .concurrently(rp.stdout.through(lsp.decodeMessages).through(fs2Channel.inputOrBounce)) + .concurrently(rp.stderr.through(fs2.io.stderr[IO])) + + //////////////////////////////////////////////////////// + /////// INTERACTION + //////////////////////////////////////////////////////// + result1 <- Stream.eval(server.greet("Client")) + _ <- log(s"Client received $result1") + _ <- Stream.eval(server.ping("Ping")) + } yield () + run.compile.drain.guarantee(IO.consoleForIO.errorln("Terminating client")) + } + +} diff --git a/modules/examples/smithyServer/src/main/scala/examples/smithy/server/ServerMain.scala b/modules/examples/smithyServer/src/main/scala/examples/smithy/server/ServerMain.scala new file mode 100644 index 0000000..d410ad3 --- /dev/null +++ b/modules/examples/smithyServer/src/main/scala/examples/smithy/server/ServerMain.scala @@ -0,0 +1,44 @@ +package examples.smithy.server + +import cats.effect._ +import fs2.io._ +import jsonrpclib.fs2._ +import jsonrpclib.smithy4sinterop.ClientStub +import jsonrpclib.smithy4sinterop.ServerEndpoints +import jsonrpclib.CallId +import test._ // smithy4s-generated package + +object ServerMain extends IOApp.Simple { + + // Reserving a method for cancelation. + val cancelEndpoint = CancelTemplate.make[CallId]("$/cancel", identity, identity) + + // Implementing the generated interface + class ServerImpl(client: TestClient[IO]) extends TestServer[IO] { + def greet(name: String): IO[GreetOutput] = IO.pure(GreetOutput(s"Server says: hello $name !")) + + def ping(ping: String): IO[Unit] = client.pong(s"Returned to sender: $ping") + } + + def printErr(s: String): IO[Unit] = IO.consoleForIO.errorln(s) + + def run: IO[Unit] = { + val run = + FS2Channel + .stream[IO](cancelTemplate = Some(cancelEndpoint)) + .flatMap { channel => + val testClient = ClientStub(TestClient, channel) + channel.withEndpointsStream(ServerEndpoints(new ServerImpl(testClient))) + } + .flatMap { channel => + fs2.Stream + .eval(IO.never) // running the server forever + .concurrently(stdin[IO](512).through(lsp.decodeMessages).through(channel.inputOrBounce)) + .concurrently(channel.output.through(lsp.encodeMessages).through(stdout[IO])) + } + + // Using errorln as stdout is used by the RPC channel + printErr("Starting server") >> run.compile.drain.guarantee(printErr("Terminating server")) + } + +} diff --git a/modules/examples/smithyShared/src/main/smithy/spec.smithy b/modules/examples/smithyShared/src/main/smithy/spec.smithy new file mode 100644 index 0000000..905eb1d --- /dev/null +++ b/modules/examples/smithyShared/src/main/smithy/spec.smithy @@ -0,0 +1,45 @@ +$version: "2.0" + +namespace test + +use jsonrpclib#jsonRpcRequest +use jsonrpclib#jsonRpc +use jsonrpclib#jsonRpcNotification + +@jsonRpc +service TestServer { + operations: [Greet, Ping] +} + +@jsonRpc +service TestClient { + operations: [Pong] +} + +@jsonRpcRequest("greet") +operation Greet { + input := { + @required + name: String + } + output := { + @required + message: String + } +} + +@jsonRpcNotification("ping") +operation Ping { + input := { + @required + ping: String + } +} + +@jsonRpcNotification("pong") +operation Pong { + input := { + @required + pong: String + } +} diff --git a/modules/fs2/src/main/scala/jsonrpclib/fs2/CancelTemplate.scala b/modules/fs2/src/main/scala/jsonrpclib/fs2/CancelTemplate.scala index ed0c426..cf904e1 100644 --- a/modules/fs2/src/main/scala/jsonrpclib/fs2/CancelTemplate.scala +++ b/modules/fs2/src/main/scala/jsonrpclib/fs2/CancelTemplate.scala @@ -1,6 +1,6 @@ package jsonrpclib.fs2 -import jsonrpclib.Codec +import io.circe.Codec import jsonrpclib.CallId /** A cancelation template that represents the RPC method by which cancelation diff --git a/modules/fs2/src/main/scala/jsonrpclib/fs2/FS2Channel.scala b/modules/fs2/src/main/scala/jsonrpclib/fs2/FS2Channel.scala index 811ced9..7468826 100644 --- a/modules/fs2/src/main/scala/jsonrpclib/fs2/FS2Channel.scala +++ b/modules/fs2/src/main/scala/jsonrpclib/fs2/FS2Channel.scala @@ -1,22 +1,31 @@ package jsonrpclib package fs2 -import _root_.fs2.Pipe -import _root_.fs2.Stream +import cats.effect.kernel._ +import cats.effect.std.Supervisor +import cats.effect.syntax.all._ +import cats.effect.Fiber +import cats.syntax.all._ import cats.Applicative import cats.Functor import cats.Monad import cats.MonadThrow -import cats.effect.Fiber -import cats.effect.kernel._ -import cats.effect.std.Supervisor -import cats.syntax.all._ -import cats.effect.syntax.all._ +import io.circe.Codec import jsonrpclib.internals.MessageDispatcher -import scala.util.Try import java.util.regex.Pattern +import scala.util.Try + +import _root_.fs2.Pipe +import _root_.fs2.Stream +/** A JSON-RPC communication channel built on top of `fs2.Stream`. + * + * `FS2Channel[F]` enables streaming JSON-RPC messages into and out of an effectful system. It provides methods to + * register handlers (`Endpoint[F]`) for specific method names. + * + * This is the primary server-side integration point for using JSON-RPC over FS2. + */ trait FS2Channel[F[_]] extends Channel[F] { def input: Pipe[F, Message, Unit] @@ -47,14 +56,28 @@ trait FS2Channel[F[_]] extends Channel[F] { object FS2Channel { - def apply[F[_]: Concurrent]( + /** Creates a new `FS2Channel[F]` as a managed resource with a configurable buffer size for bidirectional message + * processing. + * + * Optionally, a `CancelTemplate` can be provided to support client-initiated cancellation of inflight requests via a + * dedicated cancellation endpoint. + * + * @param bufferSize + * Size of the internal outbound message queue (default: 2048) + * @param cancelTemplate + * Optional handler that defines how to decode and handle cancellation requests + * + * @return + * A `Resource` that manages the lifecycle of the channel and its internal supervisor + */ + def resource[F[_]: Concurrent]( bufferSize: Int = 2048, cancelTemplate: Option[CancelTemplate] = None - ): Stream[F, FS2Channel[F]] = { + ): Resource[F, FS2Channel[F]] = { for { - supervisor <- Stream.resource(Supervisor[F]) - ref <- Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, Vector.empty, 0)).toStream - queue <- cats.effect.std.Queue.bounded[F, Message](bufferSize).toStream + supervisor <- Supervisor[F] + ref <- Resource.eval(Ref[F].of(State[F](Map.empty, Map.empty, Map.empty, Vector.empty, 0))) + queue <- Resource.eval(cats.effect.std.Queue.bounded[F, Message](bufferSize)) impl = new Impl(queue, ref, supervisor, cancelTemplate) // Creating a bespoke endpoint to receive cancelation requests @@ -66,10 +89,21 @@ object FS2Channel { } } // mounting the cancelation endpoint - _ <- maybeCancelEndpoint.traverse_(ep => impl.mountEndpoint(ep)).toStream + _ <- Resource.eval(maybeCancelEndpoint.traverse_(ep => impl.mountEndpoint(ep))) } yield impl } + @deprecated("use stream or resource", "0.0.9") + def apply[F[_]: Concurrent]( + bufferSize: Int = 2048, + cancelTemplate: Option[CancelTemplate] = None + ): Stream[F, FS2Channel[F]] = stream(bufferSize, cancelTemplate) + + def stream[F[_]: Concurrent]( + bufferSize: Int = 2048, + cancelTemplate: Option[CancelTemplate] = None + ): Stream[F, FS2Channel[F]] = Stream.resource(resource(bufferSize, cancelTemplate)) + private case class State[F[_]]( runningCalls: Map[CallId, Fiber[F, Throwable, Unit]], pendingCalls: Map[CallId, OutputMessage => F[Unit]], diff --git a/modules/fs2/src/main/scala/jsonrpclib/fs2/lsp.scala b/modules/fs2/src/main/scala/jsonrpclib/fs2/lsp.scala index 29963a7..03d7f48 100644 --- a/modules/fs2/src/main/scala/jsonrpclib/fs2/lsp.scala +++ b/modules/fs2/src/main/scala/jsonrpclib/fs2/lsp.scala @@ -1,31 +1,34 @@ package jsonrpclib.fs2 import cats.MonadThrow +import com.github.plokhotnyuk.jsoniter_scala.circe.JsoniterScalaCodec._ +import com.github.plokhotnyuk.jsoniter_scala.core._ import fs2.Chunk -import fs2.Stream import fs2.Pipe +import fs2.Stream +import io.circe.Decoder +import io.circe.Encoder +import io.circe.HCursor +import io.circe.Json +import jsonrpclib.Message import jsonrpclib.Payload -import jsonrpclib.Codec +import jsonrpclib.ProtocolError import java.nio.charset.Charset import java.nio.charset.StandardCharsets -import jsonrpclib.Message -import jsonrpclib.ProtocolError -import jsonrpclib.Payload.Data -import jsonrpclib.Payload.NullPayload import scala.annotation.tailrec object lsp { def encodeMessages[F[_]]: Pipe[F, Message, Byte] = - (_: Stream[F, Message]).map(Codec.encode(_)).through(encodePayloads) + (_: Stream[F, Message]).map(Encoder[Message].apply(_)).map(Payload(_)).through(encodePayloads) def encodePayloads[F[_]]: Pipe[F, Payload, Byte] = (_: Stream[F, Payload]).map(writeChunk).flatMap(Stream.chunk(_)) def decodeMessages[F[_]: MonadThrow]: Pipe[F, Byte, Either[ProtocolError, Message]] = (_: Stream[F, Byte]).through(decodePayloads).map { payload => - Codec.decode[Message](Some(payload)) + Decoder[Message].apply(HCursor.fromJson(payload.data)).left.map(e => ProtocolError.ParseError(e.getMessage)) } /** Split a stream of bytes into payloads by extracting each frame based on information contained in the headers. @@ -39,20 +42,16 @@ object lsp { (ns, Chunk(maybeResult)) } .flatMap { - case Right(acc) => Stream.iterable(acc).map(c => Payload(c.toArray)) + case Right(acc) => Stream.iterable(acc).map(c => Payload(readFromArray[Json](c.toArray))) case Left(error) => Stream.raiseError[F](error) } private def writeChunk(payload: Payload): Chunk[Byte] = { - val bytes = payload match { - case Data(array) => array - case NullPayload => nullArray - } + val bytes = writeToArray(payload.data) val header = s"Content-Length: ${bytes.size}" + "\r\n" * 2 Chunk.array(header.getBytes()) ++ Chunk.array(bytes) } - private val nullArray = "null".getBytes() private val returnByte = '\r'.toByte private val newlineByte = '\n'.toByte @@ -138,7 +137,7 @@ object lsp { } continue = false } else { - bb.put(byte) + val _ = bb.put(byte) } } if (newState != null) { diff --git a/modules/fs2/src/main/scala/jsonrpclib/fs2/package.scala b/modules/fs2/src/main/scala/jsonrpclib/fs2/package.scala index c77c114..cdee83e 100644 --- a/modules/fs2/src/main/scala/jsonrpclib/fs2/package.scala +++ b/modules/fs2/src/main/scala/jsonrpclib/fs2/package.scala @@ -1,10 +1,11 @@ package jsonrpclib -import _root_.fs2.Stream -import cats.MonadThrow -import cats.Monad -import cats.effect.kernel.Resource import cats.effect.kernel.MonadCancel +import cats.effect.kernel.Resource +import cats.Monad +import cats.MonadThrow + +import _root_.fs2.Stream package object fs2 { @@ -24,6 +25,10 @@ package object fs2 { def doAttempt[A](fa: F[A]): F[Either[Throwable, A]] = MonadThrow[F].attempt(fa) def doRaiseError[A](e: Throwable): F[A] = MonadThrow[F].raiseError(e) + + override def doMap[A, B](fa: F[A])(f: A => B): F[B] = Monad[F].map(fa)(f) + + override def doVoid[A](fa: F[A]): F[Unit] = Monad[F].void(fa) } } diff --git a/modules/fs2/src/test/scala/jsonrpclib/fs2/FS2ChannelSpec.scala b/modules/fs2/src/test/scala/jsonrpclib/fs2/FS2ChannelSpec.scala index 43b7c60..ffffce5 100644 --- a/modules/fs2/src/test/scala/jsonrpclib/fs2/FS2ChannelSpec.scala +++ b/modules/fs2/src/test/scala/jsonrpclib/fs2/FS2ChannelSpec.scala @@ -2,9 +2,9 @@ package jsonrpclib.fs2 import cats.effect.IO import cats.syntax.all._ -import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec -import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker import fs2.Stream +import io.circe.generic.semiauto._ +import io.circe.Codec import jsonrpclib._ import weaver._ @@ -14,12 +14,12 @@ object FS2ChannelSpec extends SimpleIOSuite { case class IntWrapper(int: Int) object IntWrapper { - implicit val jcodec: JsonValueCodec[IntWrapper] = JsonCodecMaker.make + implicit val codec: Codec[IntWrapper] = deriveCodec } case class CancelRequest(callId: CallId) object CancelRequest { - implicit val jcodec: JsonValueCodec[CancelRequest] = JsonCodecMaker.make + implicit val codec: Codec[CancelRequest] = deriveCodec } def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit = @@ -30,8 +30,8 @@ object FS2ChannelSpec extends SimpleIOSuite { def setup(cancelTemplate: CancelTemplate, endpoints: Endpoint[IO]*) = setupAux(endpoints, Some(cancelTemplate)) def setupAux(endpoints: Seq[Endpoint[IO]], cancelTemplate: Option[CancelTemplate]): Stream[IO, ClientSideChannel] = { for { - serverSideChannel <- FS2Channel[IO](cancelTemplate = cancelTemplate) - clientSideChannel <- FS2Channel[IO](cancelTemplate = cancelTemplate) + serverSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + clientSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) _ <- serverSideChannel.withEndpointsStream(endpoints) _ <- Stream(()) .concurrently(clientSideChannel.output.through(serverSideChannel.input)) diff --git a/modules/smithy-tests/src/test/scala/jsonrpclib/JsonNotificationOutputValidatorSpec.scala b/modules/smithy-tests/src/test/scala/jsonrpclib/JsonNotificationOutputValidatorSpec.scala new file mode 100644 index 0000000..b0c3148 --- /dev/null +++ b/modules/smithy-tests/src/test/scala/jsonrpclib/JsonNotificationOutputValidatorSpec.scala @@ -0,0 +1,57 @@ +package jsonrpclib + +import jsonrpclib.ModelUtils.assembleModel +import jsonrpclib.ModelUtils.eventsWithoutLocations +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.validation.Severity +import software.amazon.smithy.model.validation.ValidationEvent +import weaver._ + +object JsonNotificationOutputValidatorSpec extends FunSuite { + test("no error when a @jsonNotification operation has unit output") { + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpcNotification + | + |@jsonRpcNotification("notify") + |operation NotifySomething { + |} + |""".stripMargin + ) + success + } + test("return an error when a @jsonNotification operation does not have unit output") { + val events = eventsWithoutLocations( + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpcNotification + | + |@jsonRpcNotification("notify") + |operation NotifySomething { + | output:={ + | message: String + | } + |} + | + |""".stripMargin + ) + ) + + val expected = ValidationEvent + .builder() + .id("JsonNotificationOutput") + .shapeId(ShapeId.fromParts("test", "NotifySomething")) + .severity(Severity.ERROR) + .message( + "Operation marked as @jsonRpcNotification must not return anything, but found `test#NotifySomethingOutput`." + ) + .build() + + assert(events.contains(expected)) + } + +} diff --git a/modules/smithy-tests/src/test/scala/jsonrpclib/JsonPayloadValidatorSpec.scala b/modules/smithy-tests/src/test/scala/jsonrpclib/JsonPayloadValidatorSpec.scala new file mode 100644 index 0000000..22b4265 --- /dev/null +++ b/modules/smithy-tests/src/test/scala/jsonrpclib/JsonPayloadValidatorSpec.scala @@ -0,0 +1,99 @@ +package jsonrpclib + +import jsonrpclib.ModelUtils.assembleModel +import jsonrpclib.ModelUtils.eventsWithoutLocations +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.validation.Severity +import software.amazon.smithy.model.validation.ValidationEvent +import weaver._ + +object JsonPayloadValidatorSpec extends FunSuite { + test("no error when jsonRpcPayload is used on the input, output or error structure's member") { + + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcPayload + | + |@jsonRpc + |service MyService { + | operations: [OpA] + |} + | + |@jsonRpcRequest("foo") + |operation OpA { + | input: OpInput + | output: OpOutput + | errors: [OpError] + |} + | + |structure OpInput { + | @jsonRpcPayload + | data: String + |} + | + |structure OpOutput { + | @jsonRpcPayload + | data: String + |} + | + |@error("client") + |structure OpError { + | @jsonRpcPayload + | data: String + |} + | + |""".stripMargin + ).unwrap() + + success + } + test("return an error when jsonRpcPayload is used in a nested structure") { + val events = eventsWithoutLocations( + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcPayload + | + |@jsonRpc + |service MyService { + | operations: [OpA] + |} + | + |@jsonRpcRequest("foo") + |operation OpA { + | input: OpInput + |} + | + |structure OpInput { + | data: NestedStructure + |} + | + |structure NestedStructure { + | @jsonRpcPayload + | data: String + |} + |""".stripMargin + ) + ) + + val expected = ValidationEvent + .builder() + .id("jsonRpcPayload.OnlyTopLevel") + .shapeId(ShapeId.fromParts("test", "NestedStructure", "data")) + .severity(Severity.ERROR) + .message( + "Found an incompatible shape when validating the constraints of the `jsonrpclib#jsonRpcPayload` trait attached to `test#NestedStructure$data`: jsonRpcPayload can only be used on the top level of an operation input/output/error." + ) + .build() + + assert(events.contains(expected)) + } + +} diff --git a/modules/smithy-tests/src/test/scala/jsonrpclib/JsonRpcOperationValidatorSpec.scala b/modules/smithy-tests/src/test/scala/jsonrpclib/JsonRpcOperationValidatorSpec.scala new file mode 100644 index 0000000..9a3e213 --- /dev/null +++ b/modules/smithy-tests/src/test/scala/jsonrpclib/JsonRpcOperationValidatorSpec.scala @@ -0,0 +1,72 @@ +package jsonrpclib + +import jsonrpclib.ModelUtils.assembleModel +import jsonrpclib.ModelUtils.eventsWithoutLocations +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.validation.Severity +import software.amazon.smithy.model.validation.ValidationEvent +import weaver._ + +object JsonRpcOperationValidatorSpec extends FunSuite { + test("no error when all operations in @jsonRpc service are properly annotated") { + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcNotification + | + |@jsonRpc + |service MyService { + | operations: [OpA, OpB] + |} + | + |@jsonRpcRequest("methodA") + |operation OpA {} + | + |@jsonRpcNotification("methodB") + |operation OpB { + | output: unit + |} + |""".stripMargin + ) + success + } + + test("return an error when a @jsonRpc service has an operation without @jsonRpcRequest or @jsonRpcNotification") { + val events = eventsWithoutLocations( + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + | + |@jsonRpc + |service MyService { + | operations: [GoodOp, BadOp] + |} + | + |@jsonRpcRequest("good") + |operation GoodOp {} + | + |operation BadOp {} // ❌ missing jsonRpcRequest or jsonRpcNotification + |""".stripMargin + ) + ) + + val expected = + ValidationEvent + .builder() + .id("JsonRpcOperation") + .shapeId(ShapeId.fromParts("test", "BadOp")) + .severity(Severity.ERROR) + .message( + "Operation is part of service `test#MyService` marked with @jsonRpc but is missing @jsonRpcRequest or @jsonRpcNotification." + ) + .build() + + assert(events.contains(expected)) + } +} diff --git a/modules/smithy-tests/src/test/scala/jsonrpclib/ModelUtils.scala b/modules/smithy-tests/src/test/scala/jsonrpclib/ModelUtils.scala new file mode 100644 index 0000000..07ca7a4 --- /dev/null +++ b/modules/smithy-tests/src/test/scala/jsonrpclib/ModelUtils.scala @@ -0,0 +1,27 @@ +package jsonrpclib + +import software.amazon.smithy.model.validation.ValidatedResult +import software.amazon.smithy.model.validation.ValidationEvent +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.SourceLocation + +import scala.jdk.CollectionConverters._ + +private object ModelUtils { + + def assembleModel(text: String): ValidatedResult[Model] = { + Model + .assembler() + .discoverModels() + .addUnparsedModel( + "test.smithy", + text + ) + .assemble() + } + + def eventsWithoutLocations(result: ValidatedResult[?]): List[ValidationEvent] = { + if (!result.isBroken) sys.error("Expected a broken result") + result.getValidationEvents.asScala.toList.map(e => e.toBuilder.sourceLocation(SourceLocation.NONE).build()) + } +} diff --git a/modules/smithy-tests/src/test/scala/jsonrpclib/UniqueJsonRpcMethodNamesValidatorSpec.scala b/modules/smithy-tests/src/test/scala/jsonrpclib/UniqueJsonRpcMethodNamesValidatorSpec.scala new file mode 100644 index 0000000..327d811 --- /dev/null +++ b/modules/smithy-tests/src/test/scala/jsonrpclib/UniqueJsonRpcMethodNamesValidatorSpec.scala @@ -0,0 +1,123 @@ +package jsonrpclib + +import jsonrpclib.ModelUtils.assembleModel +import jsonrpclib.ModelUtils.eventsWithoutLocations +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.validation.Severity +import software.amazon.smithy.model.validation.ValidationEvent +import weaver._ + +object UniqueJsonRpcMethodNamesValidatorSpec extends FunSuite { + test("no error when all jsonRpc method names are unique within a service") { + + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcNotification + | + |@jsonRpc + |service MyService { + | operations: [OpA, OpB] + |} + | + |@jsonRpcRequest("foo") + |operation OpA {} + | + |@jsonRpcNotification("bar") + |operation OpB {} + |""".stripMargin + ).unwrap() + + success + } + test("return an error when two operations use the same jsonRpc method name in a service") { + val events = eventsWithoutLocations( + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcNotification + | + |@jsonRpc + |service MyService { + | operations: [OpA, OpB] + |} + | + |@jsonRpcRequest("foo") + |operation OpA {} + | + |@jsonRpcNotification("foo") + |operation OpB {} // duplicate method name "foo" + |""".stripMargin + ) + ) + + val expected = ValidationEvent + .builder() + .id("UniqueJsonRpcMethodNames") + .shapeId(ShapeId.fromParts("test", "MyService")) + .severity(Severity.ERROR) + .message( + "Duplicate JSON-RPC method name `foo` in service `test#MyService`. It is used by: test#OpA, test#OpB" + ) + .build() + + assert(events.contains(expected)) + } + + test("no error if two services use the same operation") { + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpc + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcNotification + | + |@jsonRpc + |service MyService { + | operations: [OpA] + |} + | + |@jsonRpc + |service MyOtherService { + | operations: [OpA] + |} + | + |@jsonRpcRequest("foo") + |operation OpA {} + | + |""".stripMargin + ).unwrap() + success + } + + test("no error if two services use the same operation") { + assembleModel( + """$version: "2" + |namespace test + | + |use jsonrpclib#jsonRpcRequest + |use jsonrpclib#jsonRpcNotification + | + | + |service NonJsonRpcService { + | operations: [OpA] + |} + | + |@jsonRpcRequest("foo") + |operation OpA {} + | + |@jsonRpcNotification("foo") + |operation OpB {} // duplicate method name "foo" + |""".stripMargin + ).unwrap() + success + } + +} diff --git a/modules/smithy/src/main/java/jsonrpclib/validation/JsonNotificationOutputValidator.java b/modules/smithy/src/main/java/jsonrpclib/validation/JsonNotificationOutputValidator.java new file mode 100644 index 0000000..6fffed0 --- /dev/null +++ b/modules/smithy/src/main/java/jsonrpclib/validation/JsonNotificationOutputValidator.java @@ -0,0 +1,32 @@ +package jsonrpclib.validation; + +import jsonrpclib.JsonRpcNotificationTrait; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Validates that operations marked with @jsonNotification don't have any + * output. + */ +public class JsonNotificationOutputValidator extends AbstractValidator { + + @Override + public List validate(Model model) { + return model.getShapesWithTrait(JsonRpcNotificationTrait.ID).stream().flatMap(op -> { + ShapeId outputShapeId = op.asOperationShape().orElseThrow().getOutputShape(); + var outputShape = model.expectShape(outputShapeId); + if (outputShape.asStructureShape().map(s -> !s.members().isEmpty()).orElse(true)) { + return Stream.of(error(op, String.format( + "Operation marked as @jsonRpcNotification must not return anything, but found `%s`.", outputShapeId))); + } else { + return Stream.empty(); + } + }).collect(Collectors.toUnmodifiableList()); + } +} diff --git a/modules/smithy/src/main/java/jsonrpclib/validation/JsonRpcOperationValidator.java b/modules/smithy/src/main/java/jsonrpclib/validation/JsonRpcOperationValidator.java new file mode 100644 index 0000000..d4f594f --- /dev/null +++ b/modules/smithy/src/main/java/jsonrpclib/validation/JsonRpcOperationValidator.java @@ -0,0 +1,37 @@ +package jsonrpclib.validation; + +import jsonrpclib.JsonRpcNotificationTrait; +import jsonrpclib.JsonRpcTrait; +import jsonrpclib.JsonRpcRequestTrait; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class JsonRpcOperationValidator extends AbstractValidator { + + @Override + public List validate(Model model) { + return model.getServiceShapesWithTrait(JsonRpcTrait.class).stream() + .flatMap(service -> validateService(model, service)) + .collect(Collectors.toList()); + } + + private Stream validateService(Model model, ServiceShape service) { + return service.getAllOperations().stream() + .map(model::expectShape) + .filter(op -> !hasJsonRpcMethod(op)) + .map(op -> error(op, String.format( + "Operation is part of service `%s` marked with @jsonRpc but is missing @jsonRpcRequest or @jsonRpcNotification.", service.getId()))); + } + + private boolean hasJsonRpcMethod(Shape op) { + return op.hasTrait(JsonRpcRequestTrait.ID) || op.hasTrait(JsonRpcNotificationTrait.ID); + } +} + diff --git a/modules/smithy/src/main/java/jsonrpclib/validation/UniqueJsonRpcMethodNamesValidator.java b/modules/smithy/src/main/java/jsonrpclib/validation/UniqueJsonRpcMethodNamesValidator.java new file mode 100644 index 0000000..64179ed --- /dev/null +++ b/modules/smithy/src/main/java/jsonrpclib/validation/UniqueJsonRpcMethodNamesValidator.java @@ -0,0 +1,61 @@ +package jsonrpclib.validation; + +import jsonrpclib.JsonRpcNotificationTrait; +import jsonrpclib.JsonRpcTrait; +import jsonrpclib.JsonRpcRequestTrait; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.traits.StringTrait; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class UniqueJsonRpcMethodNamesValidator extends AbstractValidator { + + @Override + public List validate(Model model) { + return model.getServiceShapesWithTrait(JsonRpcTrait.class).stream() + .flatMap(service -> validateService(service, model)) + .collect(Collectors.toList()); + } + + private Stream validateService(ServiceShape service, Model model) { + Map> methodsToOps = service.getAllOperations().stream() + .map(model::expectShape) + .map(shape -> shape.asOperationShape().orElseThrow()) + .flatMap(op -> getJsonRpcMethodName(op).map(name -> Map.entry(name, op)).stream()) + .collect(Collectors.groupingBy( + Map.Entry::getKey, + Collectors.mapping(Map.Entry::getValue, Collectors.toList()) + )); + + // Emit a validation error for each method name that occurs more than once + return methodsToOps.entrySet().stream() + .filter(entry -> entry.getValue().size() > 1) + .flatMap(entry -> entry.getValue().stream() + .map(op -> + error(service, String.format( + "Duplicate JSON-RPC method name `%s` in service `%s`. It is used by: %s", + entry.getKey(), + service.getId(), + entry.getValue().stream() + .map(OperationShape::getId) + .map(Object::toString) + .collect(Collectors.joining(", ")) + ))) + ); + } + + private Optional getJsonRpcMethodName(OperationShape operation) { + return operation.getTrait(JsonRpcRequestTrait.class) + .map(StringTrait::getValue) + .or(() -> operation.getTrait(JsonRpcNotificationTrait.class).map(StringTrait::getValue)); + } +} + diff --git a/modules/smithy/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator b/modules/smithy/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator new file mode 100644 index 0000000..712ec39 --- /dev/null +++ b/modules/smithy/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator @@ -0,0 +1,3 @@ +jsonrpclib.validation.JsonNotificationOutputValidator +jsonrpclib.validation.UniqueJsonRpcMethodNamesValidator +jsonrpclib.validation.JsonRpcOperationValidator diff --git a/modules/smithy/src/main/resources/META-INF/smithy/jsonrpclib.smithy b/modules/smithy/src/main/resources/META-INF/smithy/jsonrpclib.smithy new file mode 100644 index 0000000..6303ce1 --- /dev/null +++ b/modules/smithy/src/main/resources/META-INF/smithy/jsonrpclib.smithy @@ -0,0 +1,47 @@ +$version: "2.0" + +namespace jsonrpclib + +/// the JSON-RPC protocol, +/// see https://www.jsonrpc.org/specification +@protocolDefinition(traits: [ + jsonRpcRequest + jsonRpcNotification + jsonRpcPayload + smithy.api#jsonName + smithy.api#length + smithy.api#pattern + smithy.api#range + smithy.api#required + smithy.api#timestampFormat + alloy#uuidFormat + alloy#discriminated + alloy#nullable + alloy#untagged +]) +@trait(selector: "service") +structure jsonRpc { +} + +/// Identifies an operation that abides by request/response semantics +/// https://www.jsonrpc.org/specification#request_object +@trait(selector: "operation", conflicts: [jsonRpcNotification]) +string jsonRpcRequest + +/// Identifies an operation that abides by fire-and-forget semantics +/// see https://www.jsonrpc.org/specification#notification +@trait(selector: "operation", conflicts: [jsonRpcRequest]) +string jsonRpcNotification + + +/// Binds a single structure member to the payload of a jsonrpc message. +/// Just like @httpPayload, but for jsonRpc. +@trait(selector: "structure > member", structurallyExclusive: "member") +@traitValidators({ + "jsonRpcPayload.OnlyTopLevel": { + message: "jsonRpcPayload can only be used on the top level of an operation input/output/error.", + severity: "ERROR", + selector: "$allowedShapes(:root(operation -[input, output, error]-> structure > member)) :not(:in(${allowedShapes}))" + } +}) +structure jsonRpcPayload {} diff --git a/modules/smithy/src/main/resources/META-INF/smithy/manifest b/modules/smithy/src/main/resources/META-INF/smithy/manifest new file mode 100644 index 0000000..94839e2 --- /dev/null +++ b/modules/smithy/src/main/resources/META-INF/smithy/manifest @@ -0,0 +1 @@ +jsonrpclib.smithy diff --git a/modules/smithy4s-tests/src/main/smithy/spec.smithy b/modules/smithy4s-tests/src/main/smithy/spec.smithy new file mode 100644 index 0000000..9d55f5b --- /dev/null +++ b/modules/smithy4s-tests/src/main/smithy/spec.smithy @@ -0,0 +1,100 @@ +$version: "2.0" + +namespace test + +use jsonrpclib#jsonRpcNotification +use jsonrpclib#jsonRpc +use jsonrpclib#jsonRpcRequest +use jsonrpclib#jsonRpcPayload + +@jsonRpc +service TestServer { + operations: [Greet, Ping] +} + +@jsonRpc +service TestClient { + operations: [Pong] +} + +@jsonRpcRequest("greet") +operation Greet { + input := { + @required + name: String + } + output := { + @required + message: String + } + errors: [NotWelcomeError] +} + + +@jsonRpc +service TestServerWithPayload { + operations: [GreetWithPayload] +} + +@jsonRpcRequest("greetWithPayload") +operation GreetWithPayload { + input := { + @required + @jsonRpcPayload + payload: GreetInputPayload + } + output := { + @required + @jsonRpcPayload + payload: GreetOutputPayload + } +} + +structure GreetInputPayload { + @required + name: String +} + +structure GreetOutputPayload { + @required + message: String +} + +@error("client") +structure NotWelcomeError { + @required + msg: String +} + +@jsonRpcNotification("ping") +operation Ping { + input := { + @required + ping: String + } +} + +@jsonRpcNotification("pong") +operation Pong { + input := { + @required + pong: String + } +} + +@jsonRpc +service WeatherService { + operations: [GetWeather] +} + +@jsonRpcRequest("getWeather") +operation GetWeather { + input := { + @required + city: String + } + output := { + @required + weather: String + } +} diff --git a/modules/smithy4s-tests/src/test/scala/jsonrpclib/smithy4sinterop/TestClientSpec.scala b/modules/smithy4s-tests/src/test/scala/jsonrpclib/smithy4sinterop/TestClientSpec.scala new file mode 100644 index 0000000..eddedb0 --- /dev/null +++ b/modules/smithy4s-tests/src/test/scala/jsonrpclib/smithy4sinterop/TestClientSpec.scala @@ -0,0 +1,125 @@ +package jsonrpclib.smithy4sinterop + +import cats.effect.IO +import cats.syntax.all._ +import fs2.Stream +import io.circe.Decoder +import io.circe.Encoder +import jsonrpclib._ +import jsonrpclib.fs2._ +import test._ +import test.TestServerOperation.GreetError +import weaver._ + +import scala.concurrent.duration._ + +import _root_.fs2.concurrent.SignallingRef + +object TestClientSpec extends SimpleIOSuite { + def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit = + test(name)(run.compile.lastOrError.timeout(10.second)) + + type ClientSideChannel = FS2Channel[IO] + def setup(endpoints: Endpoint[IO]*) = setupAux(endpoints, None) + def setup(cancelTemplate: CancelTemplate, endpoints: Endpoint[IO]*) = setupAux(endpoints, Some(cancelTemplate)) + def setupAux(endpoints: Seq[Endpoint[IO]], cancelTemplate: Option[CancelTemplate]): Stream[IO, ClientSideChannel] = { + for { + serverSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + clientSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + _ <- serverSideChannel.withEndpointsStream(endpoints) + _ <- Stream(()) + .concurrently(clientSideChannel.output.through(serverSideChannel.input)) + .concurrently(serverSideChannel.output.through(clientSideChannel.input)) + } yield { + clientSideChannel + } + } + + testRes("Round trip") { + implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema + val endpoint: Endpoint[IO] = + Endpoint[IO]("greet").simple[GreetInput, GreetOutput](in => IO(GreetOutput(s"Hello ${in.name}"))) + + for { + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServer, clientSideChannel) + result <- clientStub.greet("Bob").toStream + } yield { + expect.same(result.message, "Hello Bob") + } + } + + testRes("Sending notification") { + implicit val pingInputDecoder: Decoder[PingInput] = CirceJsonCodec.fromSchema + + for { + ref <- SignallingRef[IO, Option[PingInput]](none).toStream + endpoint: Endpoint[IO] = Endpoint[IO]("ping").notification[PingInput](p => ref.set(p.some)) + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServer, clientSideChannel) + _ <- clientStub.ping("hello").toStream + result <- ref.discrete.dropWhile(_.isEmpty).take(1) + } yield { + expect.same(result, Some(PingInput("hello"))) + } + } + + testRes("Round trip with jsonPayload") { + implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema + val endpoint: Endpoint[IO] = + Endpoint[IO]("greetWithPayload").simple[GreetInput, GreetOutput](in => IO(GreetOutput(s"Hello ${in.name}"))) + + for { + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServerWithPayload, clientSideChannel) + result <- clientStub.greetWithPayload(GreetInputPayload("Bob")).toStream + } yield { + expect.same(result.payload.message, "Hello Bob") + } + } + + testRes("server returns known error") { + implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema + implicit val greetErrorEncoder: Encoder[GreetError] = CirceJsonCodec.fromSchema + implicit val errEncoder: ErrorEncoder[GreetError] = + err => ErrorPayload(-1, "error", Some(Payload(greetErrorEncoder(err)))) + + val endpoint: Endpoint[IO] = + Endpoint[IO]("greet").apply[GreetInput, GreetError, GreetOutput](in => + IO.pure(Left(GreetError.notWelcomeError(NotWelcomeError(s"${in.name} is not welcome")))) + ) + + for { + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServer, clientSideChannel) + result <- clientStub.greet("Bob").attempt.toStream + } yield { + matches(result) { case Left(t: NotWelcomeError) => + expect.same(t.msg, s"Bob is not welcome") + } + } + } + + testRes("server returns unknown error") { + implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema + + val endpoint: Endpoint[IO] = + Endpoint[IO]("greet").simple[GreetInput, GreetOutput](_ => IO.raiseError(new RuntimeException("boom!"))) + + for { + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServer, clientSideChannel) + result <- clientStub.greet("Bob").attempt.toStream + } yield { + matches(result) { case Left(t: ErrorPayload) => + expect.same(t.code, 0) && + expect.same(t.message, "boom!") && + expect.same(t.data, None) + } + } + } +} diff --git a/modules/smithy4s-tests/src/test/scala/jsonrpclib/smithy4sinterop/TestServerSpec.scala b/modules/smithy4s-tests/src/test/scala/jsonrpclib/smithy4sinterop/TestServerSpec.scala new file mode 100644 index 0000000..88ae345 --- /dev/null +++ b/modules/smithy4s-tests/src/test/scala/jsonrpclib/smithy4sinterop/TestServerSpec.scala @@ -0,0 +1,253 @@ +package jsonrpclib.smithy4sinterop + +import cats.effect.IO +import cats.syntax.all._ +import fs2.concurrent.SignallingRef +import fs2.Stream +import io.circe.Decoder +import io.circe.Encoder +import jsonrpclib.fs2._ +import jsonrpclib.ErrorPayload +import jsonrpclib.Monadic +import jsonrpclib.Payload +import smithy4s.kinds.FunctorAlgebra +import smithy4s.Service +import test._ +import test.TestServerOperation._ +import weaver._ + +import scala.concurrent.duration._ + +object TestServerSpec extends SimpleIOSuite { + def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit = + test(name)(run.compile.lastOrError.timeout(10.second)) + + type ClientSideChannel = FS2Channel[IO] + + class ServerImpl(client: TestClient[IO]) extends TestServer[IO] { + def greet(name: String): IO[GreetOutput] = IO.pure(GreetOutput(s"Hello $name")) + + def ping(ping: String): IO[Unit] = { + client.pong(s"Returned to sender: $ping") + } + } + + class Client(ref: SignallingRef[IO, Option[String]]) extends TestClient[IO] { + def pong(pong: String): IO[Unit] = ref.set(Some(pong)) + } + + trait AlgebraWrapper { + type Alg[_[_, _, _, _, _]] + + def algebra: FunctorAlgebra[Alg, IO] + def service: Service[Alg] + } + + object AlgebraWrapper { + def apply[A[_[_, _, _, _, _]]](alg: FunctorAlgebra[A, IO])(implicit srv: Service[A]): AlgebraWrapper = + new AlgebraWrapper { + type Alg[F[_, _, _, _, _]] = A[F] + + val algebra = alg + val service = srv + } + } + + def setup(mkServer: FS2Channel[IO] => AlgebraWrapper) = + setupAux(None, mkServer.andThen(Seq(_)), _ => Seq.empty) + + def setup(mkServer: FS2Channel[IO] => AlgebraWrapper, mkClient: FS2Channel[IO] => AlgebraWrapper) = + setupAux(None, mkServer.andThen(Seq(_)), mkClient.andThen(Seq(_))) + + def setup[Alg[_[_, _, _, _, _]]]( + cancelTemplate: CancelTemplate, + mkServer: FS2Channel[IO] => Seq[AlgebraWrapper], + mkClient: FS2Channel[IO] => Seq[AlgebraWrapper] + ) = setupAux(Some(cancelTemplate), mkServer, mkClient) + + def setupAux[Alg[_[_, _, _, _, _]]]( + cancelTemplate: Option[CancelTemplate], + mkServer: FS2Channel[IO] => Seq[AlgebraWrapper], + mkClient: FS2Channel[IO] => Seq[AlgebraWrapper] + ): Stream[IO, ClientSideChannel] = { + for { + serverSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + clientSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + serverChannelWithEndpoints <- serverSideChannel.withEndpointsStream(mkServer(serverSideChannel).flatMap { p => + ServerEndpoints(p.algebra)(p.service, Monadic[IO]) + }) + clientChannelWithEndpoints <- clientSideChannel.withEndpointsStream(mkClient(clientSideChannel).flatMap { p => + ServerEndpoints(p.algebra)(p.service, Monadic[IO]) + }) + _ <- Stream(()) + .concurrently(clientChannelWithEndpoints.output.through(serverChannelWithEndpoints.input)) + .concurrently(serverChannelWithEndpoints.output.through(clientChannelWithEndpoints.input)) + } yield { + clientChannelWithEndpoints + } + } + + testRes("Round trip") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + + for { + clientSideChannel <- setup(channel => { + val testClient = ClientStub(TestClient, channel) + AlgebraWrapper(new ServerImpl(testClient)) + }) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + result <- remoteFunction(GreetInput("Bob")).toStream + } yield { + expect.same(result.message, "Hello Bob") + } + } + + testRes("notification both ways") { + implicit val greetInputEncoder: Encoder[PingInput] = CirceJsonCodec.fromSchema + + for { + ref <- SignallingRef[IO, Option[String]](none).toStream + clientSideChannel <- setup( + channel => { + val testClient = ClientStub(TestClient, channel) + AlgebraWrapper(new ServerImpl(testClient)) + }, + _ => AlgebraWrapper(new Client(ref)) + ) + remoteFunction = clientSideChannel.notificationStub[PingInput]("ping") + _ <- remoteFunction(PingInput("hi server")).toStream + result <- ref.discrete.dropWhile(_.isEmpty).take(1) + } yield { + expect.same(result, "Returned to sender: hi server".some) + } + } + + testRes("internal error when processing notification should not break the server") { + implicit val greetInputEncoder: Encoder[PingInput] = CirceJsonCodec.fromSchema + + for { + ref <- SignallingRef[IO, Option[String]](none).toStream + clientSideChannel <- setup( + channel => { + val testClient = ClientStub(TestClient, channel) + AlgebraWrapper(new TestServer[IO] { + override def greet(name: String): IO[GreetOutput] = ??? + + override def ping(ping: String): IO[Unit] = { + if (ping == "fail") IO.raiseError(new RuntimeException("throwing internal error on demand")) + else testClient.pong("pong") + } + }) + }, + _ => AlgebraWrapper(new Client(ref)) + ) + remoteFunction = clientSideChannel.notificationStub[PingInput]("ping") + _ <- remoteFunction(PingInput("fail")).toStream + _ <- remoteFunction(PingInput("ping")).toStream + result <- ref.discrete.dropWhile(_.isEmpty).take(1) + } yield { + expect.same(result, "pong".some) + } + } + + testRes("server returns known error") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + implicit val greetErrorEncoder: Encoder[TestServerOperation.GreetError] = CirceJsonCodec.fromSchema + + for { + clientSideChannel <- setup(_ => { + AlgebraWrapper(new TestServer[IO] { + override def greet(name: String): IO[GreetOutput] = IO.raiseError(NotWelcomeError(s"$name is not welcome")) + + override def ping(ping: String): IO[Unit] = ??? + }) + }) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + result <- remoteFunction(GreetInput("Alice")).attempt.toStream + } yield { + matches(result) { case Left(t: ErrorPayload) => + expect.same(t.code, 0) && + expect.same(t.message, "test.NotWelcomeError(Alice is not welcome)") && + expect.same( + t.data, + Payload(greetErrorEncoder.apply(GreetError.notWelcomeError(NotWelcomeError(s"Alice is not welcome")))).some + ) + } + } + } + + testRes("server returns unknown error") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + + for { + clientSideChannel <- setup(_ => { + AlgebraWrapper(new TestServer[IO] { + override def greet(name: String): IO[GreetOutput] = IO.raiseError(new RuntimeException("some other error")) + + override def ping(ping: String): IO[Unit] = ??? + }) + }) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + result <- remoteFunction(GreetInput("Alice")).attempt.toStream + } yield { + matches(result) { case Left(t: ErrorPayload) => + expect.same(t.code, 0) && + expect.same(t.message, "ServerInternalError: some other error") && + expect.same(t.data, none) + } + } + } + + testRes("accessing endpoints from multiple servers") { + class WeatherServiceImpl() extends WeatherService[IO] { + override def getWeather(city: String): IO[GetWeatherOutput] = IO(GetWeatherOutput("sunny")) + } + + for { + clientSideChannel <- setupAux( + None, + channel => { + val testClient = ClientStub(TestClient, channel) + Seq(AlgebraWrapper(new ServerImpl(testClient)), AlgebraWrapper(new WeatherServiceImpl())) + }, + _ => Seq.empty + ) + greetResult <- { + implicit val inputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val outputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + val remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + remoteFunction(GreetInput("Bob")).toStream + } + getWeatherResult <- { + implicit val inputEncoder: Encoder[GetWeatherInput] = CirceJsonCodec.fromSchema + implicit val outputDecoder: Decoder[GetWeatherOutput] = CirceJsonCodec.fromSchema + val remoteFunction = clientSideChannel.simpleStub[GetWeatherInput, GetWeatherOutput]("getWeather") + remoteFunction(GetWeatherInput("Warsaw")).toStream + } + } yield { + expect.same(greetResult.message, "Hello Bob") && + expect.same(getWeatherResult.weather, "sunny") + } + } + + testRes("Round trip with jsonPayload") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + + object ServerImpl extends TestServerWithPayload[IO] { + def greetWithPayload(payload: GreetInputPayload): IO[GreetWithPayloadOutput] = + IO.pure(GreetWithPayloadOutput(GreetOutputPayload(s"Hello ${payload.name}"))) + } + + for { + clientSideChannel <- setup(_ => AlgebraWrapper(ServerImpl)) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greetWithPayload") + result <- remoteFunction(GreetInput("Bob")).toStream + } yield { + expect.same(result.message, "Hello Bob") + } + } +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceDecoderImpl.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceDecoderImpl.scala new file mode 100644 index 0000000..ad6c55a --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceDecoderImpl.scala @@ -0,0 +1,47 @@ +package jsonrpclib.smithy4sinterop + +import io.circe.{Decoder => CirceDecoder, _} +import smithy4s.codecs.PayloadPath +import smithy4s.schema.CachedSchemaCompiler +import smithy4s.Document +import smithy4s.Document.{Encoder => _, _} +import smithy4s.Schema + +private[smithy4sinterop] class CirceDecoderImpl extends CachedSchemaCompiler[CirceDecoder] { + val decoder: CachedSchemaCompiler.DerivingImpl[Decoder] = Document.Decoder + + type Cache = decoder.Cache + def createCache(): Cache = decoder.createCache() + + def fromSchema[A](schema: Schema[A], cache: Cache): CirceDecoder[A] = + c => { + c.as[Json] + .map(fromJson(_)) + .flatMap { d => + decoder + .fromSchema(schema, cache) + .decode(d) + .left + .map(e => + DecodingFailure(DecodingFailure.Reason.CustomReason(e.getMessage), c.history ++ toCursorOps(e.path)) + ) + } + } + + def fromSchema[A](schema: Schema[A]): CirceDecoder[A] = fromSchema(schema, createCache()) + + private def toCursorOps(path: PayloadPath): List[CursorOp] = + path.segments.map { + case PayloadPath.Segment.Label(name) => CursorOp.DownField(name) + case PayloadPath.Segment.Index(i) => CursorOp.DownN(i) + } + + private def fromJson(json: Json): Document = json.fold( + jsonNull = DNull, + jsonBoolean = DBoolean(_), + jsonNumber = n => DNumber(n.toBigDecimal.get), + jsonString = DString(_), + jsonArray = arr => DArray(arr.map(fromJson)), + jsonObject = obj => DObject(obj.toMap.view.mapValues(fromJson).toMap) + ) +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceEncoderImpl.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceEncoderImpl.scala new file mode 100644 index 0000000..b03cf64 --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceEncoderImpl.scala @@ -0,0 +1,28 @@ +package jsonrpclib.smithy4sinterop + +import io.circe.{Encoder => CirceEncoder, _} +import smithy4s.schema.CachedSchemaCompiler +import smithy4s.Document +import smithy4s.Document._ +import smithy4s.Schema + +private[smithy4sinterop] class CirceEncoderImpl extends CachedSchemaCompiler[CirceEncoder] { + val encoder: CachedSchemaCompiler.DerivingImpl[Encoder] = Document.Encoder + + type Cache = encoder.Cache + def createCache(): Cache = encoder.createCache() + + def fromSchema[A](schema: Schema[A], cache: Cache): CirceEncoder[A] = + a => documentToJson(encoder.fromSchema(schema, cache).encode(a)) + + def fromSchema[A](schema: Schema[A]): CirceEncoder[A] = fromSchema(schema, createCache()) + + private val documentToJson: Document => Json = { + case DNull => Json.Null + case DString(value) => Json.fromString(value) + case DBoolean(value) => Json.fromBoolean(value) + case DNumber(value) => Json.fromBigDecimal(value) + case DArray(values) => Json.fromValues(values.map(documentToJson)) + case DObject(entries) => Json.fromFields(entries.view.mapValues(documentToJson)) + } +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJsonCodec.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJsonCodec.scala new file mode 100644 index 0000000..8e37725 --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJsonCodec.scala @@ -0,0 +1,33 @@ +package jsonrpclib.smithy4sinterop + +import io.circe._ +import smithy4s.schema.CachedSchemaCompiler +import smithy4s.Schema + +object CirceJsonCodec { + + object Encoder extends CirceEncoderImpl + object Decoder extends CirceDecoderImpl + + object Codec extends CachedSchemaCompiler[Codec] { + type Cache = (Encoder.Cache, Decoder.Cache) + def createCache(): Cache = (Encoder.createCache(), Decoder.createCache()) + + def fromSchema[A](schema: Schema[A]): Codec[A] = + io.circe.Codec.from(Decoder.fromSchema(schema), Encoder.fromSchema(schema)) + + def fromSchema[A](schema: Schema[A], cache: Cache): Codec[A] = + io.circe.Codec.from( + Decoder.fromSchema(schema, cache._2), + Encoder.fromSchema(schema, cache._1) + ) + } + + /** Creates a Circe `Codec[A]` from a Smithy4s `Schema[A]`. + * + * This enables encoding values of type `A` to JSON and decoding JSON back into `A`, using the structure defined by + * the Smithy schema. + */ + def fromSchema[A](implicit schema: Schema[A]): Codec[A] = + Codec.fromSchema(schema) +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala new file mode 100644 index 0000000..dbc8fde --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala @@ -0,0 +1,104 @@ +package jsonrpclib.smithy4sinterop + +import io.circe.Codec +import io.circe.HCursor +import jsonrpclib.Channel +import jsonrpclib.ErrorPayload +import jsonrpclib.Monadic +import jsonrpclib.Monadic.syntax._ +import jsonrpclib.ProtocolError +import smithy4s.~> +import smithy4s.schema._ +import smithy4s.Service +import smithy4s.ShapeId + +object ClientStub { + + /** Creates a JSON-RPC client implementation for a Smithy service. + * + * Given a Smithy `Service[Alg]` and a JSON-RPC communication `Channel[F]`, this constructs a fully functional client + * that translates method calls into JSON-RPC messages sent over the channel. + * + * Usage: + * {{{ + * val stub: MyService[IO] = ClientStub(myService, myChannel) + * val response: IO[String] = stub.hello("world") + * }}} + * + * Supports both standard request-response and fire-and-forget notification endpoints. + */ + def apply[Alg[_[_, _, _, _, _]], F[_]: Monadic](service: Service[Alg], channel: Channel[F]): service.Impl[F] = + new ClientStub(JsonRpcTransformations.apply(service), channel).compile +} + +private class ClientStub[Alg[_[_, _, _, _, _]], F[_]: Monadic](val service: Service[Alg], channel: Channel[F]) { + + def compile: service.Impl[F] = { + val codecCache = CirceJsonCodec.Codec.createCache() + val interpreter = new service.FunctorEndpointCompiler[F] { + def apply[I, E, O, SI, SO](e: service.Endpoint[I, E, O, SI, SO]): I => F[O] = { + val shapeId = e.id + val spec = EndpointSpec.fromHints(e.hints).toRight(NotJsonRPCEndpoint(shapeId)).toTry.get + + jsonRPCStub(e, spec, codecCache) + } + } + + service.impl(interpreter) + } + + def jsonRPCStub[I, E, O, SI, SO]( + smithy4sEndpoint: service.Endpoint[I, E, O, SI, SO], + endpointSpec: EndpointSpec, + codecCache: CirceJsonCodec.Codec.Cache + ): I => F[O] = { + + implicit val inputCodec: Codec[I] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.input, codecCache) + implicit val outputCodec: Codec[O] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.output, codecCache) + + def errorResponse(throwable: Throwable, errorCodec: Codec[E]): F[E] = { + throwable match { + case ErrorPayload(_, _, Some(payload)) => + errorCodec.decodeJson(payload.data) match { + case Left(err) => ProtocolError.ParseError(err.getMessage).raiseError + case Right(error) => error.pure + } + case e: Throwable => e.raiseError + } + } + + endpointSpec match { + case EndpointSpec.Notification(methodName) => + val coerce = coerceUnit[O](smithy4sEndpoint.output) + channel.notificationStub[I](methodName).andThen(f => Monadic[F].doFlatMap(f)(_ => coerce)) + case EndpointSpec.Request(methodName) => + smithy4sEndpoint.error match { + case None => channel.simpleStub[I, O](methodName) + case Some(errorSchema) => + val errorCodec = CirceJsonCodec.Codec.fromSchema(errorSchema.schema, codecCache) + val stub = channel.simpleStub[I, O](methodName) + (in: I) => + stub.apply(in).attempt.flatMap { + case Right(success) => success.pure + case Left(error) => + errorResponse(error, errorCodec) + .flatMap(e => errorSchema.unliftError(e).raiseError) + } + } + } + } + + case class NotJsonRPCEndpoint(shapeId: ShapeId) extends Throwable + case object NotUnitReturnType extends Throwable + + private object CoerceUnitVisitor extends (Schema ~> F) { + def apply[A](schema: Schema[A]): F[A] = schema match { + case s @ Schema.StructSchema(_, _, _, make) if s.isUnit => + Monadic[F].doPure(make(IndexedSeq.empty)) + case _ => Monadic[F].doRaiseError[A](NotUnitReturnType) + } + } + + private def coerceUnit[A](schema: Schema[A]): F[A] = CoerceUnitVisitor(schema) + +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/EndpointSpec.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/EndpointSpec.scala new file mode 100644 index 0000000..2c91f14 --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/EndpointSpec.scala @@ -0,0 +1,15 @@ +package jsonrpclib.smithy4sinterop + +import smithy4s.Hints + +private[smithy4sinterop] sealed trait EndpointSpec +private[smithy4sinterop] object EndpointSpec { + case class Notification(methodName: String) extends EndpointSpec + case class Request(methodName: String) extends EndpointSpec + + def fromHints(hints: Hints): Option[EndpointSpec] = hints match { + case jsonrpclib.JsonRpcRequest.hint(r) => Some(Request(r.value)) + case jsonrpclib.JsonRpcNotification.hint(r) => Some(Notification(r.value)) + case _ => None + } +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/JsonPayloadTransformation.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/JsonPayloadTransformation.scala new file mode 100644 index 0000000..a37092f --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/JsonPayloadTransformation.scala @@ -0,0 +1,21 @@ +package jsonrpclib.smithy4sinterop + +import jsonrpclib.JsonRpcPayload +import smithy4s.~> +import smithy4s.Schema +import smithy4s.Schema.StructSchema + +private[jsonrpclib] object JsonPayloadTransformation extends (Schema ~> Schema) { + + def apply[A0](fa: Schema[A0]): Schema[A0] = + fa match { + case struct: StructSchema[b] => + struct.fields + .collectFirst { + case field if field.hints.has[JsonRpcPayload] => + field.schema.biject[b]((f: Any) => struct.make(Vector(f)))(field.get) + } + .getOrElse(fa) + case _ => fa + } +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/JsonRpcTransformations.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/JsonRpcTransformations.scala new file mode 100644 index 0000000..4ff844d --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/JsonRpcTransformations.scala @@ -0,0 +1,33 @@ +package jsonrpclib.smithy4sinterop + +import smithy4s.~> +import smithy4s.schema.ErrorSchema +import smithy4s.schema.OperationSchema +import smithy4s.Endpoint +import smithy4s.Schema +import smithy4s.Service + +private[jsonrpclib] object JsonRpcTransformations { + + def apply[Alg[_[_, _, _, _, _]]]: Service[Alg] => Service[Alg] = + _.toBuilder + .mapEndpointEach( + Endpoint.mapSchema( + OperationSchema + .mapInputK(JsonPayloadTransformation) + .andThen(OperationSchema.mapOutputK(JsonPayloadTransformation)) + .andThen(OperationSchema.mapErrorK(errorTransformation)) + ) + ) + .build + + private val payloadTransformation: Schema ~> Schema = Schema + .transformTransitivelyK(JsonPayloadTransformation) + + private val errorTransformation: ErrorSchema ~> ErrorSchema = + new smithy4s.kinds.PolyFunction[ErrorSchema, ErrorSchema] { + def apply[A](e: ErrorSchema[A]): ErrorSchema[A] = { + payloadTransformation(e.schema).error(e.unliftError)(e.liftError.unlift) + } + } +} diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala new file mode 100644 index 0000000..92843f4 --- /dev/null +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala @@ -0,0 +1,113 @@ +package jsonrpclib.smithy4sinterop + +import io.circe.Codec +import jsonrpclib.Endpoint +import jsonrpclib.ErrorEncoder +import jsonrpclib.ErrorPayload +import jsonrpclib.Monadic +import jsonrpclib.Monadic.syntax._ +import jsonrpclib.Payload +import smithy4s.kinds.FunctorAlgebra +import smithy4s.kinds.FunctorInterpreter +import smithy4s.schema.ErrorSchema +import smithy4s.Service + +import _root_.smithy4s.{Endpoint => Smithy4sEndpoint} + +object ServerEndpoints { + + /** Creates JSON-RPC server endpoints from a Smithy service implementation. + * + * Given a Smithy `FunctorAlgebra[Alg, F]`, this extracts all operations and compiles them into JSON-RPC + * `Endpoint[F]` handlers that can be mounted on a communication channel (e.g. `FS2Channel`). + * + * Supports both standard request-response and notification-style endpoints, as well as Smithy-modeled errors. + * + * Usage: + * {{{ + * val endpoints = ServerEndpoints(new ServerImpl) + * channel.withEndpoints(endpoints) + * }}} + */ + def apply[Alg[_[_, _, _, _, _]], F[_]]( + impl: FunctorAlgebra[Alg, F] + )(implicit service: Service[Alg], F: Monadic[F]): List[Endpoint[F]] = { + val transformedService = JsonRpcTransformations.apply(service) + val interpreter: transformedService.FunctorInterpreter[F] = transformedService.toPolyFunction(impl) + val codecCache = CirceJsonCodec.Codec.createCache() + transformedService.endpoints.toList.flatMap { smithy4sEndpoint => + EndpointSpec + .fromHints(smithy4sEndpoint.hints) + .map { endpointSpec => + jsonRPCEndpoint(smithy4sEndpoint, endpointSpec, interpreter, codecCache) + } + .toList + } + } + + /** Constructs a JSON-RPC endpoint from a Smithy endpoint definition. + * + * Translates a single Smithy4s endpoint into a JSON-RPC `Endpoint[F]`, based on the method name and interaction type + * described in `EndpointSpec`. + * + * @param smithy4sEndpoint + * The Smithy4s endpoint to expose over JSON-RPC + * @param endpointSpec + * JSON-RPC method name and interaction hints + * @param impl + * Interpreter that executes the Smithy operation in `F` + * @param codecCache + * Coche for the schema to codec compilation results + * @return + * A JSON-RPC-compatible `Endpoint[F]` + */ + private def jsonRPCEndpoint[F[_]: Monadic, Op[_, _, _, _, _], I, E, O, SI, SO]( + smithy4sEndpoint: Smithy4sEndpoint[Op, I, E, O, SI, SO], + endpointSpec: EndpointSpec, + impl: FunctorInterpreter[Op, F], + codecCache: CirceJsonCodec.Codec.Cache + ): Endpoint[F] = { + implicit val inputCodec: Codec[I] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.input, codecCache) + implicit val outputCodec: Codec[O] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.output, codecCache) + + def errorResponse(throwable: Throwable): F[E] = throwable match { + case smithy4sEndpoint.Error((_, e)) => e.pure + case e: Throwable => e.raiseError + } + + endpointSpec match { + case EndpointSpec.Notification(methodName) => + Endpoint[F](methodName).notification { (input: I) => + val op = smithy4sEndpoint.wrap(input) + impl(op).void + } + case EndpointSpec.Request(methodName) => + smithy4sEndpoint.error match { + case None => + Endpoint[F](methodName).simple[I, O] { (input: I) => + val op = smithy4sEndpoint.wrap(input) + impl(op) + } + case Some(errorSchema) => + implicit val errorCodec: ErrorEncoder[E] = errorCodecFromSchema(errorSchema, codecCache) + Endpoint[F](methodName).apply[I, E, O] { (input: I) => + val op = smithy4sEndpoint.wrap(input) + impl(op).attempt.flatMap { + case Left(err) => errorResponse(err).map(r => Left(r): Either[E, O]) + case Right(success) => (Right(success): Either[E, O]).pure + } + } + } + } + } + + private def errorCodecFromSchema[A](s: ErrorSchema[A], cache: CirceJsonCodec.Codec.Cache): ErrorEncoder[A] = { + val circeCodec = CirceJsonCodec.Codec.fromSchema(s.schema, cache) + (a: A) => + ErrorPayload( + 0, + Option(s.unliftError(a).getMessage()).getOrElse("JSONRPC-smithy4s application error"), + Some(Payload(circeCodec.apply(a))) + ) + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala new file mode 100644 index 0000000..47311b2 --- /dev/null +++ b/project/Dependencies.scala @@ -0,0 +1,8 @@ +import sbt.* + +object Dependencies { + val alloy = new { + val version = "0.3.20" + val core = "com.disneystreaming.alloy" % "alloy-core" % version + } +} diff --git a/project/PathRef.scala b/project/PathRef.scala new file mode 100644 index 0000000..45539c5 --- /dev/null +++ b/project/PathRef.scala @@ -0,0 +1,32 @@ +import sbt.io.Hash +import sbt.util.FileInfo +import sbt.util.HashFileInfo +import sjsonnew.* + +import java.io.File + +case class PathRef(path: os.Path) + +object PathRef { + + def apply(f: File): PathRef = PathRef(os.Path(f)) + + implicit val pathFormat: JsonFormat[PathRef] = + BasicJsonProtocol.projectFormat[PathRef, HashFileInfo]( + p => + if (os.isFile(p.path)) FileInfo.hash(p.path.toIO) + else + // If the path is a directory, we get the hashes of all files + // then hash the concatenation of the hash's bytes. + FileInfo.hash( + p.path.toIO, + Hash( + os.walk(p.path) + .map(_.toIO) + .map(Hash(_)) + .foldLeft(Array.emptyByteArray)(_ ++ _) + ) + ), + hash => PathRef(hash.file) + ) +} diff --git a/project/SmithyTraitCodegen.scala b/project/SmithyTraitCodegen.scala new file mode 100644 index 0000000..551d51d --- /dev/null +++ b/project/SmithyTraitCodegen.scala @@ -0,0 +1,127 @@ +import sbt.* +import sbt.io.IO +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.model.node.ArrayNode +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.Model +import software.amazon.smithy.traitcodegen.TraitCodegenPlugin + +import java.io.File +import java.nio.file.Paths +import java.util.UUID + +object SmithyTraitCodegen { + + import sjsonnew.* + + import BasicJsonProtocol.* + + case class Args( + javaPackage: String, + smithyNamespace: String, + targetDir: os.Path, + smithySourcesDir: PathRef, + dependencies: List[PathRef] + ) + object Args { + + // format: off + private type ArgsDeconstructed = String :*: String :*: os.Path :*: PathRef :*: List[PathRef] :*: LNil + // format: on + + private implicit val pathFormat: JsonFormat[os.Path] = + BasicJsonProtocol.projectFormat[os.Path, File](p => p.toIO, file => os.Path(file)) + + implicit val argsIso = + LList.iso[Args, ArgsDeconstructed]( + { args: Args => + ("javaPackage", args.javaPackage) :*: + ("smithyNamespace", args.smithyNamespace) :*: + ("targetDir", args.targetDir) :*: + ("smithySourcesDir", args.smithySourcesDir) :*: + ("dependencies", args.dependencies) :*: + LNil + }, + { + case (_, javaPackage) :*: + (_, smithyNamespace) :*: + (_, targetDir) :*: + (_, smithySourcesDir) :*: + (_, dependencies) :*: + LNil => + Args( + javaPackage = javaPackage, + smithyNamespace = smithyNamespace, + targetDir = targetDir, + smithySourcesDir = smithySourcesDir, + dependencies = dependencies + ) + } + ) + + } + + case class Output(metaDir: File, javaDir: File) + + object Output { + + // format: off + private type OutputDeconstructed = File :*: File :*: LNil + // format: on + + implicit val outputIso = + LList.iso[Output, OutputDeconstructed]( + { output: Output => + ("metaDir", output.metaDir) :*: + ("javaDir", output.javaDir) :*: + LNil + }, + { + case (_, metaDir) :*: + (_, javaDir) :*: + LNil => + Output( + metaDir = metaDir, + javaDir = javaDir + ) + } + ) + } + + def generate(args: Args): Output = { + val outputDir = args.targetDir / "smithy-trait-generator-output" + val genDir = outputDir / "java" + val metaDir = outputDir / "meta" + os.remove.all(outputDir) + List(outputDir, genDir, metaDir).foreach(os.makeDir.all(_)) + + val manifest = FileManifest.create(genDir.toNIO) + + val model = args.dependencies + .foldLeft(Model.assembler().addImport(args.smithySourcesDir.path.toNIO)) { case (acc, dep) => + acc.addImport(dep.path.toNIO) + } + .assemble() + .unwrap() + val context = PluginContext + .builder() + .model(model) + .fileManifest(manifest) + .settings( + ObjectNode + .builder() + .withMember("package", args.javaPackage) + .withMember("namespace", args.smithyNamespace) + .withMember("header", ArrayNode.builder.build()) + .withMember("excludeTags", ArrayNode.builder.withValue("nocodegen").build()) + .build() + ) + .build() + val plugin = new TraitCodegenPlugin() + plugin.execute(context) + os.move(genDir / "META-INF", metaDir / "META-INF") + Output(metaDir = metaDir.toIO, javaDir = genDir.toIO) + } +} diff --git a/project/SmithyTraitCodegenPlugin.scala b/project/SmithyTraitCodegenPlugin.scala new file mode 100644 index 0000000..78903f4 --- /dev/null +++ b/project/SmithyTraitCodegenPlugin.scala @@ -0,0 +1,95 @@ +import sbt.* +import sbt.plugins.JvmPlugin +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.model.node.ArrayNode +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.Model +import software.amazon.smithy.traitcodegen.TraitCodegenPlugin + +import Keys.* + +object SmithyTraitCodegenPlugin extends AutoPlugin { + override def trigger: PluginTrigger = noTrigger + override def requires: Plugins = JvmPlugin + + object autoImport { + val smithyTraitCodegenJavaPackage = + settingKey[String]("The java target package where the generated smithy traits will be created") + val smithyTraitCodegenNamespace = settingKey[String]("The smithy namespace where the traits are defined") + val smithyTraitCodegenDependencies = settingKey[List[ModuleID]]("Dependencies to be added into codegen model") + } + import autoImport.* + + override def projectSettings: Seq[Setting[?]] = + Seq( + Keys.generateSmithyTraits := Def.task { + import sbt.util.CacheImplicits.* + val s = (Compile / streams).value + val logger = sLog.value + + val report = update.value + val dependencies = smithyTraitCodegenDependencies.value + val jars = + dependencies.flatMap(m => + report.matching(moduleFilter(organization = m.organization, name = m.name, revision = m.revision)) + ) + require( + jars.size == dependencies.size, + "Not all dependencies required for smithy-trait-codegen have been found" + ) + + val args = SmithyTraitCodegen.Args( + javaPackage = smithyTraitCodegenJavaPackage.value, + smithyNamespace = smithyTraitCodegenNamespace.value, + targetDir = os.Path((Compile / target).value), + smithySourcesDir = PathRef((Compile / resourceDirectory).value / "META-INF" / "smithy"), + dependencies = jars.map(PathRef(_)).toList + ) + val cachedCodegen = + Tracked.inputChanged[SmithyTraitCodegen.Args, SmithyTraitCodegen.Output]( + s.cacheStoreFactory.make("smithy-trait-codegen-args") + ) { + Function.untupled( + Tracked + .lastOutput[(Boolean, SmithyTraitCodegen.Args), SmithyTraitCodegen.Output]( + s.cacheStoreFactory.make("smithy-trait-codegen-output") + ) { case ((inputChanged, codegenArgs), cached) => + cached + .filter(_ => !inputChanged) + .fold { + SmithyTraitCodegen.generate(codegenArgs) + } { last => + logger.info(s"Using cached result of smithy-trait-codegen") + last + } + } + ) + } + cachedCodegen(args) + }.value, + Compile / sourceGenerators += Def.task { + val codegenOutput = (Compile / Keys.generateSmithyTraits).value + cleanCopy(source = codegenOutput.javaDir, target = (Compile / sourceManaged).value / "java") + }, + Compile / resourceGenerators += Def.task { + val codegenOutput = (Compile / Keys.generateSmithyTraits).value + cleanCopy(source = codegenOutput.metaDir, target = (Compile / resourceManaged).value) + }.taskValue, + libraryDependencies ++= smithyTraitCodegenDependencies.value + ) + + private def cleanCopy(source: File, target: File) = { + val sourcePath = os.Path(source) + val targetPath = os.Path(target) + os.remove.all(targetPath) + os.copy(from = sourcePath, to = targetPath, createFolders = true) + os.walk(targetPath).map(_.toIO).filter(_.isFile()) + } + + object Keys { + val generateSmithyTraits = + taskKey[SmithyTraitCodegen.Output]("Run AWS smithy-trait-codegen on the protocol specs") + } + +} diff --git a/project/build.sbt b/project/build.sbt new file mode 100644 index 0000000..969768d --- /dev/null +++ b/project/build.sbt @@ -0,0 +1,4 @@ +libraryDependencies ++= Seq( + "software.amazon.smithy" % "smithy-trait-codegen", + "software.amazon.smithy" % "smithy-model" +).map(_ % "1.58.0") diff --git a/project/plugins.sbt b/project/plugins.sbt index 6e9bb88..1a67d0e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -14,4 +14,6 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4") +addSbtPlugin("com.disneystreaming.smithy4s" % "smithy4s-sbt-codegen" % "0.18.37") + addDependencyTreePlugin