From 47cd16e048e19ebc0b8673bccffdade678cc0363 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 27 Nov 2023 18:00:38 -0800 Subject: [PATCH] feat: add error handler to rpc interface (#965) * adding handle error functionality * improvements and additional testcases * tests wrapped up * refactors * add documentation to readme * remove jest config for debugging in vscode * fix: update with PR feedback * fix: test case had object properties out of order --- README.markdown | 18 +- .../before-after-request-test.ts | 12 +- .../before-after-request/parameters.txt | 2 +- integration/before-after-request/simple.ts | 24 ++- .../handle-error-test.ts | 65 +++++++ .../parameters.txt | 1 + .../simple.bin | Bin 0 -> 482 bytes .../simple.proto | 14 ++ .../handle-error-in-default-service/simple.ts | 178 +++++++++++++++++ .../handle-error-test.ts | 89 +++++++++ .../parameters.txt | 1 + .../simple.bin | Bin 0 -> 482 bytes .../simple.proto | 14 ++ .../simple.ts | 183 ++++++++++++++++++ package.json | 2 +- src/generate-services.ts | 83 ++++++-- src/options.ts | 16 +- src/utils.ts | 8 + tests/options-test.ts | 27 ++- 19 files changed, 688 insertions(+), 49 deletions(-) create mode 100644 integration/handle-error-in-default-service/handle-error-test.ts create mode 100644 integration/handle-error-in-default-service/parameters.txt create mode 100644 integration/handle-error-in-default-service/simple.bin create mode 100755 integration/handle-error-in-default-service/simple.proto create mode 100644 integration/handle-error-in-default-service/simple.ts create mode 100644 integration/handle-error-with-after-response/handle-error-test.ts create mode 100644 integration/handle-error-with-after-response/parameters.txt create mode 100644 integration/handle-error-with-after-response/simple.bin create mode 100755 integration/handle-error-with-after-response/simple.proto create mode 100644 integration/handle-error-with-after-response/simple.ts diff --git a/README.markdown b/README.markdown index e8a1bdbe8..6e33b8f0b 100644 --- a/README.markdown +++ b/README.markdown @@ -19,10 +19,10 @@ - [Highlights](#highlights) - [Auto-Batching / N+1 Prevention](#auto-batching--n1-prevention) - [Usage](#usage) - - [Supported options](#supported-options) - - [NestJS Support](#nestjs-support) - - [Watch Mode](#watch-mode) - - [Basic gRPC implementation](#basic-grpc-implementation) + - [Supported options](#supported-options) + - [NestJS Support](#nestjs-support) + - [Watch Mode](#watch-mode) + - [Basic gRPC implementation](#basic-grpc-implementation) - [Sponsors](#sponsors) - [Development](#development) - [Assumptions](#assumptions) @@ -121,8 +121,8 @@ plugins: If you're using a modern TS setup with either `esModuleInterop` or running in an ESM environment, you'll need to pass `ts_proto_opt`s of: -* `esModuleInterop=true` if using `esModuleInterop` in your `tsconfig.json`, and -* `importSuffix=.js` if executing the generated ts-proto code in an ESM environment +- `esModuleInterop=true` if using `esModuleInterop` in your `tsconfig.json`, and +- `importSuffix=.js` if executing the generated ts-proto code in an ESM environment # Goals @@ -448,9 +448,11 @@ Generated code will be placed in the Gradle build directory. - With `--ts_proto_opt=outputServices=false`, or `=none`, ts-proto will output NO service definitions. -- With `--ts_proto_opt=outputBeforeRequest=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `beforeRequest(request: )`. It will will also automatically set `outputTypeRegistry=true` and `outputServices=true`. Each of the Service's methods will call `beforeRequest` before performing it's request. +- With `--ts_proto_opt=rpcBeforeRequest=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `beforeRequest(service: string, message: string, request: )`. It will will also automatically set `outputServices=default`. Each of the Service's methods will call `beforeRequest` before performing it's request. -- With `--ts_proto_opt=outputAfterResponse=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `afterResponse(response: )`. It will will also automatically set `outputTypeRegistry=true` and `outputServices=true`. Each of the Service's methods will call `afterResponse` before returning the response. +- With `--ts_proto_opt=rpcAfterResponse=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `afterResponse(service: string, message: string, response: )`. It will will also automatically set `outputServices=default`. Each of the Service's methods will call `afterResponse` before returning the response. + +- With `--ts_proto_opt=rpcErrorHandler=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `handleError(service: string, message: string, error: Error)`. It will will also automatically set `outputServices=default`. - With `--ts_proto_opt=useAbortSignal=true`, the generated services will accept an `AbortSignal` to cancel RPC calls. diff --git a/integration/before-after-request/before-after-request-test.ts b/integration/before-after-request/before-after-request-test.ts index 9d2be2dfb..b9ea417c1 100644 --- a/integration/before-after-request/before-after-request-test.ts +++ b/integration/before-after-request/before-after-request-test.ts @@ -1,5 +1,9 @@ -import { FooServiceClientImpl, FooServiceCreateRequest, FooServiceCreateResponse } from "./simple"; -import { MessageType } from "./typeRegistry"; +import { + FooServiceClientImpl, + FooServiceCreateRequest, + FooServiceCreateResponse, + FooServiceServiceName, +} from "./simple"; interface Rpc { request(service: string, method: string, data: Uint8Array): Promise; @@ -27,14 +31,14 @@ describe("before-after-request", () => { const req = FooServiceCreateRequest.create(exampleData); client = new FooServiceClientImpl({ ...rpc, beforeRequest: beforeRequest }); await client.Create(req); - expect(beforeRequest).toHaveBeenCalledWith(req); + expect(beforeRequest).toHaveBeenCalledWith(FooServiceServiceName, "Create", req); }); it("performs function after request if specified", async () => { const req = FooServiceCreateRequest.create(exampleData); client = new FooServiceClientImpl({ ...rpc, afterResponse: afterResponse }); await client.Create(req); - expect(afterResponse).toHaveBeenCalledWith(exampleData); + expect(afterResponse).toHaveBeenCalledWith(FooServiceServiceName, "Create", exampleData); }); it("doesn't perform function before or after request if they are not specified", async () => { diff --git a/integration/before-after-request/parameters.txt b/integration/before-after-request/parameters.txt index 9335148d4..736b689d7 100644 --- a/integration/before-after-request/parameters.txt +++ b/integration/before-after-request/parameters.txt @@ -1 +1 @@ -outputBeforeRequest=true,outputAfterResponse=true +rpcBeforeRequest=true,rpcAfterResponse=true,outputServices=default,outputServices=generic-definitions, diff --git a/integration/before-after-request/simple.ts b/integration/before-after-request/simple.ts index c3d318639..1da2df759 100644 --- a/integration/before-after-request/simple.ts +++ b/integration/before-after-request/simple.ts @@ -432,23 +432,39 @@ export class FooServiceClientImpl implements FooService { Create(request: FooServiceCreateRequest): Promise { const data = FooServiceCreateRequest.encode(request).finish(); if (this.rpc.beforeRequest) { - this.rpc.beforeRequest(request); + this.rpc.beforeRequest(this.service, "Create", request); } const promise = this.rpc.request(this.service, "Create", data); return promise.then((data) => { const response = FooServiceCreateResponse.decode(_m0.Reader.create(data)); if (this.rpc.afterResponse) { - this.rpc.afterResponse(response); + this.rpc.afterResponse(this.service, "Create", response); } return response; }); } } +export type FooServiceDefinition = typeof FooServiceDefinition; +export const FooServiceDefinition = { + name: "FooService", + fullName: "simple.FooService", + methods: { + create: { + name: "Create", + requestType: FooServiceCreateRequest, + requestStream: false, + responseType: FooServiceCreateResponse, + responseStream: false, + options: {}, + }, + }, +} as const; + interface Rpc { request(service: string, method: string, data: Uint8Array): Promise; - beforeRequest?(request: T): void; - afterResponse?(response: T): void; + beforeRequest?(service: string, method: string, request: T): void; + afterResponse?(service: string, method: string, response: T): void; } type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; diff --git a/integration/handle-error-in-default-service/handle-error-test.ts b/integration/handle-error-in-default-service/handle-error-test.ts new file mode 100644 index 000000000..fcec1e510 --- /dev/null +++ b/integration/handle-error-in-default-service/handle-error-test.ts @@ -0,0 +1,65 @@ +import { GetBasicResponse, GetBasicRequest, BasicServiceClientImpl, BasicServiceServiceName } from "./simple"; + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + handleError?(service: string, method: string, error: Error): Error; +} + +describe("before-after-request", () => { + const exampleData = { + name: "test-name", + }; + let rpc = { + request: jest.fn(() => Promise.resolve(new Uint8Array())), + }; + let client = new BasicServiceClientImpl(rpc); + let err = new Error("error"); + + let modifiedError = new Error("modified error"); + const handleError = jest.fn(() => modifiedError); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it("doesn't perform handleError if error occurs during encode step", async () => { + const encodeSpy = jest.spyOn(GetBasicRequest, "encode").mockImplementation(() => { + throw err; + }); + const req = GetBasicRequest.create(exampleData); + client = new BasicServiceClientImpl({ ...rpc, handleError: handleError }); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(err); + expect(handleError).not.toHaveBeenCalled(); + } + encodeSpy.mockRestore(); + }); + + it("performs handleError if error occurs when decoding", async () => { + const decodeSpy = jest.spyOn(GetBasicResponse, "decode").mockImplementation(() => { + throw err; + }); + const req = GetBasicRequest.create(exampleData); + client = new BasicServiceClientImpl({ ...rpc, handleError: handleError }); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(modifiedError); + expect(handleError).toHaveBeenCalledWith(BasicServiceServiceName, "GetBasic", err); + } + decodeSpy.mockRestore(); + }); + + it("doesn't perform handleError if it is not specified", async () => { + const req = GetBasicRequest.create(exampleData); + client = new BasicServiceClientImpl(rpc); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(err); + expect(handleError).not.toHaveBeenCalled(); + } + }); +}); diff --git a/integration/handle-error-in-default-service/parameters.txt b/integration/handle-error-in-default-service/parameters.txt new file mode 100644 index 000000000..fba9f2ee7 --- /dev/null +++ b/integration/handle-error-in-default-service/parameters.txt @@ -0,0 +1 @@ +outputServices=default,rpcErrorHandler=true diff --git a/integration/handle-error-in-default-service/simple.bin b/integration/handle-error-in-default-service/simple.bin new file mode 100644 index 0000000000000000000000000000000000000000..6a80e7ca5b5395a5161b3259e832e672698a8bdf GIT binary patch literal 482 zcmZvZu};G<5QfjO<2o0&gags23n&r@b;tl45)3RzNG#@d)nK3du%0&SmUlM;^QM_8V)w(#xg9o?2_w&NNzPWZ! zglfCoxLllx%VZJe;sUduSliUw*8P{edO{mq&u;f#tsUK>8NQ>FGOR3zgMQy^`sch` zP{8A7gqjqrq?mB&h$Da4w_!a#5&XIG`CBAdv2$FDmBMmr$aj% c-T4L{_{!?ZK>R$Loeso0JHJ`!@Jp}10s0L{^#A|> literal 0 HcmV?d00001 diff --git a/integration/handle-error-in-default-service/simple.proto b/integration/handle-error-in-default-service/simple.proto new file mode 100755 index 000000000..db5c7a53d --- /dev/null +++ b/integration/handle-error-in-default-service/simple.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; +package basic; + +message GetBasicRequest { + string name = 1; +} + +message GetBasicResponse { + string name = 1; +} + +service BasicService { + rpc GetBasic (GetBasicRequest) returns (GetBasicResponse) {} +} diff --git a/integration/handle-error-in-default-service/simple.ts b/integration/handle-error-in-default-service/simple.ts new file mode 100644 index 000000000..7dda0512b --- /dev/null +++ b/integration/handle-error-in-default-service/simple.ts @@ -0,0 +1,178 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; + +export const protobufPackage = "basic"; + +export interface GetBasicRequest { + name: string; +} + +export interface GetBasicResponse { + name: string; +} + +function createBaseGetBasicRequest(): GetBasicRequest { + return { name: "" }; +} + +export const GetBasicRequest = { + encode(message: GetBasicRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetBasicRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetBasicRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetBasicRequest { + return { name: isSet(object.name) ? globalThis.String(object.name) : "" }; + }, + + toJSON(message: GetBasicRequest): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + return obj; + }, + + create, I>>(base?: I): GetBasicRequest { + return GetBasicRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): GetBasicRequest { + const message = createBaseGetBasicRequest(); + message.name = object.name ?? ""; + return message; + }, +}; + +function createBaseGetBasicResponse(): GetBasicResponse { + return { name: "" }; +} + +export const GetBasicResponse = { + encode(message: GetBasicResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetBasicResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetBasicResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetBasicResponse { + return { name: isSet(object.name) ? globalThis.String(object.name) : "" }; + }, + + toJSON(message: GetBasicResponse): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + return obj; + }, + + create, I>>(base?: I): GetBasicResponse { + return GetBasicResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): GetBasicResponse { + const message = createBaseGetBasicResponse(); + message.name = object.name ?? ""; + return message; + }, +}; + +export interface BasicService { + GetBasic(request: GetBasicRequest): Promise; +} + +export const BasicServiceServiceName = "basic.BasicService"; +export class BasicServiceClientImpl implements BasicService { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || BasicServiceServiceName; + this.rpc = rpc; + this.GetBasic = this.GetBasic.bind(this); + } + GetBasic(request: GetBasicRequest): Promise { + const data = GetBasicRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "GetBasic", data); + return promise.then((data) => { + try { + return GetBasicResponse.decode(_m0.Reader.create(data)); + } catch (error) { + return Promise.reject(error); + } + }).catch((error) => { + if (error instanceof Error && this.rpc.handleError) { + return Promise.reject(this.rpc.handleError(this.service, "GetBasic", error)); + } + return Promise.reject(error); + }); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + handleError?(service: string, method: string, error: Error): Error; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/integration/handle-error-with-after-response/handle-error-test.ts b/integration/handle-error-with-after-response/handle-error-test.ts new file mode 100644 index 000000000..5cd51d0ff --- /dev/null +++ b/integration/handle-error-with-after-response/handle-error-test.ts @@ -0,0 +1,89 @@ +import { GetBasicResponse, GetBasicRequest, BasicServiceClientImpl, BasicServiceServiceName } from "./simple"; + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + handleError?(service: string, method: string, error: Error): Error; +} + +describe("before-after-request", () => { + const exampleData = { + name: "test-name", + }; + let rpc = { + request: jest.fn(() => Promise.resolve(new Uint8Array())), + }; + let client = new BasicServiceClientImpl(rpc); + let err = new Error("error"); + + let modifiedError = new Error("modified error"); + const handleError = jest.fn(() => modifiedError); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it("performs handleError if error occurs during main request code block", async () => { + const encodeSpy = jest.spyOn(GetBasicRequest, "encode").mockImplementation(() => { + throw err; + }); + const req = GetBasicRequest.create(exampleData); + client = new BasicServiceClientImpl({ ...rpc, handleError: handleError }); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(err); + expect(handleError).not.toHaveBeenCalled(); + } + encodeSpy.mockRestore(); + }); + + it("performs handleError if error occurs when decoding", async () => { + const decodeSpy = jest.spyOn(GetBasicResponse, "decode").mockImplementation(() => { + throw err; + }); + const req = GetBasicRequest.create(exampleData); + client = new BasicServiceClientImpl({ ...rpc, handleError: handleError }); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(modifiedError); + expect(handleError).toHaveBeenCalledWith(BasicServiceServiceName, "GetBasic", err); + } + decodeSpy.mockRestore(); + }); + + it("performs handleError if error occurs when calling afterResponse", async () => { + const decodeSpy = jest.spyOn(GetBasicResponse, "decode").mockReturnValue(exampleData); + const req = GetBasicRequest.create(exampleData); + const res = GetBasicResponse.create(exampleData); + + const afterResponse = jest.fn(() => { + throw err; + }); + client = new BasicServiceClientImpl({ + ...rpc, + handleError: handleError, + afterResponse: afterResponse, + }); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(modifiedError); + expect(decodeSpy).toHaveBeenCalledTimes(1); + expect(afterResponse).toHaveBeenCalledWith(BasicServiceServiceName, "GetBasic", res); + expect(handleError).toHaveBeenCalledWith(BasicServiceServiceName, "GetBasic", err); + } + decodeSpy.mockRestore(); + }); + + it("doesn't perform handleError if it is not specified", async () => { + const req = GetBasicRequest.create(exampleData); + client = new BasicServiceClientImpl(rpc); + try { + await client.GetBasic(req); + } catch (error) { + expect(error).toBe(err); + expect(handleError).not.toHaveBeenCalled(); + } + }); +}); diff --git a/integration/handle-error-with-after-response/parameters.txt b/integration/handle-error-with-after-response/parameters.txt new file mode 100644 index 000000000..fa51884ce --- /dev/null +++ b/integration/handle-error-with-after-response/parameters.txt @@ -0,0 +1 @@ +outputServices=default,rpcAfterResponse=true,rpcErrorHandler=true diff --git a/integration/handle-error-with-after-response/simple.bin b/integration/handle-error-with-after-response/simple.bin new file mode 100644 index 0000000000000000000000000000000000000000..6a80e7ca5b5395a5161b3259e832e672698a8bdf GIT binary patch literal 482 zcmZvZu};G<5QfjO<2o0&gags23n&r@b;tl45)3RzNG#@d)nK3du%0&SmUlM;^QM_8V)w(#xg9o?2_w&NNzPWZ! zglfCoxLllx%VZJe;sUduSliUw*8P{edO{mq&u;f#tsUK>8NQ>FGOR3zgMQy^`sch` zP{8A7gqjqrq?mB&h$Da4w_!a#5&XIG`CBAdv2$FDmBMmr$aj% c-T4L{_{!?ZK>R$Loeso0JHJ`!@Jp}10s0L{^#A|> literal 0 HcmV?d00001 diff --git a/integration/handle-error-with-after-response/simple.proto b/integration/handle-error-with-after-response/simple.proto new file mode 100755 index 000000000..db5c7a53d --- /dev/null +++ b/integration/handle-error-with-after-response/simple.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; +package basic; + +message GetBasicRequest { + string name = 1; +} + +message GetBasicResponse { + string name = 1; +} + +service BasicService { + rpc GetBasic (GetBasicRequest) returns (GetBasicResponse) {} +} diff --git a/integration/handle-error-with-after-response/simple.ts b/integration/handle-error-with-after-response/simple.ts new file mode 100644 index 000000000..b0bc76c5a --- /dev/null +++ b/integration/handle-error-with-after-response/simple.ts @@ -0,0 +1,183 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; + +export const protobufPackage = "basic"; + +export interface GetBasicRequest { + name: string; +} + +export interface GetBasicResponse { + name: string; +} + +function createBaseGetBasicRequest(): GetBasicRequest { + return { name: "" }; +} + +export const GetBasicRequest = { + encode(message: GetBasicRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetBasicRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetBasicRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetBasicRequest { + return { name: isSet(object.name) ? globalThis.String(object.name) : "" }; + }, + + toJSON(message: GetBasicRequest): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + return obj; + }, + + create, I>>(base?: I): GetBasicRequest { + return GetBasicRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): GetBasicRequest { + const message = createBaseGetBasicRequest(); + message.name = object.name ?? ""; + return message; + }, +}; + +function createBaseGetBasicResponse(): GetBasicResponse { + return { name: "" }; +} + +export const GetBasicResponse = { + encode(message: GetBasicResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetBasicResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetBasicResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetBasicResponse { + return { name: isSet(object.name) ? globalThis.String(object.name) : "" }; + }, + + toJSON(message: GetBasicResponse): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + return obj; + }, + + create, I>>(base?: I): GetBasicResponse { + return GetBasicResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): GetBasicResponse { + const message = createBaseGetBasicResponse(); + message.name = object.name ?? ""; + return message; + }, +}; + +export interface BasicService { + GetBasic(request: GetBasicRequest): Promise; +} + +export const BasicServiceServiceName = "basic.BasicService"; +export class BasicServiceClientImpl implements BasicService { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || BasicServiceServiceName; + this.rpc = rpc; + this.GetBasic = this.GetBasic.bind(this); + } + GetBasic(request: GetBasicRequest): Promise { + const data = GetBasicRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "GetBasic", data); + return promise.then((data) => { + try { + const response = GetBasicResponse.decode(_m0.Reader.create(data)); + if (this.rpc.afterResponse) { + this.rpc.afterResponse(this.service, "GetBasic", response); + } + return response; + } catch (error) { + return Promise.reject(error); + } + }).catch((error) => { + if (error instanceof Error && this.rpc.handleError) { + return Promise.reject(this.rpc.handleError(this.service, "GetBasic", error)); + } + return Promise.reject(error); + }); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + afterResponse?(service: string, method: string, response: T): void; + handleError?(service: string, method: string, error: Error): Error; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/package.json b/package.json index a07abf736..9183d3dbb 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,7 @@ "proto2bin-node": "docker compose run --rm node update-bins.sh", "proto2pbjs": "docker compose run --rm protoc pbjs.sh", "bin2ts": "docker compose run --rm protoc codegen.sh", - "test": "yarn jest -c jest.config.js --maxWorkers=2", + "test": "yarn jest -c jest.config.js", "tsc:check": "./tsc-check.sh tsconfig.json tests/tsconfig.json integration/tsconfig.json integration/tsconfig.proto.json protos/tsconfig.json", "format": "prettier --write {src,tests}/**/*.ts integration/*.ts", "format:check": "prettier --list-different {src,tests}/**/*.ts", diff --git a/src/generate-services.ts b/src/generate-services.ts index ec7665314..d81594337 100644 --- a/src/generate-services.ts +++ b/src/generate-services.ts @@ -16,6 +16,7 @@ import { maybeAddComment, maybePrefixPackage, singular, + tryCatchBlock, } from "./utils"; import SourceInfo, { Fields } from "./sourceInfo"; import { contextTypeVar } from "./main"; @@ -119,23 +120,33 @@ function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProt const maybeCtx = options.context ? "ctx," : ""; const maybeAbortSignal = options.useAbortSignal ? "abortSignal || undefined," : ""; + let errorHandler; + if (options.rpcErrorHandler) { + errorHandler = code` + if (error instanceof Error && this.rpc.handleError) { + return Promise.reject(this.rpc.handleError(this.service, "${methodDesc.name}", error)); + } + return Promise.reject(error); + `; + } + let encode = code`${rawInputType}.encode(request).finish()`; let beforeRequest; - if (options.outputBeforeRequest) { + if (options.rpcBeforeRequest) { beforeRequest = code` if (this.rpc.beforeRequest) { - this.rpc.beforeRequest(request); + this.rpc.beforeRequest(this.service, "${methodDesc.name}", request); }`; } - let decode = code`data => ${rawOutputType}.decode(${Reader}.create(data))`; - if (options.outputAfterResponse) { - decode = code`data => { + let decode = code`${rawOutputType}.decode(${Reader}.create(data))`; + if (options.rpcAfterResponse) { + decode = code` const response = ${rawOutputType}.decode(${Reader}.create(data)); if (this.rpc.afterResponse) { - this.rpc.afterResponse(response); + this.rpc.afterResponse(this.service, "${methodDesc.name}", response); } return response; - }`; + `; } // if (options.useDate && rawOutputType.toString().includes("Timestamp")) { @@ -148,17 +159,14 @@ function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProt encode = code`request.pipe(${imp("map@rxjs/operators")}(request => ${encode}))`; } } + + const returnStatement = createDefaultServiceReturn(ctx, methodDesc, decode, errorHandler); + let returnVariable: string; if (options.returnObservable || methodDesc.serverStreaming) { returnVariable = "result"; - if (options.useAsyncIterable) { - decode = code`${rawOutputType}.decodeTransform(result)`; - } else { - decode = code`result.pipe(${imp("map@rxjs/operators")}(${decode}))`; - } } else { returnVariable = "promise"; - decode = code`promise.then(${decode})`; } let rpcMethod: string; @@ -184,11 +192,42 @@ function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProt data, ${maybeAbortSignal} ); - return ${decode}; + return ${returnStatement}; } `; } +function createDefaultServiceReturn( + ctx: Context, + methodDesc: MethodDescriptorProto, + decode: Code, + errorHandler?: Code, +): Code { + const { options } = ctx; + const rawOutputType = responseType(ctx, methodDesc, { keepValueType: true }); + if (options.returnObservable || methodDesc.serverStreaming) { + if (options.useAsyncIterable) { + return code`${rawOutputType}.decodeTransform(result)`; + } else { + return code`result.pipe(${imp("map@rxjs/operators")}(data => ${decode}))`; + } + } + + if (errorHandler) { + let tryBlock = decode; + if (!options.rpcAfterResponse) { + tryBlock = code`return ${decode}`; + } + return code`promise.then(data => { ${tryCatchBlock( + tryBlock, + code`return Promise.reject(error);`, + )}}).catch((error) => { ${errorHandler} })`; + } else if (options.rpcAfterResponse) { + return code`promise.then(data => { ${decode} } )`; + } + return code`promise.then(data => ${decode})`; +} + export function generateServiceClientImpl( ctx: Context, fileDesc: FileDescriptorProto, @@ -368,12 +407,20 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod const maybeAbortSignalParam = options.useAbortSignal ? "abortSignal?: AbortSignal," : ""; const methods = [[code`request`, code`Uint8Array`, code`Promise`]]; const additionalMethods = []; - if (options.outputBeforeRequest) { - additionalMethods.push(code`beforeRequest?(request: T): void;`); + if (options.rpcBeforeRequest) { + additionalMethods.push( + code`beforeRequest?(service: string, method: string, request: T): void;`, + ); } - if (options.outputAfterResponse) { - additionalMethods.push(code`afterResponse?(response: T): void;`); + if (options.rpcAfterResponse) { + additionalMethods.push( + code`afterResponse?(service: string, method: string, response: T): void;`, + ); + } + if (options.rpcErrorHandler) { + additionalMethods.push(code`handleError?(service: string, method: string, error: Error): Error;`); } + if (hasStreamingMethods) { const observable = observableType(ctx, true); methods.push([code`clientStreamingRequest`, code`${observable}`, code`Promise`]); diff --git a/src/options.ts b/src/options.ts index 72a77687f..0e5661435 100644 --- a/src/options.ts +++ b/src/options.ts @@ -93,8 +93,9 @@ export type Options = { outputExtensions: boolean; outputIndex: boolean; M: { [from: string]: string }; - outputBeforeRequest: boolean; - outputAfterResponse: boolean; + rpcBeforeRequest: boolean; + rpcAfterResponse: boolean; + rpcErrorHandler: boolean; }; export function defaultOptions(): Options { @@ -152,8 +153,9 @@ export function defaultOptions(): Options { outputExtensions: false, outputIndex: false, M: {}, - outputBeforeRequest: false, - outputAfterResponse: false, + rpcBeforeRequest: false, + rpcAfterResponse: false, + rpcErrorHandler: false, }; } @@ -251,8 +253,12 @@ export function optionsFromParameter(parameter: string | undefined): Options { options.exportCommonSymbols = false; } - if (options.outputBeforeRequest || options.outputAfterResponse) { + if (options.rpcBeforeRequest || options.rpcAfterResponse || options.rpcErrorHandler) { + const includesGeneric = options.outputServices.includes(ServiceOption.GENERIC); options.outputServices = [ServiceOption.DEFAULT]; + if (includesGeneric) { + options.outputServices.push(ServiceOption.GENERIC); + } } if (options.unrecognizedEnumValue) { diff --git a/src/utils.ts b/src/utils.ts index 55b16fee0..eeb19a803 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -285,3 +285,11 @@ export function impProto(options: Options, module: string, type: string): Import } return imp(`${prefix}${type}@./${module}${options.fileSuffix}${options.importSuffix}`); } + +export function tryCatchBlock(tryBlock: Code | string, handleErrorBlock: Code | string): Code { + return code`try { + ${tryBlock} + } catch (error) { + ${handleErrorBlock} + }`; +} diff --git a/tests/options-test.ts b/tests/options-test.ts index d6783dbd2..fdaf2b272 100644 --- a/tests/options-test.ts +++ b/tests/options-test.ts @@ -26,8 +26,6 @@ describe("options", () => { "nestJs": true, "oneof": "properties", "onlyTypes": false, - "outputAfterResponse": false, - "outputBeforeRequest": false, "outputClientImpl": false, "outputEncodeMethods": false, "outputExtensions": false, @@ -42,6 +40,9 @@ describe("options", () => { "outputTypeRegistry": false, "removeEnumPrefix": false, "returnObservable": false, + "rpcAfterResponse": false, + "rpcBeforeRequest": false, + "rpcErrorHandler": false, "snakeToCamel": [ "json", "keys", @@ -168,19 +169,29 @@ describe("options", () => { }); }); - it("outputAfterResponse implies default service", () => { - const options = optionsFromParameter("outputAfterResponse=true"); + it("rpcAfterResponse implies default service", () => { + const options = optionsFromParameter("rpcAfterResponse=true"); expect(options).toMatchObject({ - outputAfterResponse: true, + rpcAfterResponse: true, outputServices: [ServiceOption.DEFAULT], }); }); - it("outputBeforeRequest implies default service", () => { - const options = optionsFromParameter("outputBeforeRequest=true"); + it("rpcBeforeRequest implies default service", () => { + const options = optionsFromParameter("rpcBeforeRequest=true"); expect(options).toMatchObject({ - outputBeforeRequest: true, + rpcBeforeRequest: true, outputServices: [ServiceOption.DEFAULT], }); }); + + it("rpcAfterResponse implies default service but allows generics too", () => { + const options = optionsFromParameter( + "rpcBeforeRequest=true,outputServices=generic-definitions,outputServices=default", + ); + expect(options).toMatchObject({ + rpcBeforeRequest: true, + outputServices: [ServiceOption.DEFAULT, ServiceOption.GENERIC], + }); + }); });