From ef498db875f8d21353cd193487709d5caacd8e98 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 7 Oct 2024 15:04:04 +0200 Subject: [PATCH] Update gitignore --- src/ts/.gitignore | 2 + src/ts/src/lib/address.js | 38 ++ src/ts/src/lib/address.test.js | 49 +++ src/ts/src/lib/address.test.ts | 54 +++ src/ts/src/lib/address.ts | 48 +++ src/ts/src/lib/client.js | 111 ++++++ src/ts/src/lib/client.test.js | 151 +++++++ src/ts/src/lib/client.test.ts | 181 +++++++++ src/ts/src/lib/client.ts | 138 +++++++ src/ts/src/lib/client_app.js | 137 +++++++ src/ts/src/lib/client_app.ts | 157 ++++++++ src/ts/src/lib/client_interceptor.js | 58 +++ src/ts/src/lib/client_interceptor.test.js | 138 +++++++ src/ts/src/lib/client_interceptor.test.ts | 152 +++++++ src/ts/src/lib/client_interceptor.ts | 88 ++++ src/ts/src/lib/config.js | 363 +++++++++++++++++ src/ts/src/lib/config.test.js | 166 ++++++++ src/ts/src/lib/config.test.ts | 180 +++++++++ src/ts/src/lib/config.ts | 395 ++++++++++++++++++ src/ts/src/lib/connection.js | 176 ++++++++ src/ts/src/lib/connection.ts | 235 +++++++++++ src/ts/src/lib/constants.js | 11 + src/ts/src/lib/constants.ts | 8 + src/ts/src/lib/crypto_helpers.js | 48 +++ src/ts/src/lib/crypto_helpers.ts | 24 ++ src/ts/src/lib/grpc.js | 57 +++ src/ts/src/lib/grpc.ts | 45 +++ src/ts/src/lib/heartbeat.js | 70 ++++ src/ts/src/lib/heartbeat.ts | 81 ++++ src/ts/src/lib/index.js | 21 + src/ts/src/lib/index.ts | 20 + src/ts/src/lib/logger.js | 39 ++ src/ts/src/lib/logger.ts | 46 +++ src/ts/src/lib/message_handler.js | 91 +++++ src/ts/src/lib/message_handler.test.js | 161 ++++++++ src/ts/src/lib/message_handler.test.ts | 205 ++++++++++ src/ts/src/lib/message_handler.ts | 129 ++++++ src/ts/src/lib/node_state.js | 70 ++++ src/ts/src/lib/node_state.ts | 89 +++++ src/ts/src/lib/recordset.js | 45 +++ src/ts/src/lib/recordset.ts | 56 +++ src/ts/src/lib/recordset_compat.js | 197 +++++++++ src/ts/src/lib/recordset_compat.test.js | 73 ++++ src/ts/src/lib/recordset_compat.test.ts | 107 +++++ src/ts/src/lib/recordset_compat.ts | 312 +++++++++++++++ src/ts/src/lib/retry_invoker.js | 114 ++++++ src/ts/src/lib/retry_invoker.test.js | 83 ++++ src/ts/src/lib/retry_invoker.test.ts | 106 +++++ src/ts/src/lib/retry_invoker.ts | 149 +++++++ src/ts/src/lib/serde.js | 367 +++++++++++++++++ src/ts/src/lib/serde.test.js | 202 ++++++++++ src/ts/src/lib/serde.test.ts | 260 ++++++++++++ src/ts/src/lib/serde.ts | 464 ++++++++++++++++++++++ src/ts/src/lib/start.js | 174 ++++++++ src/ts/src/lib/start.ts | 222 +++++++++++ src/ts/src/lib/task_handler.js | 17 + src/ts/src/lib/task_handler.test.js | 57 +++ src/ts/src/lib/task_handler.test.ts | 61 +++ src/ts/src/lib/task_handler.ts | 18 + src/ts/src/lib/typing.js | 78 ++++ src/ts/src/lib/typing.ts | 206 ++++++++++ 61 files changed, 7600 insertions(+) create mode 100644 src/ts/src/lib/address.js create mode 100644 src/ts/src/lib/address.test.js create mode 100644 src/ts/src/lib/address.test.ts create mode 100644 src/ts/src/lib/address.ts create mode 100644 src/ts/src/lib/client.js create mode 100644 src/ts/src/lib/client.test.js create mode 100644 src/ts/src/lib/client.test.ts create mode 100644 src/ts/src/lib/client.ts create mode 100644 src/ts/src/lib/client_app.js create mode 100644 src/ts/src/lib/client_app.ts create mode 100644 src/ts/src/lib/client_interceptor.js create mode 100644 src/ts/src/lib/client_interceptor.test.js create mode 100644 src/ts/src/lib/client_interceptor.test.ts create mode 100644 src/ts/src/lib/client_interceptor.ts create mode 100644 src/ts/src/lib/config.js create mode 100644 src/ts/src/lib/config.test.js create mode 100644 src/ts/src/lib/config.test.ts create mode 100644 src/ts/src/lib/config.ts create mode 100644 src/ts/src/lib/connection.js create mode 100644 src/ts/src/lib/connection.ts create mode 100644 src/ts/src/lib/constants.js create mode 100644 src/ts/src/lib/constants.ts create mode 100644 src/ts/src/lib/crypto_helpers.js create mode 100644 src/ts/src/lib/crypto_helpers.ts create mode 100644 src/ts/src/lib/grpc.js create mode 100644 src/ts/src/lib/grpc.ts create mode 100644 src/ts/src/lib/heartbeat.js create mode 100644 src/ts/src/lib/heartbeat.ts create mode 100644 src/ts/src/lib/index.js create mode 100644 src/ts/src/lib/index.ts create mode 100644 src/ts/src/lib/logger.js create mode 100644 src/ts/src/lib/logger.ts create mode 100644 src/ts/src/lib/message_handler.js create mode 100644 src/ts/src/lib/message_handler.test.js create mode 100644 src/ts/src/lib/message_handler.test.ts create mode 100644 src/ts/src/lib/message_handler.ts create mode 100644 src/ts/src/lib/node_state.js create mode 100644 src/ts/src/lib/node_state.ts create mode 100644 src/ts/src/lib/recordset.js create mode 100644 src/ts/src/lib/recordset.ts create mode 100644 src/ts/src/lib/recordset_compat.js create mode 100644 src/ts/src/lib/recordset_compat.test.js create mode 100644 src/ts/src/lib/recordset_compat.test.ts create mode 100644 src/ts/src/lib/recordset_compat.ts create mode 100644 src/ts/src/lib/retry_invoker.js create mode 100644 src/ts/src/lib/retry_invoker.test.js create mode 100644 src/ts/src/lib/retry_invoker.test.ts create mode 100644 src/ts/src/lib/retry_invoker.ts create mode 100644 src/ts/src/lib/serde.js create mode 100644 src/ts/src/lib/serde.test.js create mode 100644 src/ts/src/lib/serde.test.ts create mode 100644 src/ts/src/lib/serde.ts create mode 100644 src/ts/src/lib/start.js create mode 100644 src/ts/src/lib/start.ts create mode 100644 src/ts/src/lib/task_handler.js create mode 100644 src/ts/src/lib/task_handler.test.js create mode 100644 src/ts/src/lib/task_handler.test.ts create mode 100644 src/ts/src/lib/task_handler.ts create mode 100644 src/ts/src/lib/typing.js create mode 100644 src/ts/src/lib/typing.ts diff --git a/src/ts/.gitignore b/src/ts/.gitignore index de9ee66e8545..23df5c69c53a 100644 --- a/src/ts/.gitignore +++ b/src/ts/.gitignore @@ -23,3 +23,5 @@ coverage *.njsproj *.sln *.sw? + +!src/lib diff --git a/src/ts/src/lib/address.js b/src/ts/src/lib/address.js new file mode 100644 index 000000000000..67615e6eff45 --- /dev/null +++ b/src/ts/src/lib/address.js @@ -0,0 +1,38 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.parseAddress = parseAddress; +const net_1 = require("net"); +const IPV6 = 6; +const IPV4 = 4; +function parseAddress(address) { + try { + const lastColonIndex = address.lastIndexOf(":"); + if (lastColonIndex === -1) { + throw new Error("No port was provided."); + } + // Split the address into host and port. + const rawHost = address.slice(0, lastColonIndex); + const rawPort = address.slice(lastColonIndex + 1); + const port = parseInt(rawPort, 10); + if (port > 65535 || port < 1) { + throw new Error("Port number is invalid."); + } + let host = rawHost.replace(/[\[\]]/g, ""); // Remove brackets for IPv6 + let version = null; + const ipVersion = (0, net_1.isIP)(host); + if (ipVersion === IPV6) { + version = true; + } + else if (ipVersion === IPV4) { + version = false; + } + return { + host, + port, + version, + }; + } + catch (err) { + return null; + } +} diff --git a/src/ts/src/lib/address.test.js b/src/ts/src/lib/address.test.js new file mode 100644 index 000000000000..4d8138d11c96 --- /dev/null +++ b/src/ts/src/lib/address.test.js @@ -0,0 +1,49 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const address_1 = require("./address"); +describe("parseAddress", () => { + test("parses a valid IPv4 address", () => { + const result = (0, address_1.parseAddress)("127.0.0.1:8080"); + expect(result).toEqual({ + host: "127.0.0.1", + port: 8080, + version: false, // IPv4 address + }); + }); + test("parses a valid IPv6 address", () => { + const result = (0, address_1.parseAddress)("[::1]:8080"); + expect(result).toEqual({ + host: "::1", + port: 8080, + version: true, // IPv6 address + }); + }); + test("returns null for an invalid port number", () => { + const result = (0, address_1.parseAddress)("127.0.0.1:70000"); // Invalid port + expect(result).toBeNull(); + }); + test("returns null for missing port", () => { + const result = (0, address_1.parseAddress)("127.0.0.1"); // No port provided + expect(result).toBeNull(); + }); + test("returns null for an invalid address format", () => { + const result = (0, address_1.parseAddress)("notAnAddress"); + expect(result).toBeNull(); + }); + test("parses domain names correctly", () => { + const result = (0, address_1.parseAddress)("example.com:8080"); + expect(result).toEqual({ + host: "example.com", + port: 8080, + version: null, // Domain names do not have IP versions + }); + }); + test("parses IPv6 with brackets and returns proper version", () => { + const result = (0, address_1.parseAddress)("[2001:db8::ff00:42:8329]:9090"); + expect(result).toEqual({ + host: "2001:db8::ff00:42:8329", + port: 9090, + version: true, // IPv6 address + }); + }); +}); diff --git a/src/ts/src/lib/address.test.ts b/src/ts/src/lib/address.test.ts new file mode 100644 index 000000000000..04759b583c10 --- /dev/null +++ b/src/ts/src/lib/address.test.ts @@ -0,0 +1,54 @@ +import { parseAddress } from './address'; + +describe("parseAddress", () => { + test("parses a valid IPv4 address", () => { + const result = parseAddress("127.0.0.1:8080"); + expect(result).toEqual({ + host: "127.0.0.1", + port: 8080, + version: false, // IPv4 address + }); + }); + + test("parses a valid IPv6 address", () => { + const result = parseAddress("[::1]:8080"); + expect(result).toEqual({ + host: "::1", + port: 8080, + version: true, // IPv6 address + }); + }); + + test("returns null for an invalid port number", () => { + const result = parseAddress("127.0.0.1:70000"); // Invalid port + expect(result).toBeNull(); + }); + + test("returns null for missing port", () => { + const result = parseAddress("127.0.0.1"); // No port provided + expect(result).toBeNull(); + }); + + test("returns null for an invalid address format", () => { + const result = parseAddress("notAnAddress"); + expect(result).toBeNull(); + }); + + test("parses domain names correctly", () => { + const result = parseAddress("example.com:8080"); + expect(result).toEqual({ + host: "example.com", + port: 8080, + version: null, // Domain names do not have IP versions + }); + }); + + test("parses IPv6 with brackets and returns proper version", () => { + const result = parseAddress("[2001:db8::ff00:42:8329]:9090"); + expect(result).toEqual({ + host: "2001:db8::ff00:42:8329", + port: 9090, + version: true, // IPv6 address + }); + }); +}); diff --git a/src/ts/src/lib/address.ts b/src/ts/src/lib/address.ts new file mode 100644 index 000000000000..b3413ab7025a --- /dev/null +++ b/src/ts/src/lib/address.ts @@ -0,0 +1,48 @@ +import { isIP } from "net"; + +const IPV6 = 6; +const IPV4 = 4; + +interface ParsedAddress { + host: string; + port: number; + version: boolean | null; +} + +export function parseAddress(address: string): ParsedAddress | null { + try { + const lastColonIndex = address.lastIndexOf(":"); + + if (lastColonIndex === -1) { + throw new Error("No port was provided."); + } + + // Split the address into host and port. + const rawHost = address.slice(0, lastColonIndex); + const rawPort = address.slice(lastColonIndex + 1); + + const port = parseInt(rawPort, 10); + + if (port > 65535 || port < 1) { + throw new Error("Port number is invalid."); + } + + let host = rawHost.replace(/[\[\]]/g, ""); // Remove brackets for IPv6 + let version: boolean | null = null; + + const ipVersion = isIP(host); + if (ipVersion === IPV6) { + version = true; + } else if (ipVersion === IPV4) { + version = false; + } + + return { + host, + port, + version, + }; + } catch (err) { + return null; + } +} diff --git a/src/ts/src/lib/client.js b/src/ts/src/lib/client.js new file mode 100644 index 000000000000..e7e0aa37589c --- /dev/null +++ b/src/ts/src/lib/client.js @@ -0,0 +1,111 @@ +"use strict"; +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== +Object.defineProperty(exports, "__esModule", { value: true }); +exports.Client = void 0; +exports.maybeCallGetProperties = maybeCallGetProperties; +exports.maybeCallGetParameters = maybeCallGetParameters; +exports.maybeCallFit = maybeCallFit; +exports.maybeCallEvaluate = maybeCallEvaluate; +const typing_1 = require("./typing"); +class BaseClient { + context; + constructor(context) { + this.context = context; + } + setContext(context) { + this.context = context; + } + getContext() { + return this.context; + } +} +class Client extends BaseClient { + getProperties(_ins) { + return { + status: { + code: typing_1.Code.GET_PROPERTIES_NOT_IMPLEMENTED, + message: "Client does not implement `get_properties`", + }, + properties: {}, + }; + } +} +exports.Client = Client; +function hasGetProperties(client) { + return client.getProperties !== undefined; +} +function hasGetParameters(client) { + return client.getParameters !== undefined; +} +function hasFit(client) { + return client.fit !== undefined; +} +function hasEvaluate(client) { + return client.evaluate !== undefined; +} +function maybeCallGetProperties(client, getPropertiesIns) { + if (!hasGetProperties(client)) { + const status = { + code: typing_1.Code.GET_PROPERTIES_NOT_IMPLEMENTED, + message: "Client does not implement `get_properties`", + }; + return { status, properties: {} }; + } + return client.getProperties(getPropertiesIns); +} +function maybeCallGetParameters(client, getParametersIns) { + if (!hasGetParameters(client)) { + const status = { + code: typing_1.Code.GET_PARAMETERS_NOT_IMPLEMENTED, + message: "Client does not implement `get_parameters`", + }; + return { + status, + parameters: { tensorType: "", tensors: [] }, + }; + } + return client.getParameters(getParametersIns); +} +function maybeCallFit(client, fitIns) { + if (!hasFit(client)) { + const status = { + code: typing_1.Code.FIT_NOT_IMPLEMENTED, + message: "Client does not implement `fit`", + }; + return { + status, + parameters: { tensorType: "", tensors: [] }, + numExamples: 0, + metrics: {}, + }; + } + return client.fit(fitIns); +} +function maybeCallEvaluate(client, evaluateIns) { + if (!hasEvaluate(client)) { + const status = { + code: typing_1.Code.EVALUATE_NOT_IMPLEMENTED, + message: "Client does not implement `evaluate`", + }; + return { + status, + loss: 0.0, + numExamples: 0, + metrics: {}, + }; + } + return client.evaluate(evaluateIns); +} diff --git a/src/ts/src/lib/client.test.js b/src/ts/src/lib/client.test.js new file mode 100644 index 000000000000..51fd2413c89d --- /dev/null +++ b/src/ts/src/lib/client.test.js @@ -0,0 +1,151 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const client_1 = require("./client"); +const typing_1 = require("./typing"); +// Mock classes for testing +class OverridingClient extends client_1.Client { + getParameters(_ins) { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + }; + } + fit(_ins) { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + numExamples: 1, + metrics: {}, + }; + } + evaluate(_ins) { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + loss: 1.0, + numExamples: 1, + metrics: {}, + }; + } + getProperties(_ins) { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + properties: {}, + }; + } +} +class NotOverridingClient extends client_1.Client { + getParameters(_ins) { + return { + status: { code: typing_1.Code.GET_PARAMETERS_NOT_IMPLEMENTED, message: "Not Implemented" }, + parameters: { tensors: [], tensorType: "" }, + }; + } + fit(_ins) { + return { + status: { code: typing_1.Code.FIT_NOT_IMPLEMENTED, message: "Not Implemented" }, + parameters: { tensors: [], tensorType: "" }, + numExamples: 0, + metrics: {}, + }; + } + evaluate(_ins) { + return { + status: { code: typing_1.Code.EVALUATE_NOT_IMPLEMENTED, message: "Not Implemented" }, + loss: 0.0, + numExamples: 0, + metrics: {}, + }; + } +} +// Test Suite for maybeCallGetProperties +describe("maybeCallGetProperties", () => { + it("should return OK when client implements getProperties", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallGetProperties)(client, {}); + expect(result.status.code).toBe(typing_1.Code.OK); + }); + it("should return GET_PROPERTIES_NOT_IMPLEMENTED when client does not implement getProperties", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallGetProperties)(client, {}); + expect(result.status.code).toBe(typing_1.Code.GET_PROPERTIES_NOT_IMPLEMENTED); + }); +}); +// Test Suite for maybeCallGetParameters +describe("maybeCallGetParameters", () => { + it("should return OK when client implements getParameters", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallGetParameters)(client, {}); + expect(result.status.code).toBe(typing_1.Code.OK); + }); + it("should return GET_PARAMETERS_NOT_IMPLEMENTED when client does not implement getParameters", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallGetParameters)(client, {}); + expect(result.status.code).toBe(typing_1.Code.GET_PARAMETERS_NOT_IMPLEMENTED); + }); +}); +// Test Suite for maybeCallFit +describe("maybeCallFit", () => { + it("should return OK when client implements fit", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallFit)(client, {}); + expect(result.status.code).toBe(typing_1.Code.OK); + }); + it("should return FIT_NOT_IMPLEMENTED when client does not implement fit", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallFit)(client, {}); + expect(result.status.code).toBe(typing_1.Code.FIT_NOT_IMPLEMENTED); + }); +}); +// Test Suite for maybeCallEvaluate +describe("maybeCallEvaluate", () => { + it("should return OK when client implements evaluate", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallEvaluate)(client, {}); + expect(result.status.code).toBe(typing_1.Code.OK); + }); + it("should return EVALUATE_NOT_IMPLEMENTED when client does not implement evaluate", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {}, + runConfig: {}, + }); + const result = (0, client_1.maybeCallEvaluate)(client, {}); + expect(result.status.code).toBe(typing_1.Code.EVALUATE_NOT_IMPLEMENTED); + }); +}); diff --git a/src/ts/src/lib/client.test.ts b/src/ts/src/lib/client.test.ts new file mode 100644 index 000000000000..562532290103 --- /dev/null +++ b/src/ts/src/lib/client.test.ts @@ -0,0 +1,181 @@ +import { + Client, + maybeCallGetProperties, + maybeCallGetParameters, + maybeCallFit, + maybeCallEvaluate, +} from "./client"; +import { + GetParametersIns, + GetPropertiesIns, + FitIns, + EvaluateIns, + Code, + GetParametersRes, + GetPropertiesRes, + FitRes, + EvaluateRes, +} from "./typing"; +import { RecordSet } from "./recordset"; + +// Mock classes for testing +class OverridingClient extends Client { + getParameters(_ins: GetParametersIns): GetParametersRes { + return { + status: { code: Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + }; + } + + fit(_ins: FitIns): FitRes { + return { + status: { code: Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + numExamples: 1, + metrics: {}, + }; + } + + evaluate(_ins: EvaluateIns): EvaluateRes { + return { + status: { code: Code.OK, message: "Success" }, + loss: 1.0, + numExamples: 1, + metrics: {}, + }; + } + + getProperties(_ins: GetPropertiesIns): GetPropertiesRes { + return { + status: { code: Code.OK, message: "Success" }, + properties: {}, + }; + } +} + +class NotOverridingClient extends Client { + getParameters(_ins: GetParametersIns): GetParametersRes { + return { + status: { code: Code.GET_PARAMETERS_NOT_IMPLEMENTED, message: "Not Implemented" }, + parameters: { tensors: [], tensorType: "" }, + }; + } + + fit(_ins: FitIns): FitRes { + return { + status: { code: Code.FIT_NOT_IMPLEMENTED, message: "Not Implemented" }, + parameters: { tensors: [], tensorType: "" }, + numExamples: 0, + metrics: {}, + }; + } + + evaluate(_ins: EvaluateIns): EvaluateRes { + return { + status: { code: Code.EVALUATE_NOT_IMPLEMENTED, message: "Not Implemented" }, + loss: 0.0, + numExamples: 0, + metrics: {}, + }; + } +} + +// Test Suite for maybeCallGetProperties +describe("maybeCallGetProperties", () => { + it("should return OK when client implements getProperties", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallGetProperties(client, {} as GetPropertiesIns); + expect(result.status.code).toBe(Code.OK); + }); + + it("should return GET_PROPERTIES_NOT_IMPLEMENTED when client does not implement getProperties", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallGetProperties(client, {} as GetPropertiesIns); + expect(result.status.code).toBe(Code.GET_PROPERTIES_NOT_IMPLEMENTED); + }); +}); + +// Test Suite for maybeCallGetParameters +describe("maybeCallGetParameters", () => { + it("should return OK when client implements getParameters", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallGetParameters(client, {} as GetParametersIns); + expect(result.status.code).toBe(Code.OK); + }); + + it("should return GET_PARAMETERS_NOT_IMPLEMENTED when client does not implement getParameters", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallGetParameters(client, {} as GetParametersIns); + expect(result.status.code).toBe(Code.GET_PARAMETERS_NOT_IMPLEMENTED); + }); +}); + +// Test Suite for maybeCallFit +describe("maybeCallFit", () => { + it("should return OK when client implements fit", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallFit(client, {} as FitIns); + expect(result.status.code).toBe(Code.OK); + }); + + it("should return FIT_NOT_IMPLEMENTED when client does not implement fit", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallFit(client, {} as FitIns); + expect(result.status.code).toBe(Code.FIT_NOT_IMPLEMENTED); + }); +}); + +// Test Suite for maybeCallEvaluate +describe("maybeCallEvaluate", () => { + it("should return OK when client implements evaluate", () => { + const client = new OverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallEvaluate(client, {} as EvaluateIns); + expect(result.status.code).toBe(Code.OK); + }); + + it("should return EVALUATE_NOT_IMPLEMENTED when client does not implement evaluate", () => { + const client = new NotOverridingClient({ + nodeId: BigInt(1), + nodeConfig: {}, + state: {} as RecordSet, + runConfig: {}, + }); + const result = maybeCallEvaluate(client, {} as EvaluateIns); + expect(result.status.code).toBe(Code.EVALUATE_NOT_IMPLEMENTED); + }); +}); diff --git a/src/ts/src/lib/client.ts b/src/ts/src/lib/client.ts new file mode 100644 index 000000000000..bd6230e52998 --- /dev/null +++ b/src/ts/src/lib/client.ts @@ -0,0 +1,138 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +import { + GetPropertiesIns, + GetPropertiesRes, + GetParametersRes, + GetParametersIns, + FitIns, + FitRes, + EvaluateRes, + EvaluateIns, + Context, + Code, + Status, +} from "./typing"; + +abstract class BaseClient { + protected context: Context; + + constructor(context: Context) { + this.context = context; + } + + setContext(context: Context) { + this.context = context; + } + + getContext(): Context { + return this.context; + } +} + +export abstract class Client extends BaseClient { + abstract getParameters(ins: GetParametersIns): GetParametersRes; + abstract fit(ins: FitIns): FitRes; + abstract evaluate(ins: EvaluateIns): EvaluateRes; + getProperties(_ins: GetPropertiesIns): GetPropertiesRes { + return { + status: { + code: Code.GET_PROPERTIES_NOT_IMPLEMENTED, + message: "Client does not implement `get_properties`", + }, + properties: {}, + }; + } +} + +function hasGetProperties(client: Client): boolean { + return client.getProperties !== undefined; +} + +function hasGetParameters(client: Client): boolean { + return client.getParameters !== undefined; +} + +function hasFit(client: Client): boolean { + return client.fit !== undefined; +} + +function hasEvaluate(client: Client): boolean { + return client.evaluate !== undefined; +} + +export function maybeCallGetProperties( + client: Client, + getPropertiesIns: GetPropertiesIns, +): GetPropertiesRes { + if (!hasGetProperties(client)) { + const status: Status = { + code: Code.GET_PROPERTIES_NOT_IMPLEMENTED, + message: "Client does not implement `get_properties`", + }; + return { status, properties: {} }; + } + return client.getProperties!(getPropertiesIns); +} + +export function maybeCallGetParameters( + client: Client, + getParametersIns: GetParametersIns, +): GetParametersRes { + if (!hasGetParameters(client)) { + const status: Status = { + code: Code.GET_PARAMETERS_NOT_IMPLEMENTED, + message: "Client does not implement `get_parameters`", + }; + return { + status, + parameters: { tensorType: "", tensors: [] }, + }; + } + return client.getParameters!(getParametersIns); +} + +export function maybeCallFit(client: Client, fitIns: FitIns): FitRes { + if (!hasFit(client)) { + const status: Status = { + code: Code.FIT_NOT_IMPLEMENTED, + message: "Client does not implement `fit`", + }; + return { + status, + parameters: { tensorType: "", tensors: [] }, + numExamples: 0, + metrics: {}, + }; + } + return client.fit!(fitIns); +} + +export function maybeCallEvaluate(client: Client, evaluateIns: EvaluateIns): EvaluateRes { + if (!hasEvaluate(client)) { + const status: Status = { + code: Code.EVALUATE_NOT_IMPLEMENTED, + message: "Client does not implement `evaluate`", + }; + return { + status, + loss: 0.0, + numExamples: 0, + metrics: {}, + }; + } + return client.evaluate!(evaluateIns); +} diff --git a/src/ts/src/lib/client_app.js b/src/ts/src/lib/client_app.js new file mode 100644 index 000000000000..d52ca095bdb1 --- /dev/null +++ b/src/ts/src/lib/client_app.js @@ -0,0 +1,137 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.LoadClientAppError = exports.ClientApp = exports.ClientAppException = void 0; +exports.makeFFN = makeFFN; +const typing_1 = require("./typing"); +const message_handler_1 = require("./message_handler"); +const logger_1 = require("./logger"); // Mock for warnings +function makeFFN(ffn, mods) { + function wrapFFN(_ffn, _mod) { + return function newFFN(message, context) { + return _mod(message, context, _ffn); // Call the mod with the message, context, and original ffn + }; + } + // Apply each mod to ffn, in reversed order + for (const mod of mods.reverse()) { + ffn = wrapFFN(ffn, mod); + } + return ffn; // Return the modified ffn +} +function alertErroneousClientFn() { + throw new Error("A `ClientApp` cannot make use of a `client_fn` that does not have a signature in the form: `function client_fn(context: Context)`. You can import the `Context` like this: `import { Context } from './common'`"); +} +function inspectMaybeAdaptClientFnSignature(clientFn) { + if (clientFn.length !== 1) { + alertErroneousClientFn(); + } + // const firstArg = clientFn.arguments[0]; + // if (typeof firstArg === "string") { + // warnDeprecatedFeature( + // "`clientFn` now expects a signature `function clientFn(context: Context)`. The provided `clientFn` has a signature `function clientFn(cid: string)`" + // ); + // return (context: Context): Client => { + // const cid = context.nodeConfig["partition-id"] || context.nodeId; + // return clientFn(cid as any); + // }; + // } + return clientFn; +} +class ClientAppException extends Error { + constructor(message) { + const exName = "ClientAppException"; + super(`\nException ${exName} occurred. Message: ${message}`); + this.name = exName; + } +} +exports.ClientAppException = ClientAppException; +class ClientApp { + _mods; + _call = null; + _train = null; + _evaluate = null; + _query = null; + constructor(clientFn, mods) { + this._mods = mods || []; + if (clientFn) { + clientFn = inspectMaybeAdaptClientFnSignature(clientFn); + const ffn = (message, context) => { + return (0, message_handler_1.handleLegacyMessageFromMsgType)(clientFn, message, context); + }; + this._call = makeFFN(ffn, this._mods); + } + } + call(message, context) { + if (this._call) { + return this._call(message, context); + } + switch (message.metadata.messageType) { + case typing_1.MessageType.TRAIN: + if (this._train) + return this._train(message, context); + throw new Error("No `train` function registered"); + case typing_1.MessageType.EVALUATE: + if (this._evaluate) + return this._evaluate(message, context); + throw new Error("No `evaluate` function registered"); + case typing_1.MessageType.QUERY: + if (this._query) + return this._query(message, context); + throw new Error("No `query` function registered"); + default: + throw new Error(`Unknown message_type: ${message.metadata.messageType}`); + } + } + train() { + return (trainFn) => { + if (this._call) { + throw registrationError("train"); + } + (0, logger_1.warnPreviewFeature)("ClientApp-register-train-function"); + this._train = makeFFN(trainFn, this._mods); + return trainFn; + }; + } + evaluate() { + return (evaluateFn) => { + if (this._call) { + throw registrationError("evaluate"); + } + (0, logger_1.warnPreviewFeature)("ClientApp-register-evaluate-function"); + this._evaluate = makeFFN(evaluateFn, this._mods); + return evaluateFn; + }; + } + query() { + return (queryFn) => { + if (this._call) { + throw registrationError("query"); + } + (0, logger_1.warnPreviewFeature)("ClientApp-register-query-function"); + this._query = makeFFN(queryFn, this._mods); + return queryFn; + }; + } +} +exports.ClientApp = ClientApp; +class LoadClientAppError extends Error { + constructor(message) { + super(message); + this.name = "LoadClientAppError"; + } +} +exports.LoadClientAppError = LoadClientAppError; +function registrationError(fnName) { + return new Error(`Use either \`@app.${fnName}()\` or \`clientFn\`, but not both.\n\n` + + `Use the \`ClientApp\` with an existing \`clientFn\`:\n\n` + + `\`\`\`\nclass FlowerClient extends NumPyClient {}\n\n` + + `function clientFn(context: Context) {\n` + + ` return new FlowerClient().toClient();\n` + + `}\n\n` + + `const app = new ClientApp({ clientFn });\n\`\`\`\n\n` + + `Use the \`ClientApp\` with a custom ${fnName} function:\n\n` + + `\`\`\`\nconst app = new ClientApp();\n\n` + + `app.${fnName}((message, context) => {\n` + + ` console.log("ClientApp ${fnName} running");\n` + + ` return message.createReply({ content: message.content });\n` + + `});\n\`\`\`\n`); +} diff --git a/src/ts/src/lib/client_app.ts b/src/ts/src/lib/client_app.ts new file mode 100644 index 000000000000..04e1891bfd19 --- /dev/null +++ b/src/ts/src/lib/client_app.ts @@ -0,0 +1,157 @@ +import { Context, Message, MessageType, ClientFnExt, Mod, ClientAppCallable } from "./typing"; +import { Client } from "./client"; +import { handleLegacyMessageFromMsgType } from "./message_handler"; +import { warnDeprecatedFeature, warnPreviewFeature } from "./logger"; // Mock for warnings + +export function makeFFN(ffn: ClientAppCallable, mods: Mod[]): ClientAppCallable { + function wrapFFN(_ffn: ClientAppCallable, _mod: Mod): ClientAppCallable { + return function newFFN(message: Message, context: Context): Message { + return _mod(message, context, _ffn); // Call the mod with the message, context, and original ffn + }; + } + + // Apply each mod to ffn, in reversed order + for (const mod of mods.reverse()) { + ffn = wrapFFN(ffn, mod); + } + + return ffn; // Return the modified ffn +} + +function alertErroneousClientFn(): void { + throw new Error( + "A `ClientApp` cannot make use of a `client_fn` that does not have a signature in the form: `function client_fn(context: Context)`. You can import the `Context` like this: `import { Context } from './common'`" + ); +} + +function inspectMaybeAdaptClientFnSignature(clientFn: ClientFnExt): ClientFnExt { + if (clientFn.length !== 1) { + alertErroneousClientFn(); + } + + // const firstArg = clientFn.arguments[0]; + + // if (typeof firstArg === "string") { + // warnDeprecatedFeature( + // "`clientFn` now expects a signature `function clientFn(context: Context)`. The provided `clientFn` has a signature `function clientFn(cid: string)`" + // ); + + // return (context: Context): Client => { + // const cid = context.nodeConfig["partition-id"] || context.nodeId; + // return clientFn(cid as any); + // }; + // } + + return clientFn; +} + +export class ClientAppException extends Error { + constructor(message: string) { + const exName = "ClientAppException"; + super(`\nException ${exName} occurred. Message: ${message}`); + this.name = exName; + } +} + +export class ClientApp { + private _mods: Mod[]; + private _call: ClientAppCallable | null = null; + private _train: ClientAppCallable | null = null; + private _evaluate: ClientAppCallable | null = null; + private _query: ClientAppCallable | null = null; + + constructor(clientFn?: ClientFnExt, mods?: Mod[]) { + this._mods = mods || []; + + if (clientFn) { + clientFn = inspectMaybeAdaptClientFnSignature(clientFn); + + const ffn: ClientAppCallable = (message, context) => { + return handleLegacyMessageFromMsgType(clientFn!, message, context); + }; + + this._call = makeFFN(ffn, this._mods); + } + } + + call(message: Message, context: Context): Message { + if (this._call) { + return this._call(message, context); + } + + switch (message.metadata.messageType) { + case MessageType.TRAIN: + if (this._train) return this._train(message, context); + throw new Error("No `train` function registered"); + case MessageType.EVALUATE: + if (this._evaluate) return this._evaluate(message, context); + throw new Error("No `evaluate` function registered"); + case MessageType.QUERY: + if (this._query) return this._query(message, context); + throw new Error("No `query` function registered"); + default: + throw new Error(`Unknown message_type: ${message.metadata.messageType}`); + } + } + + train(): (trainFn: ClientAppCallable) => ClientAppCallable { + return (trainFn: ClientAppCallable) => { + if (this._call) { + throw registrationError("train"); + } + + warnPreviewFeature("ClientApp-register-train-function"); + this._train = makeFFN(trainFn, this._mods); + return trainFn; + }; + } + + evaluate(): (evaluateFn: ClientAppCallable) => ClientAppCallable { + return (evaluateFn: ClientAppCallable) => { + if (this._call) { + throw registrationError("evaluate"); + } + + warnPreviewFeature("ClientApp-register-evaluate-function"); + this._evaluate = makeFFN(evaluateFn, this._mods); + return evaluateFn; + }; + } + + query(): (queryFn: ClientAppCallable) => ClientAppCallable { + return (queryFn: ClientAppCallable) => { + if (this._call) { + throw registrationError("query"); + } + + warnPreviewFeature("ClientApp-register-query-function"); + this._query = makeFFN(queryFn, this._mods); + return queryFn; + }; + } +} + +export class LoadClientAppError extends Error { + constructor(message: string) { + super(message); + this.name = "LoadClientAppError"; + } +} + +function registrationError(fnName: string): Error { + return new Error( + `Use either \`@app.${fnName}()\` or \`clientFn\`, but not both.\n\n` + + `Use the \`ClientApp\` with an existing \`clientFn\`:\n\n` + + `\`\`\`\nclass FlowerClient extends NumPyClient {}\n\n` + + `function clientFn(context: Context) {\n` + + ` return new FlowerClient().toClient();\n` + + `}\n\n` + + `const app = new ClientApp({ clientFn });\n\`\`\`\n\n` + + `Use the \`ClientApp\` with a custom ${fnName} function:\n\n` + + `\`\`\`\nconst app = new ClientApp();\n\n` + + `app.${fnName}((message, context) => {\n` + + ` console.log("ClientApp ${fnName} running");\n` + + ` return message.createReply({ content: message.content });\n` + + `});\n\`\`\`\n` + ); +} diff --git a/src/ts/src/lib/client_interceptor.js b/src/ts/src/lib/client_interceptor.js new file mode 100644 index 000000000000..3850ffdebcac --- /dev/null +++ b/src/ts/src/lib/client_interceptor.js @@ -0,0 +1,58 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.AUTH_TOKEN_HEADER = exports.PUBLIC_KEY_HEADER = void 0; +exports.AuthenticateClientInterceptor = AuthenticateClientInterceptor; +const crypto_helpers_1 = require("./crypto_helpers"); +exports.PUBLIC_KEY_HEADER = "public-key"; +exports.AUTH_TOKEN_HEADER = "auth-token"; +// Helper function to extract values from metadata +function getValueFromMetadata(key, metadata) { + const values = metadata[key]; + return values.length > 0 && typeof values[0] === "string" ? values[0] : ""; +} +function base64UrlEncode(buffer) { + return buffer + .toString("base64") // Standard Base64 encoding + .replace(/\+/g, "-") // Replace + with - + .replace(/\//g, "_") // Replace / with _ + .replace(/=+$/, ""); // Remove padding (trailing = characters) +} +function AuthenticateClientInterceptor(privateKey, publicKey) { + let sharedSecret = null; + let serverPublicKey = null; + // Convert the public key to bytes and encode it + const encodedPublicKey = base64UrlEncode((0, crypto_helpers_1.publicKeyToBytes)(publicKey)); + return { + interceptUnary(next, method, input, options) { + // Manipulate metadata before sending the request + const metadata = options.meta || {}; + // Always add the public key to the metadata + metadata[exports.PUBLIC_KEY_HEADER] = encodedPublicKey; + const postprocess = "pingInterval" in input; + // Add HMAC to metadata if a shared secret exists + if (sharedSecret !== null) { + // Assuming the message is already serialized and available at this point + const serializedMessage = method.I.toBinary(input); + const hmac = (0, crypto_helpers_1.computeHMAC)(sharedSecret, Buffer.from(serializedMessage)); + metadata[exports.AUTH_TOKEN_HEADER] = base64UrlEncode(hmac); + } + const continuation = next(method, input, { ...options, meta: metadata }); + if (postprocess) { + handlePostprocess(metadata); + } + return continuation; + }, + }; + function handlePostprocess(metadata) { + const serverPublicKeyBytes = getValueFromMetadata(exports.PUBLIC_KEY_HEADER, metadata); + if (serverPublicKeyBytes.length > 0) { + serverPublicKey = (0, crypto_helpers_1.bytesToPublicKey)(Buffer.from(serverPublicKeyBytes)); + } + else { + console.warn("Couldn't get server public key, server may be offline"); + } + if (serverPublicKey) { + sharedSecret = (0, crypto_helpers_1.generateSharedKey)(privateKey, serverPublicKey); + } + } +} diff --git a/src/ts/src/lib/client_interceptor.test.js b/src/ts/src/lib/client_interceptor.test.js new file mode 100644 index 000000000000..b6f31e34e4ff --- /dev/null +++ b/src/ts/src/lib/client_interceptor.test.js @@ -0,0 +1,138 @@ +"use strict"; +var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { enumerable: true, get: function() { return m[k]; } }; + } + Object.defineProperty(o, k2, desc); +}) : (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; +})); +var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { + Object.defineProperty(o, "default", { enumerable: true, value: v }); +}) : function(o, v) { + o["default"] = v; +}); +var __importStar = (this && this.__importStar) || function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +const grpc = __importStar(require("@grpc/grpc-js")); +const crypto_helpers_1 = require("./crypto_helpers"); +const crypto_1 = require("crypto"); +const client_interceptor_1 = require("./client_interceptor"); +const elliptic_1 = require("elliptic"); +const ec = new elliptic_1.ec("p256"); +// Mock Servicer for testing +class MockServicer { + receivedMetadata = null; + messageBytes = null; + serverPrivateKey = null; + serverPublicKey = null; + constructor() { + // Asynchronous key generation using elliptic curve + (0, crypto_1.generateKeyPair)("ec", { namedCurve: "secp256k1" }, (err, publicKey, privateKey) => { + if (err) + throw err; + this.serverPrivateKey = privateKey.export({ format: "pem", type: "pkcs8" }).toString(); + this.serverPublicKey = publicKey.export({ format: "pem", type: "spki" }).toString(); + }); + } + handleUnaryCall(call, callback) { + this.receivedMetadata = call.metadata; + this.messageBytes = call.request.serializeBinary(); + const publicKeyBytes = (0, crypto_helpers_1.publicKeyToBytes)(ec.keyFromPublic(this.serverPublicKey)); // ECC key handling + if ("pingInterval" in call.request) { + const responseMetadata = new grpc.Metadata(); + responseMetadata.add(client_interceptor_1.PUBLIC_KEY_HEADER, Buffer.from(publicKeyBytes).toString("base64")); + callback(null, { nodeId: 123 }, responseMetadata); + } + else if ("node" in call.request) { + callback(null, {}, undefined); + } + else { + callback({ code: grpc.status.INVALID_ARGUMENT, message: "Unknown request" }); + } + } + getReceivedMetadata() { + return this.receivedMetadata; + } + getMessageBytes() { + return this.messageBytes; + } + getServerPublicKey() { + return this.serverPublicKey; + } + getServerPrivateKey() { + return this.serverPrivateKey; + } +} +// Setup and teardown for tests +let mockServer; +beforeAll(() => { + mockServer = new MockServicer(); +}); +afterAll(() => { + // Stop server if necessary +}); +// Test: Authenticate Client with Create Node +// test('should authenticate client with create node', async () => { +// const retryInvoker = {}; // Mock retry invoker +// const { privateKey, publicKey } = generateKeyPairSync('rsa'); // Assume key generation logic +// const interceptor = AuthenticateClientInterceptor(privateKey, publicKey); +// const transport = new GrpcTransport({ +// host: "localhost:50051", +// channelCredentials: grpc.credentials.createInsecure(), +// interceptors: [interceptor], +// }); +// const client = new FleetClient(transport); +// // const client = new grpc.Client('localhost:50051', grpc.credentials.createInsecure(), { +// // interceptors: [interceptor], +// // }); +// const request = CreateNodeRequest.create(); +// const response = await client.unaryUnary(request); +// client.makeUnaryRequest(request) +// const receivedMetadata = mockServer.getReceivedMetadata(); +// expect(receivedMetadata.get(PUBLIC_KEY_HEADER)).toBeTruthy(); +// const sharedSecret = generateSharedKey(mockServer.getServerPrivateKey(), publicKey); +// const hmac = computeHMAC(sharedSecret, mockServer.getMessageBytes()); +// expect(receivedMetadata.get(AUTH_TOKEN_HEADER)).toEqual(hmac); +// }); +// // Test: Authenticate Client with Delete Node +// test('should authenticate client with delete node', async () => { +// const retryInvoker = {}; // Mock retry invoker +// const { privateKey, publicKey } = generateKeyPairs(); +// const interceptor = AuthenticateClientInterceptor(privateKey, publicKey); +// const client = new grpc.Client('localhost:50051', grpc.credentials.createInsecure(), { +// interceptors: [interceptor], +// }); +// const request = DeleteNodeRequest.create(); +// const response = await client.unaryUnary(request); +// const receivedMetadata = mockServer.getReceivedMetadata(); +// expect(receivedMetadata!.get(PUBLIC_KEY_HEADER)).toBeTruthy(); +// const sharedSecret = generateSharedKey(mockServer.getServerPrivateKey(), publicKey); +// const hmac = computeHMAC(sharedSecret, mockServer.getMessageBytes()); +// expect(receivedMetadata!.get(AUTH_TOKEN_HEADER)).toEqual(hmac); +// }); +// // Test: Authenticate Client with Get Run +// test('should authenticate client with get run', async () => { +// const retryInvoker = {}; // Mock retry invoker +// const { privateKey, publicKey } = generateKeyPairs(); +// const interceptor = AuthenticateClientInterceptor(privateKey, publicKey); +// const client = new grpc.Client('localhost:50051', grpc.credentials.createInsecure(), { +// interceptors: [interceptor], +// }); +// const request = GetRunRequest.create(); +// const response = await client.unaryUnary(request); +// const receivedMetadata = mockServer.getReceivedMetadata(); +// expect(receivedMetadata.get(PUBLIC_KEY_HEADER)).toBeTruthy(); +// const sharedSecret = generateSharedKey(mockServer.getServerPrivateKey(), publicKey); +// const hmac = computeHMAC(sharedSecret, mockServer.getMessageBytes()); +// expect(receivedMetadata.get(AUTH_TOKEN_HEADER)).toEqual(hmac); +// }); diff --git a/src/ts/src/lib/client_interceptor.test.ts b/src/ts/src/lib/client_interceptor.test.ts new file mode 100644 index 000000000000..7990ce2aed27 --- /dev/null +++ b/src/ts/src/lib/client_interceptor.test.ts @@ -0,0 +1,152 @@ +import * as grpc from "@grpc/grpc-js"; +import { computeHMAC, generateSharedKey, publicKeyToBytes } from "./crypto_helpers"; +import { generateKeyPair, generateKeyPairSync } from "crypto"; +import { + AUTH_TOKEN_HEADER, + PUBLIC_KEY_HEADER, + AuthenticateClientInterceptor, +} from "./client_interceptor"; +import { GrpcTransport } from "@protobuf-ts/grpc-transport"; +import { CreateNodeRequest, DeleteNodeRequest } from "../protos/flwr/proto/fleet"; +import { FleetClient } from "../protos/flwr/proto/fleet.client"; +import { GetRunRequest } from "../protos/flwr/proto/run"; +import { ec as EC } from "elliptic"; + +const ec = new EC("p256"); + +// Mock Servicer for testing +class MockServicer { + private receivedMetadata: grpc.Metadata | null = null; + private messageBytes: Buffer | null = null; + private serverPrivateKey: string | null = null; + private serverPublicKey: string | null = null; + + constructor() { + // Asynchronous key generation using elliptic curve + generateKeyPair("ec", { namedCurve: "secp256k1" }, (err, publicKey, privateKey) => { + if (err) throw err; + this.serverPrivateKey = privateKey.export({ format: "pem", type: "pkcs8" }).toString(); + this.serverPublicKey = publicKey.export({ format: "pem", type: "spki" }).toString(); + }); + } + + handleUnaryCall(call: grpc.ServerUnaryCall, callback: grpc.sendUnaryData) { + this.receivedMetadata = call.metadata; + this.messageBytes = call.request.serializeBinary(); + + const publicKeyBytes = publicKeyToBytes(ec.keyFromPublic(this.serverPublicKey!)); // ECC key handling + + if ("pingInterval" in call.request) { + const responseMetadata = new grpc.Metadata(); + responseMetadata.add(PUBLIC_KEY_HEADER, Buffer.from(publicKeyBytes).toString("base64")); + callback(null, { nodeId: 123 }, responseMetadata); + } else if ("node" in call.request) { + callback(null, {}, undefined); + } else { + callback({ code: grpc.status.INVALID_ARGUMENT, message: "Unknown request" }); + } + } + + getReceivedMetadata() { + return this.receivedMetadata; + } + + getMessageBytes() { + return this.messageBytes; + } + + getServerPublicKey() { + return this.serverPublicKey; + } + + getServerPrivateKey() { + return this.serverPrivateKey; + } +} + +// Setup and teardown for tests +let mockServer: MockServicer; + +beforeAll(() => { + mockServer = new MockServicer(); +}); + +afterAll(() => { + // Stop server if necessary +}); + +// Test: Authenticate Client with Create Node +// test('should authenticate client with create node', async () => { +// const retryInvoker = {}; // Mock retry invoker +// const { privateKey, publicKey } = generateKeyPairSync('rsa'); // Assume key generation logic + +// const interceptor = AuthenticateClientInterceptor(privateKey, publicKey); +// const transport = new GrpcTransport({ +// host: "localhost:50051", +// channelCredentials: grpc.credentials.createInsecure(), +// interceptors: [interceptor], +// }); +// const client = new FleetClient(transport); + +// // const client = new grpc.Client('localhost:50051', grpc.credentials.createInsecure(), { +// // interceptors: [interceptor], +// // }); + +// const request = CreateNodeRequest.create(); +// const response = await client.unaryUnary(request); +// client.makeUnaryRequest(request) + +// const receivedMetadata = mockServer.getReceivedMetadata(); +// expect(receivedMetadata.get(PUBLIC_KEY_HEADER)).toBeTruthy(); + +// const sharedSecret = generateSharedKey(mockServer.getServerPrivateKey(), publicKey); +// const hmac = computeHMAC(sharedSecret, mockServer.getMessageBytes()); + +// expect(receivedMetadata.get(AUTH_TOKEN_HEADER)).toEqual(hmac); +// }); + +// // Test: Authenticate Client with Delete Node +// test('should authenticate client with delete node', async () => { +// const retryInvoker = {}; // Mock retry invoker +// const { privateKey, publicKey } = generateKeyPairs(); + +// const interceptor = AuthenticateClientInterceptor(privateKey, publicKey); + +// const client = new grpc.Client('localhost:50051', grpc.credentials.createInsecure(), { +// interceptors: [interceptor], +// }); + +// const request = DeleteNodeRequest.create(); +// const response = await client.unaryUnary(request); + +// const receivedMetadata = mockServer.getReceivedMetadata(); +// expect(receivedMetadata!.get(PUBLIC_KEY_HEADER)).toBeTruthy(); + +// const sharedSecret = generateSharedKey(mockServer.getServerPrivateKey(), publicKey); +// const hmac = computeHMAC(sharedSecret, mockServer.getMessageBytes()); + +// expect(receivedMetadata!.get(AUTH_TOKEN_HEADER)).toEqual(hmac); +// }); + +// // Test: Authenticate Client with Get Run +// test('should authenticate client with get run', async () => { +// const retryInvoker = {}; // Mock retry invoker +// const { privateKey, publicKey } = generateKeyPairs(); + +// const interceptor = AuthenticateClientInterceptor(privateKey, publicKey); + +// const client = new grpc.Client('localhost:50051', grpc.credentials.createInsecure(), { +// interceptors: [interceptor], +// }); + +// const request = GetRunRequest.create(); +// const response = await client.unaryUnary(request); + +// const receivedMetadata = mockServer.getReceivedMetadata(); +// expect(receivedMetadata.get(PUBLIC_KEY_HEADER)).toBeTruthy(); + +// const sharedSecret = generateSharedKey(mockServer.getServerPrivateKey(), publicKey); +// const hmac = computeHMAC(sharedSecret, mockServer.getMessageBytes()); + +// expect(receivedMetadata.get(AUTH_TOKEN_HEADER)).toEqual(hmac); +// }); diff --git a/src/ts/src/lib/client_interceptor.ts b/src/ts/src/lib/client_interceptor.ts new file mode 100644 index 000000000000..b1e3f9e0e095 --- /dev/null +++ b/src/ts/src/lib/client_interceptor.ts @@ -0,0 +1,88 @@ +import { ec as EC } from "elliptic"; +import { + MethodInfo, + RpcMetadata, + UnaryCall, + NextUnaryFn, + RpcOptions, + RpcInterceptor, +} from "@protobuf-ts/runtime-rpc"; +import { + computeHMAC, + bytesToPublicKey, + publicKeyToBytes, + generateSharedKey, +} from "./crypto_helpers"; + +export const PUBLIC_KEY_HEADER = "public-key"; +export const AUTH_TOKEN_HEADER = "auth-token"; + +// Helper function to extract values from metadata +function getValueFromMetadata(key: string, metadata: RpcMetadata): string { + const values = metadata[key]; + return values.length > 0 && typeof values[0] === "string" ? values[0] : ""; +} + +function base64UrlEncode(buffer: Buffer): string { + return buffer + .toString("base64") // Standard Base64 encoding + .replace(/\+/g, "-") // Replace + with - + .replace(/\//g, "_") // Replace / with _ + .replace(/=+$/, ""); // Remove padding (trailing = characters) +} + +export function AuthenticateClientInterceptor( + privateKey: EC.KeyPair, + publicKey: EC.KeyPair, +): RpcInterceptor { + let sharedSecret: Buffer | null = null; + let serverPublicKey: EC.KeyPair | null = null; + + // Convert the public key to bytes and encode it + const encodedPublicKey = base64UrlEncode(publicKeyToBytes(publicKey)); + + return { + interceptUnary( + next: NextUnaryFn, + method: MethodInfo, + input: object, + options: RpcOptions, + ): UnaryCall { + // Manipulate metadata before sending the request + const metadata: RpcMetadata = options.meta || {}; + + // Always add the public key to the metadata + metadata[PUBLIC_KEY_HEADER] = encodedPublicKey; + + const postprocess = "pingInterval" in input; + + // Add HMAC to metadata if a shared secret exists + if (sharedSecret !== null) { + // Assuming the message is already serialized and available at this point + const serializedMessage = method.I.toBinary(input); + const hmac = computeHMAC(sharedSecret, Buffer.from(serializedMessage)); + metadata[AUTH_TOKEN_HEADER] = base64UrlEncode(hmac); + } + + const continuation = next(method, input, { ...options, meta: metadata }); + if (postprocess) { + handlePostprocess(metadata); + } + return continuation; + }, + }; + + function handlePostprocess(metadata: RpcMetadata): void { + const serverPublicKeyBytes = getValueFromMetadata(PUBLIC_KEY_HEADER, metadata); + + if (serverPublicKeyBytes.length > 0) { + serverPublicKey = bytesToPublicKey(Buffer.from(serverPublicKeyBytes)); + } else { + console.warn("Couldn't get server public key, server may be offline"); + } + + if (serverPublicKey) { + sharedSecret = generateSharedKey(privateKey, serverPublicKey); + } + } +} diff --git a/src/ts/src/lib/config.js b/src/ts/src/lib/config.js new file mode 100644 index 000000000000..5b322f7868dc --- /dev/null +++ b/src/ts/src/lib/config.js @@ -0,0 +1,363 @@ +"use strict"; +var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { enumerable: true, get: function() { return m[k]; } }; + } + Object.defineProperty(o, k2, desc); +}) : (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; +})); +var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { + Object.defineProperty(o, "default", { enumerable: true, value: v }); +}) : function(o, v) { + o["default"] = v; +}); +var __importStar = (this && this.__importStar) || function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.getFabConfig = getFabConfig; +exports.getFabMetadata = getFabMetadata; +exports.loadAndValidate = loadAndValidate; +exports.load = load; +exports.validateFields = validateFields; +exports.validate = validate; +exports.loadFromString = loadFromString; +exports.getFlwrDir = getFlwrDir; +exports.getProjectDir = getProjectDir; +exports.getProjectConfig = getProjectConfig; +exports.fuseDicts = fuseDicts; +exports.getFusedConfigFromDir = getFusedConfigFromDir; +exports.getFusedConfigFromFab = getFusedConfigFromFab; +exports.getFusedConfig = getFusedConfig; +exports.flattenDict = flattenDict; +exports.unflattenDict = unflattenDict; +exports.parseConfigArgs = parseConfigArgs; +exports.getMetadataFromConfig = getMetadataFromConfig; +const fs = __importStar(require("fs")); +const os = __importStar(require("os")); +const path = __importStar(require("path")); +const toml = __importStar(require("@iarna/toml")); +const constants_1 = require("./constants"); +function getFabConfig(fabFile) { + /** + * Extract the config from a FAB file or path. + * @param fabFile The Flower App Bundle file to validate and extract the metadata from. + * It can either be a path to the file or the file itself as bytes. + */ + // let fabFileArchive: PathLike | Buffer; + // if (Buffer.isBuffer(fabFile)) { + // fabFileArchive = fabFile; + // } else if (typeof fabFile === 'string') { + // fabFileArchive = path.resolve(fabFile); + // } else { + // throw new Error('fabFile must be either a Path or Buffer'); + // } + // // Unzip the FAB file to read pyproject.toml + // const zip = open(fabFileArchive.toString()); + // const zipFile = new ZipFile(zip); + // let tomlContent = ''; + // // Read pyproject.toml from the archive + // zipFile.on('entry', (entry) => { + // if (entry.fileName === 'pyproject.toml') { + // const readStream = zipFile.openReadStream(entry); + // if (readStream) { + // readStream.pipe(zlib.createGunzip()).on('data', (data) => { + // tomlContent += data.toString('utf-8'); + // }); + // } + // } + // }); + // // Load TOML content + // const config = loadFromString(tomlContent); + // if (!config) { + // throw new Error('Invalid TOML content in pyproject.toml'); + // } + // const [isValid, errors] = validate(config, false); + // if (!isValid) { + // throw new Error(errors.join('\n')); + // } + // return config; + return {}; +} +function getFabMetadata(fabFile) { + /** + * Extract the fab_id and fab_version from a FAB file or path. + * @param fabFile The Flower App Bundle file to validate and extract the metadata from. + * It can either be a path to the file or the file itself as bytes. + */ + const config = getFabConfig(fabFile); + return [ + `${config['tool']['flwr']['app']['publisher']}/${config['project']['name']}`, + config['project']['version'], + ]; +} +function loadAndValidate(providedPath = null, checkModule = true) { + /** + * Load and validate pyproject.toml as a dictionary. + * @param providedPath Optional path to pyproject.toml. + * @param checkModule Whether to check module validity. + */ + const configPath = providedPath ? path.resolve(providedPath.toString()) : path.join(process.cwd(), 'pyproject.toml'); + const config = load(configPath); + if (!config) { + return [ + null, + ['Project configuration could not be loaded. `pyproject.toml` does not exist.'], + [], + ]; + } + const [isValid, errors, warnings] = validate(config, checkModule, path.dirname(configPath)); + if (!isValid) { + return [null, errors, warnings]; + } + return [config, errors, warnings]; +} +function load(tomlPath) { + /** + * Load pyproject.toml and return as a dictionary. + */ + if (!fs.existsSync(tomlPath)) { + return null; + } + const tomlContent = fs.readFileSync(tomlPath, { encoding: 'utf-8' }); + return loadFromString(tomlContent); +} +function _validateRunConfig(configDict, errors) { + for (const [key, value] of Object.entries(configDict)) { + if (typeof value === 'object' && !Array.isArray(value)) { + _validateRunConfig(value, errors); + } + // else if (!getArgs(UserConfigValue).includes(typeof value)) { + // errors.push( + // `The value for key ${key} needs to be of type int, float, bool, string, or a dict of those.` + // ); + // } + } +} +function validateFields(config) { + /** + * Validate pyproject.toml fields. + */ + const errors = []; + const warnings = []; + if (!config['project']) { + errors.push('Missing [project] section'); + } + else { + if (!config['project']['name']) { + errors.push('Property "name" missing in [project]'); + } + if (!config['project']['version']) { + errors.push('Property "version" missing in [project]'); + } + if (!config['project']['description']) { + warnings.push('Recommended property "description" missing in [project]'); + } + if (!config['project']['license']) { + warnings.push('Recommended property "license" missing in [project]'); + } + if (!config['project']['authors']) { + warnings.push('Recommended property "authors" missing in [project]'); + } + } + if (!config['tool'] || + !config['tool']['flwr'] || + !config['tool']['flwr']['app']) { + errors.push('Missing [tool.flwr.app] section'); + } + else { + if (!config['tool']['flwr']['app']['publisher']) { + errors.push('Property "publisher" missing in [tool.flwr.app]'); + } + if (config['tool']['flwr']['app']['config']) { + _validateRunConfig(config['tool']['flwr']['app']['config'], errors); + } + if (!config['tool']['flwr']['app']['components']) { + errors.push('Missing [tool.flwr.app.components] section'); + } + else { + if (!config['tool']['flwr']['app']['components']['serverapp']) { + errors.push('Property "serverapp" missing in [tool.flwr.app.components]'); + } + if (!config['tool']['flwr']['app']['components']['clientapp']) { + errors.push('Property "clientapp" missing in [tool.flwr.app.components]'); + } + } + } + return [errors.length === 0, errors, warnings]; +} +function validate(config, checkModule = true, projectDir = null) { + /** + * Validate pyproject.toml. + */ + const [isValid, errors, warnings] = validateFields(config); + if (!isValid) { + return [false, errors, warnings]; + } + // Validate serverapp + const serverappRef = config['tool']['flwr']['app']['components']['serverapp']; + // const [serverIsValid, serverReason] = objectRef.validate(serverappRef, checkModule, projectDir); + // if (!serverIsValid && typeof serverReason === 'string') { + // return [false, [serverReason], []]; + // } + // Validate clientapp + const clientappRef = config['tool']['flwr']['app']['components']['clientapp']; + // const [clientIsValid, clientReason] = objectRef.validate(clientappRef, checkModule, projectDir); + // if (!clientIsValid && typeof clientReason === 'string') { + // return [false, [clientReason], []]; + // } + return [true, [], []]; +} +function loadFromString(tomlContent) { + /** + * Load TOML content from a string and return as a dictionary. + */ + try { + return toml.parse(tomlContent); + } + catch (error) { + return null; + } +} +// Get Flower home directory based on environment variables +function getFlwrDir(providedPath) { + if (!providedPath || !fs.existsSync(providedPath)) { + return path.join(process.env[constants_1.FLWR_HOME] || path.join(process.env['XDG_DATA_HOME'] || os.homedir(), '.flwr')); + } + return path.resolve(providedPath); +} +// Return the project directory based on fab_id and fab_version +function getProjectDir(fabId, fabVersion, flwrDir) { + if ((fabId.match(/\//g) || []).length !== 1) { + throw new Error(`Invalid FAB ID: ${fabId}`); + } + const [publisher, projectName] = fabId.split('/'); + flwrDir = flwrDir || getFlwrDir(); + return path.join(flwrDir, constants_1.APP_DIR, publisher, projectName, fabVersion); +} +// Return pyproject.toml configuration from the project directory +function getProjectConfig(projectDir) { + const tomlPath = path.join(projectDir, constants_1.FAB_CONFIG_FILE); + if (!fs.existsSync(tomlPath)) { + throw new Error(`Cannot find ${constants_1.FAB_CONFIG_FILE} in ${projectDir}`); + } + const fileContents = fs.readFileSync(tomlPath, 'utf8'); + const config = toml.parse(fileContents); + const [isValid, _warnings, errors] = validateFields(config); + if (!isValid) { + const errorMsg = errors.map((error) => ` - ${error}`).join('\n'); + throw new Error(`Invalid ${constants_1.FAB_CONFIG_FILE}:\n${errorMsg}`); + } + return config; +} +// Merge a config with the overrides +function fuseDicts(mainDict, overrideDict) { + const fusedDict = { ...mainDict }; + Object.entries(overrideDict).forEach(([key, value]) => { + if (mainDict.hasOwnProperty(key)) { + fusedDict[key] = value; + } + }); + return fusedDict; +} +// Merge overrides from a given dict with the config from a Flower App +function getFusedConfigFromDir(projectDir, overrideConfig) { + const defaultConfig = getProjectConfig(projectDir)['tool']['flwr']['app']?.config || {}; + const flatDefaultConfig = flattenDict(defaultConfig); + return fuseDicts(flatDefaultConfig, overrideConfig); +} +// Merge default config from a FAB with overrides in a Run +function getFusedConfigFromFab(fabFile, run) { + const defaultConfig = getFabConfig(fabFile)['tool']['flwr']['app']?.config || {}; + const flatConfig = flattenDict(defaultConfig); + return fuseDicts(flatConfig, run.overrideConfig); +} +// Merge overrides from a Run with the config from a FAB +function getFusedConfig(run, flwrDir) { + if (!run.fabId || !run.fabVersion) { + return {}; + } + const projectDir = getProjectDir(run.fabId, run.fabVersion, flwrDir); + if (!fs.existsSync(projectDir)) { + return {}; + } + return getFusedConfigFromDir(projectDir, run.overrideConfig); +} +// Flatten a nested dictionary by joining nested keys with a separator +function flattenDict(rawDict, parentKey = '') { + if (!rawDict) { + return {}; + } + const items = []; + const separator = '.'; + Object.entries(rawDict).forEach(([key, value]) => { + const newKey = parentKey ? `${parentKey}${separator}${key}` : key; + if (typeof value === 'object' && !Array.isArray(value)) { + items.push(...Object.entries(flattenDict(value, newKey))); + } + else { + items.push([newKey, value]); + } + }); + return Object.fromEntries(items); +} +// Unflatten a dictionary with keys containing separators into a nested dictionary +function unflattenDict(flatDict) { + const unflattenedDict = {}; + const separator = '.'; + Object.entries(flatDict).forEach(([key, value]) => { + const parts = key.split(separator); + let current = unflattenedDict; + parts.forEach((part, idx) => { + if (idx === parts.length - 1) { + current[part] = value; + } + else { + if (!current[part]) { + current[part] = {}; + } + current = current[part]; + } + }); + }); + return unflattenedDict; +} +// Parse a list of key-value pairs separated by '=' or load a TOML file +function parseConfigArgs(config) { + let overrides = {}; + if (!config) { + return overrides; + } + if (config.length === 1 && config[0].endsWith('.toml')) { + const fileContents = fs.readFileSync(config[0], 'utf8'); + overrides = flattenDict(toml.parse(fileContents)); + return overrides; + } + const pattern = /(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)/g; + config.forEach((configLine) => { + if (configLine) { + if (configLine.endsWith('.toml')) { + throw new Error('TOML files cannot be passed alongside key-value pairs.'); + } + const matches = Array.from(configLine.matchAll(pattern)); + const tomlStr = matches.map(([_, k, v]) => `${k} = ${v}`).join('\n'); + Object.assign(overrides, toml.parse(tomlStr)); + } + }); + return overrides; +} +// Extract `fab_version` and `fab_id` from a project config +function getMetadataFromConfig(config) { + return [ + config['project']['version'], + `${config['tool']['flwr']['app']['publisher']}/${config['project']['name']}`, + ]; +} diff --git a/src/ts/src/lib/config.test.js b/src/ts/src/lib/config.test.js new file mode 100644 index 000000000000..4a7eacd1073d --- /dev/null +++ b/src/ts/src/lib/config.test.js @@ -0,0 +1,166 @@ +"use strict"; +var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { enumerable: true, get: function() { return m[k]; } }; + } + Object.defineProperty(o, k2, desc); +}) : (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; +})); +var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { + Object.defineProperty(o, "default", { enumerable: true, value: v }); +}) : function(o, v) { + o["default"] = v; +}); +var __importStar = (this && this.__importStar) || function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +const path = __importStar(require("path")); +const config_1 = require("./config"); +// Mock constants +const FAB_CONFIG_FILE = 'pyproject.toml'; +describe('Configuration Utilities', () => { + beforeEach(() => { + jest.resetModules(); + }); + it('test_get_flwr_dir_with_provided_path', () => { + const providedPath = '.'; + expect((0, config_1.getFlwrDir)(providedPath)).toBe(path.resolve(providedPath)); + }); + it('test_get_flwr_dir_without_provided_path', () => { + jest.spyOn(process, 'env', 'get').mockReturnValue({ HOME: '/home/user' }); + expect((0, config_1.getFlwrDir)()).toBe(path.join('/home/user', '.flwr')); + }); + it('test_get_flwr_dir_with_flwr_home', () => { + jest.spyOn(process, 'env', 'get').mockReturnValue({ FLWR_HOME: '/custom/flwr/home' }); + expect((0, config_1.getFlwrDir)()).toBe(path.join('/custom/flwr/home')); + }); + it('test_get_flwr_dir_with_xdg_data_home', () => { + jest.spyOn(process, 'env', 'get').mockReturnValue({ XDG_DATA_HOME: '/custom/data/home' }); + expect((0, config_1.getFlwrDir)()).toBe(path.join('/custom/data/home', '.flwr')); + }); + it('test_get_project_dir_invalid_fab_id', () => { + expect(() => { + (0, config_1.getProjectDir)('invalid_fab_id', '1.0.0'); + }).toThrow(Error); + }); + it('test_get_project_dir_valid', () => { + const appPath = (0, config_1.getProjectDir)('app_name/user', '1.0.0', '.'); + expect(appPath).toBe(path.join('.', 'apps', 'app_name', 'user', '1.0.0')); + }); + // it('test_get_project_config_file_not_found', () => { + // expect(() => { + // getProjectConfig('/invalid/dir'); + // }).toThrow(Error); + // }); + // it('test_get_fused_config_valid', () => { + // const pyprojectTomlContent = ` + // [build-system] + // requires = ["hatchling"] + // build-backend = "hatchling.build" + // [project] + // name = "fedgpt" + // version = "1.0.0" + // [tool.flwr.app] + // publisher = "flwrlabs" + // [tool.flwr.app.config] + // num_server_rounds = 10 + // momentum = 0.1 + // lr = 0.01 + // progress_bar = true + // `; + // const overrides: UserConfig = { + // num_server_rounds: 5, + // lr: 0.2, + // "serverapp.test": "overriden", + // }; + // const expectedConfig = { + // num_server_rounds: 5, + // momentum: 0.1, + // lr: 0.2, + // progress_bar: true, + // "serverapp.test": "overriden", + // "clientapp.test": "key", + // }; + // const tmpPath = path.join(tmpdir(), 'project_dir'); + // fs.mkdirSync(tmpPath); + // const tomlPath = path.join(tmpPath, FAB_CONFIG_FILE); + // fs.writeFileSync(tomlPath, pyprojectTomlContent); + // try { + // const defaultConfig = getProjectConfig(tmpPath)['tool']['flwr']['app'].config || {}; + // const config = fuseDicts(flattenDict(defaultConfig), overrides); + // expect(config).toEqual(expectedConfig); + // } finally { + // fs.rmdirSync(tmpPath, { recursive: true }); + // } + // }); + // it('test_flatten_dict', () => { + // const rawDict = { a: { b: { c: 'd' } }, e: 'f' }; + // const expected = { 'a.b.c': 'd', e: 'f' }; + // expect(flattenDict(rawDict)).toEqual(expected); + // }); + // it('test_unflatten_dict', () => { + // const rawDict = { 'a.b.c': 'd', e: 'f' }; + // const expected = { a: { b: { c: 'd' } }, e: 'f' }; + // expect(unflattenDict(rawDict)).toEqual(expected); + // }); + // it('test_parse_config_args_none', () => { + // expect(parseConfigArgs(undefined)).toEqual({}); + // }); + // it('test_parse_config_args_overrides', () => { + // const config = parseConfigArgs([ + // "key1='value1' key2='value2'", + // 'key3=1', + // "key4=2.0 key5=true key6='value6'", + // ]); + // const expected = { + // key1: 'value1', + // key2: 'value2', + // key3: 1, + // key4: 2.0, + // key5: true, + // key6: 'value6', + // }; + // expect(config).toEqual(expected); + // }); + // it('test_parse_config_args_from_toml_file', () => { + // const tomlConfig = ` + // num_server_rounds = 10 + // momentum = 0.1 + // verbose = true + // `; + // const initialRunConfig: UserConfig = { + // "num_server_rounds": 5, + // "momentum": 0.2, + // "dataset": "my-fancy-dataset", + // "verbose": false, + // }; + // const expectedConfig = { + // "num_server_rounds": 10, + // "momentum": 0.1, + // "dataset": "my-fancy-dataset", + // "verbose": true, + // }; + // const tmpPath = tmpdir(); + // const tomlConfigFile = path.join(tmpPath, 'extra_config.toml'); + // fs.writeFileSync(tomlConfigFile, tomlConfig); + // const configFromToml = parseConfigArgs([tomlConfigFile]); + // const config = fuseDicts(initialRunConfig, configFromToml); + // expect(config).toEqual(expectedConfig); + // fs.unlinkSync(tomlConfigFile); + // }); + // it('test_parse_config_args_passing_toml_and_key_value', () => { + // const config = ['my-other-config.toml', 'lr=0.1', 'epochs=99']; + // expect(() => { + // parseConfigArgs(config); + // }).toThrow(Error); + // }); +}); diff --git a/src/ts/src/lib/config.test.ts b/src/ts/src/lib/config.test.ts new file mode 100644 index 000000000000..54c3ed2116f1 --- /dev/null +++ b/src/ts/src/lib/config.test.ts @@ -0,0 +1,180 @@ +import * as fs from 'fs'; +import * as path from 'path'; +import { tmpdir } from 'os'; +import { + flattenDict, + fuseDicts, + getFlwrDir, + getProjectConfig, + getProjectDir, + parseConfigArgs, + unflattenDict, +} from './config'; +import { UserConfig } from './typing'; + +// Mock constants +const FAB_CONFIG_FILE = 'pyproject.toml'; + +describe('Configuration Utilities', () => { + + beforeEach(() => { + jest.resetModules(); + }); + + it('test_get_flwr_dir_with_provided_path', () => { + const providedPath = '.'; + expect(getFlwrDir(providedPath)).toBe(path.resolve(providedPath)); + }); + + it('test_get_flwr_dir_without_provided_path', () => { + jest.spyOn(process, 'env', 'get').mockReturnValue({ HOME: '/home/user' }); + expect(getFlwrDir()).toBe(path.join('/home/user', '.flwr')); + }); + + it('test_get_flwr_dir_with_flwr_home', () => { + jest.spyOn(process, 'env', 'get').mockReturnValue({ FLWR_HOME: '/custom/flwr/home' }); + expect(getFlwrDir()).toBe(path.join('/custom/flwr/home')); + }); + + it('test_get_flwr_dir_with_xdg_data_home', () => { + jest.spyOn(process, 'env', 'get').mockReturnValue({ XDG_DATA_HOME: '/custom/data/home' }); + expect(getFlwrDir()).toBe(path.join('/custom/data/home', '.flwr')); + }); + + it('test_get_project_dir_invalid_fab_id', () => { + expect(() => { + getProjectDir('invalid_fab_id', '1.0.0'); + }).toThrow(Error); + }); + + it('test_get_project_dir_valid', () => { + const appPath = getProjectDir('app_name/user', '1.0.0', '.'); + expect(appPath).toBe(path.join('.', 'apps', 'app_name', 'user', '1.0.0')); + }); + + // it('test_get_project_config_file_not_found', () => { + // expect(() => { + // getProjectConfig('/invalid/dir'); + // }).toThrow(Error); + // }); + + // it('test_get_fused_config_valid', () => { + // const pyprojectTomlContent = ` + // [build-system] + // requires = ["hatchling"] + // build-backend = "hatchling.build" + + // [project] + // name = "fedgpt" + // version = "1.0.0" + + // [tool.flwr.app] + // publisher = "flwrlabs" + + // [tool.flwr.app.config] + // num_server_rounds = 10 + // momentum = 0.1 + // lr = 0.01 + // progress_bar = true + // `; + // const overrides: UserConfig = { + // num_server_rounds: 5, + // lr: 0.2, + // "serverapp.test": "overriden", + // }; + // const expectedConfig = { + // num_server_rounds: 5, + // momentum: 0.1, + // lr: 0.2, + // progress_bar: true, + // "serverapp.test": "overriden", + // "clientapp.test": "key", + // }; + + // const tmpPath = path.join(tmpdir(), 'project_dir'); + // fs.mkdirSync(tmpPath); + // const tomlPath = path.join(tmpPath, FAB_CONFIG_FILE); + // fs.writeFileSync(tomlPath, pyprojectTomlContent); + + // try { + // const defaultConfig = getProjectConfig(tmpPath)['tool']['flwr']['app'].config || {}; + // const config = fuseDicts(flattenDict(defaultConfig), overrides); + + // expect(config).toEqual(expectedConfig); + // } finally { + // fs.rmdirSync(tmpPath, { recursive: true }); + // } + // }); + + // it('test_flatten_dict', () => { + // const rawDict = { a: { b: { c: 'd' } }, e: 'f' }; + // const expected = { 'a.b.c': 'd', e: 'f' }; + // expect(flattenDict(rawDict)).toEqual(expected); + // }); + + // it('test_unflatten_dict', () => { + // const rawDict = { 'a.b.c': 'd', e: 'f' }; + // const expected = { a: { b: { c: 'd' } }, e: 'f' }; + // expect(unflattenDict(rawDict)).toEqual(expected); + // }); + + // it('test_parse_config_args_none', () => { + // expect(parseConfigArgs(undefined)).toEqual({}); + // }); + + // it('test_parse_config_args_overrides', () => { + // const config = parseConfigArgs([ + // "key1='value1' key2='value2'", + // 'key3=1', + // "key4=2.0 key5=true key6='value6'", + // ]); + // const expected = { + // key1: 'value1', + // key2: 'value2', + // key3: 1, + // key4: 2.0, + // key5: true, + // key6: 'value6', + // }; + // expect(config).toEqual(expected); + // }); + + // it('test_parse_config_args_from_toml_file', () => { + // const tomlConfig = ` + // num_server_rounds = 10 + // momentum = 0.1 + // verbose = true + // `; + + // const initialRunConfig: UserConfig = { + // "num_server_rounds": 5, + // "momentum": 0.2, + // "dataset": "my-fancy-dataset", + // "verbose": false, + // }; + // const expectedConfig = { + // "num_server_rounds": 10, + // "momentum": 0.1, + // "dataset": "my-fancy-dataset", + // "verbose": true, + // }; + + // const tmpPath = tmpdir(); + // const tomlConfigFile = path.join(tmpPath, 'extra_config.toml'); + // fs.writeFileSync(tomlConfigFile, tomlConfig); + + // const configFromToml = parseConfigArgs([tomlConfigFile]); + // const config = fuseDicts(initialRunConfig, configFromToml); + + // expect(config).toEqual(expectedConfig); + + // fs.unlinkSync(tomlConfigFile); + // }); + + // it('test_parse_config_args_passing_toml_and_key_value', () => { + // const config = ['my-other-config.toml', 'lr=0.1', 'epochs=99']; + // expect(() => { + // parseConfigArgs(config); + // }).toThrow(Error); + // }); +}); diff --git a/src/ts/src/lib/config.ts b/src/ts/src/lib/config.ts new file mode 100644 index 000000000000..171983f9e406 --- /dev/null +++ b/src/ts/src/lib/config.ts @@ -0,0 +1,395 @@ +import * as fs from 'fs'; +import * as os from 'os'; +import * as path from 'path'; +import * as toml from '@iarna/toml'; +import { APP_DIR, FAB_CONFIG_FILE, FLWR_HOME } from './constants'; +import { Run, UserConfig, UserConfigValue } from './typing'; + +import { PathLike } from 'fs'; +import { ZipFile, open } from 'yauzl'; +import * as zlib from 'zlib'; + +type AnyDict = Record; + +export function getFabConfig(fabFile: PathLike | Buffer): AnyDict { + /** + * Extract the config from a FAB file or path. + * @param fabFile The Flower App Bundle file to validate and extract the metadata from. + * It can either be a path to the file or the file itself as bytes. + */ + // let fabFileArchive: PathLike | Buffer; + // if (Buffer.isBuffer(fabFile)) { + // fabFileArchive = fabFile; + // } else if (typeof fabFile === 'string') { + // fabFileArchive = path.resolve(fabFile); + // } else { + // throw new Error('fabFile must be either a Path or Buffer'); + // } + + // // Unzip the FAB file to read pyproject.toml + // const zip = open(fabFileArchive.toString()); + // const zipFile = new ZipFile(zip); + // let tomlContent = ''; + + // // Read pyproject.toml from the archive + // zipFile.on('entry', (entry) => { + // if (entry.fileName === 'pyproject.toml') { + // const readStream = zipFile.openReadStream(entry); + // if (readStream) { + // readStream.pipe(zlib.createGunzip()).on('data', (data) => { + // tomlContent += data.toString('utf-8'); + // }); + // } + // } + // }); + + // // Load TOML content + // const config = loadFromString(tomlContent); + // if (!config) { + // throw new Error('Invalid TOML content in pyproject.toml'); + // } + + // const [isValid, errors] = validate(config, false); + // if (!isValid) { + // throw new Error(errors.join('\n')); + // } + + // return config; + return {}; +} + +export function getFabMetadata(fabFile: PathLike | Buffer): [string, string] { + /** + * Extract the fab_id and fab_version from a FAB file or path. + * @param fabFile The Flower App Bundle file to validate and extract the metadata from. + * It can either be a path to the file or the file itself as bytes. + */ + const config = getFabConfig(fabFile); + + return [ + `${config['tool']['flwr']['app']['publisher']}/${config['project']['name']}`, + config['project']['version'], + ]; +} + +export function loadAndValidate( + providedPath: PathLike | null = null, + checkModule: boolean = true +): [AnyDict | null, string[], string[]] { + /** + * Load and validate pyproject.toml as a dictionary. + * @param providedPath Optional path to pyproject.toml. + * @param checkModule Whether to check module validity. + */ + const configPath = providedPath ? path.resolve(providedPath.toString()) : path.join(process.cwd(), 'pyproject.toml'); + + const config = load(configPath); + + if (!config) { + return [ + null, + ['Project configuration could not be loaded. `pyproject.toml` does not exist.'], + [], + ]; + } + + const [isValid, errors, warnings] = validate(config, checkModule, path.dirname(configPath)); + if (!isValid) { + return [null, errors, warnings]; + } + + return [config, errors, warnings]; +} + +export function load(tomlPath: PathLike): AnyDict | null { + /** + * Load pyproject.toml and return as a dictionary. + */ + if (!fs.existsSync(tomlPath)) { + return null; + } + + const tomlContent = fs.readFileSync(tomlPath, { encoding: 'utf-8' }); + return loadFromString(tomlContent); +} + +function _validateRunConfig(configDict: AnyDict, errors: string[]): void { + for (const [key, value] of Object.entries(configDict)) { + if (typeof value === 'object' && !Array.isArray(value)) { + _validateRunConfig(value, errors); + } + // else if (!getArgs(UserConfigValue).includes(typeof value)) { + // errors.push( + // `The value for key ${key} needs to be of type int, float, bool, string, or a dict of those.` + // ); + // } + } +} + +export function validateFields(config: AnyDict): [boolean, string[], string[]] { + /** + * Validate pyproject.toml fields. + */ + const errors: string[] = []; + const warnings: string[] = []; + + if (!config['project']) { + errors.push('Missing [project] section'); + } else { + if (!config['project']['name']) { + errors.push('Property "name" missing in [project]'); + } + if (!config['project']['version']) { + errors.push('Property "version" missing in [project]'); + } + if (!config['project']['description']) { + warnings.push('Recommended property "description" missing in [project]'); + } + if (!config['project']['license']) { + warnings.push('Recommended property "license" missing in [project]'); + } + if (!config['project']['authors']) { + warnings.push('Recommended property "authors" missing in [project]'); + } + } + + if ( + !config['tool'] || + !config['tool']['flwr'] || + !config['tool']['flwr']['app'] + ) { + errors.push('Missing [tool.flwr.app] section'); + } else { + if (!config['tool']['flwr']['app']['publisher']) { + errors.push('Property "publisher" missing in [tool.flwr.app]'); + } + if (config['tool']['flwr']['app']['config']) { + _validateRunConfig(config['tool']['flwr']['app']['config'], errors); + } + if (!config['tool']['flwr']['app']['components']) { + errors.push('Missing [tool.flwr.app.components] section'); + } else { + if (!config['tool']['flwr']['app']['components']['serverapp']) { + errors.push('Property "serverapp" missing in [tool.flwr.app.components]'); + } + if (!config['tool']['flwr']['app']['components']['clientapp']) { + errors.push('Property "clientapp" missing in [tool.flwr.app.components]'); + } + } + } + + return [errors.length === 0, errors, warnings]; +} + +export function validate( + config: AnyDict, + checkModule: boolean = true, + projectDir: PathLike | null = null +): [boolean, string[], string[]] { + /** + * Validate pyproject.toml. + */ + const [isValid, errors, warnings] = validateFields(config); + + if (!isValid) { + return [false, errors, warnings]; + } + + // Validate serverapp + const serverappRef = config['tool']['flwr']['app']['components']['serverapp']; + // const [serverIsValid, serverReason] = objectRef.validate(serverappRef, checkModule, projectDir); + + // if (!serverIsValid && typeof serverReason === 'string') { + // return [false, [serverReason], []]; + // } + + // Validate clientapp + const clientappRef = config['tool']['flwr']['app']['components']['clientapp']; + // const [clientIsValid, clientReason] = objectRef.validate(clientappRef, checkModule, projectDir); + + // if (!clientIsValid && typeof clientReason === 'string') { + // return [false, [clientReason], []]; + // } + + return [true, [], []]; +} + +export function loadFromString(tomlContent: string): AnyDict | null { + /** + * Load TOML content from a string and return as a dictionary. + */ + try { + return toml.parse(tomlContent); + } catch (error) { + return null; + } +} + +// Get Flower home directory based on environment variables +export function getFlwrDir(providedPath?: string): string { + if (!providedPath || !fs.existsSync(providedPath)) { + return path.join( + process.env[FLWR_HOME] || path.join(process.env['XDG_DATA_HOME'] || os.homedir(), '.flwr') + ); + } + return path.resolve(providedPath); +} + +// Return the project directory based on fab_id and fab_version +export function getProjectDir(fabId: string, fabVersion: string, flwrDir?: string): string { + if ((fabId.match(/\//g) || []).length !== 1) { + throw new Error(`Invalid FAB ID: ${fabId}`); + } + + const [publisher, projectName] = fabId.split('/'); + flwrDir = flwrDir || getFlwrDir(); + return path.join(flwrDir, APP_DIR, publisher, projectName, fabVersion); +} + +// Return pyproject.toml configuration from the project directory +export function getProjectConfig(projectDir: string): { [key: string]: any } { + const tomlPath = path.join(projectDir, FAB_CONFIG_FILE); + if (!fs.existsSync(tomlPath)) { + throw new Error(`Cannot find ${FAB_CONFIG_FILE} in ${projectDir}`); + } + + const fileContents = fs.readFileSync(tomlPath, 'utf8'); + const config = toml.parse(fileContents); + + const [isValid, _warnings, errors] = validateFields(config); + if (!isValid) { + const errorMsg = errors.map((error: string) => ` - ${error}`).join('\n'); + throw new Error(`Invalid ${FAB_CONFIG_FILE}:\n${errorMsg}`); + } + + return config; +} + +// Merge a config with the overrides +export function fuseDicts(mainDict: UserConfig, overrideDict: UserConfig): UserConfig { + const fusedDict = { ...mainDict }; + + Object.entries(overrideDict).forEach(([key, value]) => { + if (mainDict.hasOwnProperty(key)) { + fusedDict[key] = value; + } + }); + + return fusedDict; +} + +// Merge overrides from a given dict with the config from a Flower App +export function getFusedConfigFromDir(projectDir: string, overrideConfig: UserConfig): UserConfig { + const defaultConfig = getProjectConfig(projectDir)['tool']['flwr']['app']?.config || {}; + const flatDefaultConfig = flattenDict(defaultConfig); + + return fuseDicts(flatDefaultConfig, overrideConfig); +} + +// Merge default config from a FAB with overrides in a Run +export function getFusedConfigFromFab(fabFile: string | Buffer, run: Run): UserConfig { + const defaultConfig = getFabConfig(fabFile)['tool']['flwr']['app']?.config || {}; + const flatConfig = flattenDict(defaultConfig); + + return fuseDicts(flatConfig, run.overrideConfig); +} + +// Merge overrides from a Run with the config from a FAB +export function getFusedConfig(run: Run, flwrDir?: string): UserConfig { + if (!run.fabId || !run.fabVersion) { + return {}; + } + + const projectDir = getProjectDir(run.fabId, run.fabVersion, flwrDir); + + if (!fs.existsSync(projectDir)) { + return {}; + } + + return getFusedConfigFromDir(projectDir, run.overrideConfig); +} + +// Flatten a nested dictionary by joining nested keys with a separator +export function flattenDict(rawDict: { [key: string]: any } | undefined, parentKey = ''): UserConfig { + if (!rawDict) { + return {}; + } + + const items: [string, UserConfigValue][] = []; + const separator = '.'; + + Object.entries(rawDict).forEach(([key, value]) => { + const newKey = parentKey ? `${parentKey}${separator}${key}` : key; + + if (typeof value === 'object' && !Array.isArray(value)) { + items.push(...Object.entries(flattenDict(value, newKey))); + } else { + items.push([newKey, value as UserConfigValue]); + } + }); + + return Object.fromEntries(items); +} + +// Unflatten a dictionary with keys containing separators into a nested dictionary +export function unflattenDict(flatDict: { [key: string]: any }): { [key: string]: any } { + const unflattenedDict: { [key: string]: any } = {}; + const separator = '.'; + + Object.entries(flatDict).forEach(([key, value]) => { + const parts = key.split(separator); + let current = unflattenedDict; + + parts.forEach((part, idx) => { + if (idx === parts.length - 1) { + current[part] = value; + } else { + if (!current[part]) { + current[part] = {}; + } + current = current[part]; + } + }); + }); + + return unflattenedDict; +} + +// Parse a list of key-value pairs separated by '=' or load a TOML file +export function parseConfigArgs(config?: string[]): UserConfig { + let overrides: UserConfig = {}; + + if (!config) { + return overrides; + } + + if (config.length === 1 && config[0].endsWith('.toml')) { + const fileContents = fs.readFileSync(config[0], 'utf8'); + overrides = flattenDict(toml.parse(fileContents)); + return overrides; + } + + const pattern = /(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)/g; + + config.forEach((configLine) => { + if (configLine) { + if (configLine.endsWith('.toml')) { + throw new Error('TOML files cannot be passed alongside key-value pairs.'); + } + + const matches = Array.from(configLine.matchAll(pattern)); + const tomlStr = matches.map(([_, k, v]) => `${k} = ${v}`).join('\n'); + Object.assign(overrides, toml.parse(tomlStr)); + } + }); + + return overrides; +} + +// Extract `fab_version` and `fab_id` from a project config +export function getMetadataFromConfig(config: { [key: string]: any }): [string, string] { + return [ + config['project']['version'], + `${config['tool']['flwr']['app']['publisher']}/${config['project']['name']}`, + ]; +} diff --git a/src/ts/src/lib/connection.js b/src/ts/src/lib/connection.js new file mode 100644 index 000000000000..6586714b5ce5 --- /dev/null +++ b/src/ts/src/lib/connection.js @@ -0,0 +1,176 @@ +"use strict"; +var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { enumerable: true, get: function() { return m[k]; } }; + } + Object.defineProperty(o, k2, desc); +}) : (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; +})); +var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { + Object.defineProperty(o, "default", { enumerable: true, value: v }); +}) : function(o, v) { + o["default"] = v; +}); +var __importStar = (this && this.__importStar) || function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.grpcRequestResponse = grpcRequestResponse; +const fs = __importStar(require("fs")); +const grpc_1 = require("./grpc"); +const fleet_1 = require("../protos/flwr/proto/fleet"); +const heartbeat_1 = require("./heartbeat"); +const constants_1 = require("./constants"); +const message_handler_1 = require("./message_handler"); +const task_handler_1 = require("./task_handler"); +const serde_1 = require("./serde"); +const client_interceptor_1 = require("./client_interceptor"); +const fleet_client_1 = require("../protos/flwr/proto/fleet.client"); +async function grpcRequestResponse(serverAddress, insecure, retryInvoker, maxMessageLength = grpc_1.GRPC_MAX_MESSAGE_LENGTH, rootCertificates, authenticationKeys, adapterCls) { + // If `rootCertificates` is a string, read the certificate file + if (typeof rootCertificates === "string") { + rootCertificates = await fs.promises.readFile(rootCertificates); + } + // Authentication interceptors + let interceptors = undefined; + if (authenticationKeys) { + interceptors = [(0, client_interceptor_1.AuthenticateClientInterceptor)(authenticationKeys[0], authenticationKeys[1])]; + } + const channel = (0, grpc_1.createChannel)(serverAddress, insecure, rootCertificates, maxMessageLength); + // channel.subscribe(onChannelStateChange); + let stub = new fleet_client_1.FleetClient(channel); + let metadata = null; + let node = null; + const pingStopEvent = (0, heartbeat_1.createStopEvent)(); + // Ping function + async function ping() { + if (!node) { + console.error("Node instance missing"); + return; + } + const req = {}; + req.node = node; + req.pingInterval = constants_1.PING_DEFAULT_INTERVAL; + // const res = (await retryInvoker.invoke(() => + // stub.ping(req, { timeout: PING_CALL_TIMEOUT }), + // )) as FinishedUnaryCall; + const res = await stub.ping(req, { timeout: constants_1.PING_CALL_TIMEOUT }); + if (!res.response.success) { + throw new Error("Ping failed unexpectedly."); + } + const randomFactor = Math.random() * (constants_1.PING_RANDOM_RANGE[1] - constants_1.PING_RANDOM_RANGE[0]) + constants_1.PING_RANDOM_RANGE[0]; + const nextInterval = constants_1.PING_DEFAULT_INTERVAL * (constants_1.PING_BASE_MULTIPLIER + randomFactor) - constants_1.PING_CALL_TIMEOUT; + // setTimeout(() => { + // if (!pingStopEvent.is_set) { + // ping(); + // } + // }, nextInterval * 1000); // Convert seconds to milliseconds + } + // Create node + async function createNode() { + const req = {}; + req.pingInterval = constants_1.PING_DEFAULT_INTERVAL; + // const res = (await retryInvoker.invoke(() => stub.createNode(req))) as FinishedUnaryCall< + // CreateNodeRequest, + // CreateNodeResponse + // >; + const res = await stub.createNode(req); + node = res.response.node; + // startPingLoop(ping, pingStopEvent); + return node?.nodeId || null; + } + // Delete node + async function deleteNode() { + if (!node) { + console.error("Node instance missing"); + return; + } + pingStopEvent.set(); + const req = {}; + req.node = node; + // await retryInvoker.invoke(() => stub.deleteNode(req)); + await stub.deleteNode(req); + node = null; + } + // Receive message + async function receive() { + if (!node) { + console.error("Node instance missing"); + return null; + } + const req = {}; + req.node = node; + req.taskIds = []; + // const res = (await retryInvoker.invoke(() => stub.pullTaskIns(req))) as FinishedUnaryCall< + // PullTaskInsRequest, + // PullTaskInsResponse + // >; + const res = await stub.pullTaskIns(req); + let taskIns = (0, task_handler_1.getTaskIns)(res.response); + if (taskIns && !(taskIns.task?.consumer?.nodeId === node.nodeId && (0, task_handler_1.validateTaskIns)(taskIns))) { + taskIns = null; + } + const inMessage = taskIns ? (0, serde_1.messageFromTaskIns)(taskIns) : null; + metadata = inMessage?.metadata || null; + return inMessage; + } + // Send message + async function send(message) { + if (!node) { + console.error("ERROR", "Node instance missing"); + return; + } + if (!metadata) { + console.error("ERROR", "No current message"); + return; + } + if (!(0, message_handler_1.validateOutMessage)(message, metadata)) { + console.error("Invalid out message"); + return; + } + const taskRes = (0, serde_1.messageToTaskRes)(message); + let req = fleet_1.PushTaskResRequest.create(); + req.taskResList.push(taskRes); + req.node = node; + // await retryInvoker.invoke(() => stub.pushTaskRes(req)); + await stub.pushTaskRes(req); + metadata = null; + } + // Get run + async function getRun(runId) { + const req = {}; + req.runId = runId; + // const res = (await retryInvoker.invoke(() => stub.getRun(req))) as FinishedUnaryCall< + // GetRunRequest, + // GetRunResponse + // >; + const res = await stub.getRun(req); + return { + runId, + fabId: res.response.run?.fabId, + fabVersion: res.response.run?.fabVersion, + fabHash: res.response.run?.fabHash, + overrideConfig: res.response.run?.overrideConfig ? (0, serde_1.userConfigFromProto)(res.response.run?.overrideConfig) : {}, + }; + } + // Get fab + async function getFab(fabHash) { + const req = {}; + req.hashStr = fabHash; + // const res = (await retryInvoker.invoke(() => stub.getFab(req))) as FinishedUnaryCall< + // GetFabRequest, + // GetFabResponse + // >; + const res = await stub.getFab(req); + return { hashStr: res.response.fab?.hashStr, content: res.response.fab?.content }; + } + return [receive, send, createNode, deleteNode, getRun, getFab]; +} diff --git a/src/ts/src/lib/connection.ts b/src/ts/src/lib/connection.ts new file mode 100644 index 000000000000..05c3e2890c95 --- /dev/null +++ b/src/ts/src/lib/connection.ts @@ -0,0 +1,235 @@ +import * as fs from "fs"; +import { FinishedUnaryCall, RpcInterceptor } from "@protobuf-ts/runtime-rpc"; + +import { createChannel, GRPC_MAX_MESSAGE_LENGTH } from "./grpc"; +import { + PingRequest, + PingResponse, + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, +} from "../protos/flwr/proto/fleet"; +import { GetRunRequest, GetRunResponse } from "../protos/flwr/proto/run"; +import { GetFabRequest, GetFabResponse } from "../protos/flwr/proto/fab"; +import { Node } from "../protos/flwr/proto/node"; +import { TaskIns } from "../protos/flwr/proto/task"; +import { Metadata, Message, Run, Fab } from "./typing"; +import { ec } from "elliptic"; +import { startPingLoop, createStopEvent } from "./heartbeat"; +import { + PING_CALL_TIMEOUT, + PING_RANDOM_RANGE, + PING_BASE_MULTIPLIER, + PING_DEFAULT_INTERVAL, +} from "./constants"; +import { validateOutMessage } from "./message_handler"; +import { getTaskIns, validateTaskIns } from "./task_handler"; +import { messageFromTaskIns, messageToTaskRes, userConfigFromProto } from "./serde"; +import { RetryInvoker } from "./retry_invoker"; +import { AuthenticateClientInterceptor } from "./client_interceptor"; +import { FleetClient } from "../protos/flwr/proto/fleet.client"; + +type GrpcRequestResponseReturnType = [ + () => Promise, + (message: Message) => Promise, + () => Promise, + () => Promise, + (run_id: bigint) => Promise, + (fabHash: string) => Promise, +]; + +export async function grpcRequestResponse( + serverAddress: string, + insecure: boolean, + retryInvoker: RetryInvoker, + maxMessageLength: number = GRPC_MAX_MESSAGE_LENGTH, + rootCertificates?: Buffer | string, + authenticationKeys?: [ec.KeyPair, ec.KeyPair] | null, + adapterCls?: any, +): Promise { + // If `rootCertificates` is a string, read the certificate file + if (typeof rootCertificates === "string") { + rootCertificates = await fs.promises.readFile(rootCertificates); + } + + // Authentication interceptors + let interceptors: RpcInterceptor[] | undefined = undefined; + if (authenticationKeys) { + interceptors = [AuthenticateClientInterceptor(authenticationKeys[0], authenticationKeys[1])]; + } + + const channel = createChannel( + serverAddress, + insecure, + rootCertificates, + maxMessageLength, + // interceptors, + ); + // channel.subscribe(onChannelStateChange); + + let stub = new FleetClient(channel); + let metadata: Metadata | null = null; + let node: Node | null = null; + const pingStopEvent = createStopEvent(); + + // Ping function + async function ping(): Promise { + if (!node) { + console.error("Node instance missing"); + return; + } + + const req = {} as PingRequest; + req.node = node; + req.pingInterval = PING_DEFAULT_INTERVAL; + + // const res = (await retryInvoker.invoke(() => + // stub.ping(req, { timeout: PING_CALL_TIMEOUT }), + // )) as FinishedUnaryCall; + const res = await stub.ping(req, { timeout: PING_CALL_TIMEOUT }); + if (!res.response.success) { + throw new Error("Ping failed unexpectedly."); + } + + const randomFactor = + Math.random() * (PING_RANDOM_RANGE[1] - PING_RANDOM_RANGE[0]) + PING_RANDOM_RANGE[0]; + const nextInterval = + PING_DEFAULT_INTERVAL * (PING_BASE_MULTIPLIER + randomFactor) - PING_CALL_TIMEOUT; + + // setTimeout(() => { + // if (!pingStopEvent.is_set) { + // ping(); + // } + // }, nextInterval * 1000); // Convert seconds to milliseconds + } + + // Create node + async function createNode(): Promise { + const req = {} as CreateNodeRequest; + req.pingInterval = PING_DEFAULT_INTERVAL; + + // const res = (await retryInvoker.invoke(() => stub.createNode(req))) as FinishedUnaryCall< + // CreateNodeRequest, + // CreateNodeResponse + // >; + const res = await stub.createNode(req); + + node = res.response.node!; + // startPingLoop(ping, pingStopEvent); + + return node?.nodeId || null; + } + + // Delete node + async function deleteNode(): Promise { + if (!node) { + console.error("Node instance missing"); + return; + } + + pingStopEvent.set(); + + const req = {} as DeleteNodeRequest; + req.node = node; + + // await retryInvoker.invoke(() => stub.deleteNode(req)); + await stub.deleteNode(req); + + node = null; + } + + // Receive message + async function receive(): Promise { + if (!node) { + console.error("Node instance missing"); + return null; + } + + const req = {} as PullTaskInsRequest; + req.node = node; + req.taskIds = []; + + // const res = (await retryInvoker.invoke(() => stub.pullTaskIns(req))) as FinishedUnaryCall< + // PullTaskInsRequest, + // PullTaskInsResponse + // >; + const res = await stub.pullTaskIns(req); + + let taskIns: TaskIns | null = getTaskIns(res.response); + + if (taskIns && !(taskIns.task?.consumer?.nodeId === node.nodeId && validateTaskIns(taskIns))) { + taskIns = null; + } + + const inMessage = taskIns ? messageFromTaskIns(taskIns) : null; + metadata = inMessage?.metadata || null; + return inMessage; + } + + // Send message + async function send(message: Message): Promise { + if (!node) { + console.error("ERROR", "Node instance missing"); + return; + } + + if (!metadata) { + console.error("ERROR", "No current message"); + return; + } + + if (!validateOutMessage(message, metadata)) { + console.error("Invalid out message"); + return; + } + + const taskRes = messageToTaskRes(message); + let req = PushTaskResRequest.create(); + req.taskResList.push(taskRes); + req.node = node; + + // await retryInvoker.invoke(() => stub.pushTaskRes(req)); + await stub.pushTaskRes(req); + + metadata = null; + } + + // Get run + async function getRun(runId: bigint): Promise { + const req = {} as GetRunRequest; + req.runId = runId; + + // const res = (await retryInvoker.invoke(() => stub.getRun(req))) as FinishedUnaryCall< + // GetRunRequest, + // GetRunResponse + // >; + const res = await stub.getRun(req); + + return { + runId, + fabId: res.response.run?.fabId, + fabVersion: res.response.run?.fabVersion, + fabHash: res.response.run?.fabHash, + overrideConfig: res.response.run?.overrideConfig ? userConfigFromProto(res.response.run?.overrideConfig) : {}, + } as Run; + } + + // Get fab + async function getFab(fabHash: string): Promise { + const req = {} as GetFabRequest; + req.hashStr = fabHash; + + // const res = (await retryInvoker.invoke(() => stub.getFab(req))) as FinishedUnaryCall< + // GetFabRequest, + // GetFabResponse + // >; + const res = await stub.getFab(req); + + return { hashStr: res.response.fab?.hashStr, content: res.response.fab?.content } as Fab; + } + + return [receive, send, createNode, deleteNode, getRun, getFab]; +} diff --git a/src/ts/src/lib/constants.js b/src/ts/src/lib/constants.js new file mode 100644 index 000000000000..4817d8fe7d13 --- /dev/null +++ b/src/ts/src/lib/constants.js @@ -0,0 +1,11 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.FLWR_HOME = exports.FAB_CONFIG_FILE = exports.APP_DIR = exports.PING_MAX_INTERVAL = exports.PING_RANDOM_RANGE = exports.PING_BASE_MULTIPLIER = exports.PING_CALL_TIMEOUT = exports.PING_DEFAULT_INTERVAL = void 0; +exports.PING_DEFAULT_INTERVAL = 30; +exports.PING_CALL_TIMEOUT = 5; +exports.PING_BASE_MULTIPLIER = 0.8; +exports.PING_RANDOM_RANGE = [-0.1, 0.1]; +exports.PING_MAX_INTERVAL = 1e300; +exports.APP_DIR = "apps"; +exports.FAB_CONFIG_FILE = "pyproject.toml"; +exports.FLWR_HOME = "FLWR_HOME"; diff --git a/src/ts/src/lib/constants.ts b/src/ts/src/lib/constants.ts new file mode 100644 index 000000000000..f42bc18d102c --- /dev/null +++ b/src/ts/src/lib/constants.ts @@ -0,0 +1,8 @@ +export const PING_DEFAULT_INTERVAL = 30; +export const PING_CALL_TIMEOUT = 5; +export const PING_BASE_MULTIPLIER = 0.8; +export const PING_RANDOM_RANGE = [-0.1, 0.1]; +export const PING_MAX_INTERVAL = 1e300; +export const APP_DIR = "apps"; +export const FAB_CONFIG_FILE = "pyproject.toml"; +export const FLWR_HOME = "FLWR_HOME"; diff --git a/src/ts/src/lib/crypto_helpers.js b/src/ts/src/lib/crypto_helpers.js new file mode 100644 index 000000000000..366cdf75c4ae --- /dev/null +++ b/src/ts/src/lib/crypto_helpers.js @@ -0,0 +1,48 @@ +"use strict"; +var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { enumerable: true, get: function() { return m[k]; } }; + } + Object.defineProperty(o, k2, desc); +}) : (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; +})); +var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { + Object.defineProperty(o, "default", { enumerable: true, value: v }); +}) : function(o, v) { + o["default"] = v; +}); +var __importStar = (this && this.__importStar) || function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.publicKeyToBytes = publicKeyToBytes; +exports.bytesToPublicKey = bytesToPublicKey; +exports.generateSharedKey = generateSharedKey; +exports.computeHMAC = computeHMAC; +const elliptic_1 = require("elliptic"); +const crypto = __importStar(require("crypto")); +const ec = new elliptic_1.ec("p256"); +// Convert public key to bytes +function publicKeyToBytes(key) { + return Buffer.from(key.getPublic("array")); +} +// Convert bytes back to a public key +function bytesToPublicKey(bytes) { + return ec.keyFromPublic(bytes); +} +// Generate shared key between private and public keys +function generateSharedKey(privateKey, publicKey) { + return Buffer.from(privateKey.derive(publicKey.getPublic()).toArray()); +} +// Compute HMAC using shared key and data +function computeHMAC(key, message) { + return crypto.createHmac("sha256", key).update(message).digest(); +} diff --git a/src/ts/src/lib/crypto_helpers.ts b/src/ts/src/lib/crypto_helpers.ts new file mode 100644 index 000000000000..0b9151123d35 --- /dev/null +++ b/src/ts/src/lib/crypto_helpers.ts @@ -0,0 +1,24 @@ +import { ec as EC } from "elliptic"; +import * as crypto from "crypto"; + +const ec = new EC("p256"); + +// Convert public key to bytes +export function publicKeyToBytes(key: EC.KeyPair): Buffer { + return Buffer.from(key.getPublic("array")); +} + +// Convert bytes back to a public key +export function bytesToPublicKey(bytes: Buffer): EC.KeyPair { + return ec.keyFromPublic(bytes); +} + +// Generate shared key between private and public keys +export function generateSharedKey(privateKey: EC.KeyPair, publicKey: EC.KeyPair): Buffer { + return Buffer.from(privateKey.derive(publicKey.getPublic()).toArray()); +} + +// Compute HMAC using shared key and data +export function computeHMAC(key: Buffer, message: Buffer): Buffer { + return crypto.createHmac("sha256", key).update(message).digest(); +} diff --git a/src/ts/src/lib/grpc.js b/src/ts/src/lib/grpc.js new file mode 100644 index 000000000000..0f86cca7975a --- /dev/null +++ b/src/ts/src/lib/grpc.js @@ -0,0 +1,57 @@ +"use strict"; +var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + var desc = Object.getOwnPropertyDescriptor(m, k); + if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { + desc = { enumerable: true, get: function() { return m[k]; } }; + } + Object.defineProperty(o, k2, desc); +}) : (function(o, m, k, k2) { + if (k2 === undefined) k2 = k; + o[k2] = m[k]; +})); +var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { + Object.defineProperty(o, "default", { enumerable: true, value: v }); +}) : function(o, v) { + o["default"] = v; +}); +var __importStar = (this && this.__importStar) || function (mod) { + if (mod && mod.__esModule) return mod; + var result = {}; + if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k); + __setModuleDefault(result, mod); + return result; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.GRPC_MAX_MESSAGE_LENGTH = void 0; +exports.createChannel = createChannel; +const grpc = __importStar(require("@grpc/grpc-js")); +const grpc_transport_1 = require("@protobuf-ts/grpc-transport"); +exports.GRPC_MAX_MESSAGE_LENGTH = 536_870_912; // == 512 * 1024 * 1024 +function createChannel(serverAddress, insecure, rootCertificates = null, maxMessageLength = exports.GRPC_MAX_MESSAGE_LENGTH, interceptors = null) { + // Check for conflicting parameters + if (insecure && rootCertificates !== null) { + throw new Error("Invalid configuration: 'root_certificates' should not be provided " + + "when 'insecure' is set to true. For an insecure connection, omit " + + "'root_certificates', or set 'insecure' to false for a secure connection."); + } + let creds; + if (insecure === true) { + creds = grpc.credentials.createInsecure(); + console.debug("Opened insecure gRPC connection (no certificates were passed)"); + } + else { + creds = grpc.credentials.createSsl(rootCertificates); + console.debug("Opened secure gRPC connection using certificates"); + } + // gRPC channel options + const clientOptions = { + "grpc.max_send_message_length": maxMessageLength, + "grpc.max_receive_message_length": maxMessageLength, + }; + let rpcOptions = { host: serverAddress, channelCredentials: creds, clientOptions }; + if (interceptors !== null) { + rpcOptions.interceptors = interceptors; + } + return new grpc_transport_1.GrpcTransport(rpcOptions); +} diff --git a/src/ts/src/lib/grpc.ts b/src/ts/src/lib/grpc.ts new file mode 100644 index 000000000000..c44f6d7bd912 --- /dev/null +++ b/src/ts/src/lib/grpc.ts @@ -0,0 +1,45 @@ +import * as grpc from "@grpc/grpc-js"; +import { GrpcOptions, GrpcTransport } from "@protobuf-ts/grpc-transport"; +import { RpcInterceptor, RpcOptions } from "@protobuf-ts/runtime-rpc"; + +export const GRPC_MAX_MESSAGE_LENGTH = 536_870_912; // == 512 * 1024 * 1024 + +export function createChannel( + serverAddress: string, + insecure: boolean, + rootCertificates: Buffer | null = null, + maxMessageLength: number = GRPC_MAX_MESSAGE_LENGTH, + interceptors: RpcInterceptor[] | null = null, +): GrpcTransport { + // Check for conflicting parameters + if (insecure && rootCertificates !== null) { + throw new Error( + "Invalid configuration: 'root_certificates' should not be provided " + + "when 'insecure' is set to true. For an insecure connection, omit " + + "'root_certificates', or set 'insecure' to false for a secure connection.", + ); + } + + let creds: grpc.ChannelCredentials; + if (insecure === true) { + creds = grpc.credentials.createInsecure(); + console.debug("Opened insecure gRPC connection (no certificates were passed)"); + } else { + creds = grpc.credentials.createSsl(rootCertificates); + console.debug("Opened secure gRPC connection using certificates"); + } + + // gRPC channel options + const clientOptions: grpc.ClientOptions = { + "grpc.max_send_message_length": maxMessageLength, + "grpc.max_receive_message_length": maxMessageLength, + }; + + let rpcOptions: GrpcOptions = { host: serverAddress, channelCredentials: creds, clientOptions }; + + if (interceptors !== null) { + rpcOptions.interceptors = interceptors; + } + + return new GrpcTransport(rpcOptions); +} diff --git a/src/ts/src/lib/heartbeat.js b/src/ts/src/lib/heartbeat.js new file mode 100644 index 000000000000..7e52de57f6b4 --- /dev/null +++ b/src/ts/src/lib/heartbeat.js @@ -0,0 +1,70 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.createStopEvent = createStopEvent; +exports.startPingLoop = startPingLoop; +const events_1 = require("events"); +const grpc_js_1 = require("@grpc/grpc-js"); +const retry_invoker_1 = require("./retry_invoker"); +const constants_1 = require("./constants"); +class StopEvent extends events_1.EventEmitter { + is_set; + constructor() { + super(); + this.is_set = false; + this.on("set", () => { + this.is_set = true; + }); + } + set() { + this.emit("set"); + } +} +function createStopEvent() { + return new StopEvent(); +} +function pingLoop(pingFn, stopEvent) { + const waitFn = (waitTime) => new Promise((resolve) => { + if (!stopEvent.is_set) { + setTimeout(resolve, waitTime * 1000); + } + }); + const onBackoff = (state) => { + const err = state.exception; + if (!err) + return; + const statusCode = err.code; + if (statusCode === grpc_js_1.status.DEADLINE_EXCEEDED) { + if (state.actualWait !== undefined) { + state.actualWait = Math.max(state.actualWait - constants_1.PING_CALL_TIMEOUT, 0); + } + } + }; + const wrappedPing = () => { + if (!stopEvent.is_set) { + pingFn(); + } + }; + const retrier = new retry_invoker_1.RetryInvoker(retry_invoker_1.exponential, Error, null, null, { + onBackoff, + waitFunction: waitFn, + }); + return new Promise(async (resolve) => { + while (!stopEvent.is_set) { + await retrier.invoke(wrappedPing); + } + resolve(); // Resolve when stopEvent is triggered + }); +} +// TypeScript version of startPingLoop +function startPingLoop(pingFn, stopEvent) { + // Start the loop, but do not block + pingLoop(pingFn, stopEvent).then(() => { + console.log("Ping loop terminated."); + }); + const intervalId = setInterval(() => { + if (stopEvent.is_set) { + clearInterval(intervalId); // Clear the interval when stopEvent is set + } + }, 1000); // Interval to keep the loop alive + return intervalId; // Return the interval ID} +} diff --git a/src/ts/src/lib/heartbeat.ts b/src/ts/src/lib/heartbeat.ts new file mode 100644 index 000000000000..a4ba718b24d4 --- /dev/null +++ b/src/ts/src/lib/heartbeat.ts @@ -0,0 +1,81 @@ +import { EventEmitter } from "events"; +import { ServiceError, status } from "@grpc/grpc-js"; +import { RetryInvoker, RetryState, exponential } from "./retry_invoker"; +import { PING_CALL_TIMEOUT } from "./constants"; + + +class StopEvent extends EventEmitter { + public is_set: boolean; + + constructor() { + super(); + this.is_set = false; + + this.on("set", () => { + this.is_set = true; + }); + } + + set(): void { + this.emit("set"); + } +} + +export function createStopEvent(): StopEvent { + return new StopEvent(); +} + +function pingLoop(pingFn: () => void, stopEvent: StopEvent): Promise { + const waitFn = (waitTime: number): Promise => + new Promise((resolve) => { + if (!stopEvent.is_set) { + setTimeout(resolve, waitTime * 1000); + } + }); + + const onBackoff = (state: RetryState): void => { + const err = state.exception as ServiceError; + if (!err) return; + + const statusCode = err.code; + if (statusCode === status.DEADLINE_EXCEEDED) { + if (state.actualWait !== undefined) { + state.actualWait = Math.max(state.actualWait - PING_CALL_TIMEOUT, 0); + } + } + }; + + const wrappedPing = (): void => { + if (!stopEvent.is_set) { + pingFn(); + } + }; + + const retrier = new RetryInvoker(exponential, Error, null, null, { + onBackoff, + waitFunction: waitFn, + }); + + return new Promise(async (resolve) => { + while (!stopEvent.is_set) { + await retrier.invoke(wrappedPing); + } + resolve(); // Resolve when stopEvent is triggered + }); +} + +// TypeScript version of startPingLoop +export function startPingLoop(pingFn: () => void, stopEvent: StopEvent): NodeJS.Timeout { + // Start the loop, but do not block + pingLoop(pingFn, stopEvent).then(() => { + console.log("Ping loop terminated."); + }); + + const intervalId = setInterval(() => { + if (stopEvent.is_set) { + clearInterval(intervalId); // Clear the interval when stopEvent is set + } + }, 1000); // Interval to keep the loop alive + + return intervalId; // Return the interval ID} +} diff --git a/src/ts/src/lib/index.js b/src/ts/src/lib/index.js new file mode 100644 index 000000000000..4f098234a453 --- /dev/null +++ b/src/ts/src/lib/index.js @@ -0,0 +1,21 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.Client = exports.startClientInternal = void 0; +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== +const client_1 = require("./client"); +Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Client; } }); +const start_1 = require("./start"); +Object.defineProperty(exports, "startClientInternal", { enumerable: true, get: function () { return start_1.startClientInternal; } }); diff --git a/src/ts/src/lib/index.ts b/src/ts/src/lib/index.ts new file mode 100644 index 000000000000..0539287804cb --- /dev/null +++ b/src/ts/src/lib/index.ts @@ -0,0 +1,20 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== +import { Client } from "./client"; +import { startClientInternal } from "./start"; +import { GetParametersIns, GetParametersRes, GetPropertiesIns, GetPropertiesRes, FitIns, FitRes, EvaluateIns, EvaluateRes, Context } from "./typing"; + +export { startClientInternal, Client }; +export type { GetPropertiesRes, GetPropertiesIns, GetParametersRes, GetParametersIns, FitRes, FitIns, EvaluateRes, EvaluateIns, Context }; diff --git a/src/ts/src/lib/logger.js b/src/ts/src/lib/logger.js new file mode 100644 index 000000000000..68125b3d0869 --- /dev/null +++ b/src/ts/src/lib/logger.js @@ -0,0 +1,39 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.warnPreviewFeature = warnPreviewFeature; +exports.warnDeprecatedFeature = warnDeprecatedFeature; +exports.warnDeprecatedFeatureWithExample = warnDeprecatedFeatureWithExample; +exports.warnUnsupportedFeature = warnUnsupportedFeature; +function warnPreviewFeature(name) { + console.warn(`PREVIEW FEATURE: ${name} + + This is a preview feature.It could change significantly or be removed + entirely in future versions of Flower. + `); +} +function warnDeprecatedFeature(name) { + console.warn(`DEPRECATED FEATURE: ${name} + + This is a deprecated feature.It will be removed + entirely in future versions of Flower. + `); +} +function warnDeprecatedFeatureWithExample(deprecation_message, example_message, code_example) { + console.warn(`DEPRECATED FEATURE: ${deprecation_message} + + Check the following \`FEATURE UPDATE\` warning message for the preferred + new mechanism to use this feature in Flower. + `); + console.warn(`FEATURE UPDATE: ${example_message} + ------------------------------------------------------------ + ${code_example} + ------------------------------------------------------------ + `); +} +function warnUnsupportedFeature(name) { + console.warn(`UNSUPPORTED FEATURE: ${name} + + This is an unsupported feature.It will be removed + entirely in future versions of Flower. + `); +} diff --git a/src/ts/src/lib/logger.ts b/src/ts/src/lib/logger.ts new file mode 100644 index 000000000000..53a0289a3f03 --- /dev/null +++ b/src/ts/src/lib/logger.ts @@ -0,0 +1,46 @@ +export function warnPreviewFeature(name: string) { + console.warn( + `PREVIEW FEATURE: ${name} + + This is a preview feature.It could change significantly or be removed + entirely in future versions of Flower. + `, + ) +} + +export function warnDeprecatedFeature(name: string) { + console.warn( + `DEPRECATED FEATURE: ${name} + + This is a deprecated feature.It will be removed + entirely in future versions of Flower. + `, + ) +} + +export function warnDeprecatedFeatureWithExample(deprecation_message: string, example_message: string, code_example: string) { + console.warn( + `DEPRECATED FEATURE: ${deprecation_message} + + Check the following \`FEATURE UPDATE\` warning message for the preferred + new mechanism to use this feature in Flower. + `, + ) + console.warn( + `FEATURE UPDATE: ${example_message} + ------------------------------------------------------------ + ${code_example} + ------------------------------------------------------------ + `, + ) +} + +export function warnUnsupportedFeature(name: string) { + console.warn( + `UNSUPPORTED FEATURE: ${name} + + This is an unsupported feature.It will be removed + entirely in future versions of Flower. + `, + ) +} diff --git a/src/ts/src/lib/message_handler.js b/src/ts/src/lib/message_handler.js new file mode 100644 index 000000000000..2d2ef153ef68 --- /dev/null +++ b/src/ts/src/lib/message_handler.js @@ -0,0 +1,91 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.validateOutMessage = exports.handleLegacyMessageFromMsgType = exports.handleControlMessage = void 0; +const client_1 = require("./client"); +const transport_1 = require("../protos/flwr/proto/transport"); +const recordset_1 = require("./recordset"); +const recordset_compat_1 = require("./recordset_compat"); +const reconnect = (reconnectMsg) => { + let reason = transport_1.Reason.ACK; + let sleepDuration = BigInt(0); + if (reconnectMsg.seconds !== BigInt(0)) { + reason = transport_1.Reason.RECONNECT; + sleepDuration = reconnectMsg.seconds; + } + const disconnectRes = { + reason, + }; + return [ + { + msg: { oneofKind: "disconnectRes", disconnectRes }, + }, + sleepDuration, + ]; +}; +const handleControlMessage = (message) => { + if (message.metadata.messageType === "reconnet") { + let recordset = message.content; + let seconds = recordset?.configsRecords["config"]["seconds"]; + const reconnectMsg = reconnect({ seconds: seconds }); + const disconnectMsg = reconnectMsg[0]; + const sleepDuration = reconnectMsg[1]; + if (disconnectMsg.msg.oneofKind === "disconnectRes") { + let reason = disconnectMsg.msg.disconnectRes.reason; + let recordset = new recordset_1.RecordSet(); + recordset.configsRecords["config"] = new recordset_1.ConfigsRecord({ reason: reason }); + let outMessage = message.createReply(recordset); + return [outMessage, Number(sleepDuration)]; + } + } + return [null, 0]; +}; +exports.handleControlMessage = handleControlMessage; +const handleLegacyMessageFromMsgType = (client_fn, message, context) => { + let client = client_fn(context); + client.setContext(context); + let messageType = message.metadata.messageType; + let outRecordset; + switch (messageType) { + case "get_properties": { + const getPropertiesRes = (0, client_1.maybeCallGetProperties)(client, (0, recordset_compat_1.recordSetToGetPropertiesIns)(message.content)); + outRecordset = (0, recordset_compat_1.getPropertiesResToRecordSet)(getPropertiesRes); + break; + } + case "get_parameters": { + const getParametersRes = (0, client_1.maybeCallGetParameters)(client, (0, recordset_compat_1.recordSetToGetParametersIns)(message.content)); + outRecordset = (0, recordset_compat_1.getParametersResToRecordSet)(getParametersRes, false); + break; + } + case "train": { + const fitRes = (0, client_1.maybeCallFit)(client, (0, recordset_compat_1.recordSetToFitIns)(message.content, true)); + outRecordset = (0, recordset_compat_1.fitResToRecordSet)(fitRes, false); + break; + } + case "evaluate": { + const evaluateRes = (0, client_1.maybeCallEvaluate)(client, (0, recordset_compat_1.recordSetToEvaluateIns)(message.content, true)); + outRecordset = (0, recordset_compat_1.evaluateResToRecordSet)(evaluateRes); + break; + } + default: { + throw `Invalid message type: ${messageType}`; + } + } + return message.createReply(outRecordset); +}; +exports.handleLegacyMessageFromMsgType = handleLegacyMessageFromMsgType; +const validateOutMessage = (outMessage, inMessageMetadata) => { + let outMeta = outMessage.metadata; + let inMeta = inMessageMetadata; + if (outMeta.runId === inMeta.runId && + outMeta.messageId === "" && + outMeta.srcNodeId === inMeta.dstNodeId && + outMeta.dstNodeId === inMeta.srcNodeId && + outMeta.replyToMessage === inMeta.messageId && + outMeta.groupId === inMeta.groupId && + outMeta.messageType === inMeta.messageType && + outMeta.createdAt > inMeta.createdAt) { + return true; + } + return false; +}; +exports.validateOutMessage = validateOutMessage; diff --git a/src/ts/src/lib/message_handler.test.js b/src/ts/src/lib/message_handler.test.js new file mode 100644 index 000000000000..8065763e5ecb --- /dev/null +++ b/src/ts/src/lib/message_handler.test.js @@ -0,0 +1,161 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const message_handler_1 = require("./message_handler"); +const typing_1 = require("./typing"); +const client_1 = require("./client"); +const recordset_1 = require("./recordset"); +const recordset_compat_1 = require("./recordset_compat"); +function removeCreatedAtField(metadata) { + const { createdAt, ...rest } = metadata; + return rest; +} +// Mock ClientWithoutProps and ClientWithProps +class ClientWithoutProps extends client_1.Client { + getParameters() { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + }; + } + fit() { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + numExamples: 1, + metrics: {}, + }; + } + evaluate() { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + loss: 1.0, + numExamples: 1, + metrics: {}, + }; + } +} +class ClientWithProps extends ClientWithoutProps { + getProperties() { + return { + status: { code: typing_1.Code.OK, message: "Success" }, + properties: { str_prop: "val", int_prop: 1 }, + }; + } +} +// Helper function to create the client_fn +const getClientFn = (client) => (context) => client; +describe("Message Handler Tests", () => { + const createMessage = (messageType, content) => { + return new typing_1.Message({ + runId: BigInt(123), + messageId: "abc123", + groupId: "some-group-id", + srcNodeId: BigInt(0), + dstNodeId: BigInt(1123), + replyToMessage: "", + ttl: 10, + messageType, + createdAt: 0, + }, content, {}); + }; + const context = { + nodeId: BigInt(1123), + nodeConfig: {}, + state: new recordset_1.RecordSet(), + runConfig: {}, + }; + test("Client without get_properties", () => { + const client = new ClientWithoutProps({}); + const recordset = (0, recordset_compat_1.getPropertiesInsToRecordSet)({}); + const message = createMessage("get_properties", recordset); + const actualMessage = (0, message_handler_1.handleLegacyMessageFromMsgType)(getClientFn(client), message, context); + const expectedGetPropertiesRes = { + status: { + code: typing_1.Code.GET_PROPERTIES_NOT_IMPLEMENTED, + message: "Client does not implement `get_properties`", + }, + properties: {}, + }; + const expectedRs = (0, recordset_compat_1.getPropertiesResToRecordSet)(expectedGetPropertiesRes); + const expectedMessage = new typing_1.Message({ + ...message.metadata, + messageId: "", + srcNodeId: BigInt(1123), + dstNodeId: BigInt(0), + replyToMessage: message.metadata.messageId, + ttl: actualMessage.metadata.ttl, + }, expectedRs, {}); + expect(actualMessage.content).toEqual(expectedMessage.content); + expect(removeCreatedAtField(actualMessage.metadata)).toMatchObject(removeCreatedAtField(expectedMessage.metadata)); + // expect(actualMessage.metadata.createdAt).toBeGreaterThan(message.metadata.createdAt); + }); + test("Client with get_properties", () => { + const client = new ClientWithProps({}); + const recordset = (0, recordset_compat_1.getPropertiesInsToRecordSet)({}); + const message = createMessage("get_properties", recordset); + const actualMessage = (0, message_handler_1.handleLegacyMessageFromMsgType)(getClientFn(client), message, context); + const expectedGetPropertiesRes = { + status: { code: typing_1.Code.OK, message: "Success" }, + properties: { str_prop: "val", int_prop: 1 }, + }; + const expectedRs = (0, recordset_compat_1.getPropertiesResToRecordSet)(expectedGetPropertiesRes); + const expectedMessage = new typing_1.Message({ + ...message.metadata, + messageId: "", + srcNodeId: BigInt(1123), + dstNodeId: BigInt(0), + replyToMessage: message.metadata.messageId, + ttl: actualMessage.metadata.ttl, + }, expectedRs, {}); + expect(actualMessage.content).toEqual(expectedMessage.content); + expect(removeCreatedAtField(actualMessage.metadata)).toMatchObject(removeCreatedAtField(expectedMessage.metadata)); + // expect(actualMessage.metadata.createdAt).toBeGreaterThan(message.metadata.createdAt); + }); +}); +describe("Message Validation", () => { + let inMetadata; + let validOutMetadata; + beforeEach(() => { + inMetadata = { + runId: BigInt(123), + messageId: "qwerty", + srcNodeId: BigInt(10), + dstNodeId: BigInt(20), + replyToMessage: "", + groupId: "group1", + ttl: 100, + messageType: "train", + createdAt: Date.now() - 10, + }; + validOutMetadata = { + runId: BigInt(123), + messageId: "", + srcNodeId: BigInt(20), + dstNodeId: BigInt(10), + replyToMessage: "qwerty", + groupId: "group1", + ttl: 100, + messageType: "train", + createdAt: Date.now(), + }; + }); + test("Valid message", () => { + const validMessage = new typing_1.Message(validOutMetadata, new recordset_1.RecordSet(), {}); + expect((0, message_handler_1.validateOutMessage)(validMessage, inMetadata)).toBe(true); + }); + test("Invalid message run_id", () => { + const msg = new typing_1.Message(validOutMetadata, new recordset_1.RecordSet(), {}); + const invalidMetadata = { + runId: BigInt(12), // Different runId + messageId: "qwerty", + srcNodeId: BigInt(10), + dstNodeId: BigInt(20), + replyToMessage: "", + groupId: "group1", + ttl: 100, + messageType: "train", + createdAt: Date.now() - 10, + }; + expect((0, message_handler_1.validateOutMessage)(msg, invalidMetadata)).toBe(false); + }); +}); diff --git a/src/ts/src/lib/message_handler.test.ts b/src/ts/src/lib/message_handler.test.ts new file mode 100644 index 000000000000..1c9415445474 --- /dev/null +++ b/src/ts/src/lib/message_handler.test.ts @@ -0,0 +1,205 @@ +import { handleLegacyMessageFromMsgType, validateOutMessage } from "./message_handler"; +import { + GetPropertiesRes, + Message, + Metadata, + Context, + GetPropertiesIns, + Code, + Error as LocalError, +} from "./typing"; +import { Client } from "./client"; +import { RecordSet } from "./recordset"; +import { getPropertiesInsToRecordSet, getPropertiesResToRecordSet } from "./recordset_compat"; + +function removeCreatedAtField(metadata: Metadata): Partial { + const { createdAt, ...rest } = metadata; + return rest; +} + +// Mock ClientWithoutProps and ClientWithProps +class ClientWithoutProps extends Client { + getParameters() { + return { + status: { code: Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + }; + } + + fit() { + return { + status: { code: Code.OK, message: "Success" }, + parameters: { tensors: [], tensorType: "" }, + numExamples: 1, + metrics: {}, + }; + } + + evaluate() { + return { + status: { code: Code.OK, message: "Success" }, + loss: 1.0, + numExamples: 1, + metrics: {}, + }; + } +} + +class ClientWithProps extends ClientWithoutProps { + getProperties() { + return { + status: { code: Code.OK, message: "Success" }, + properties: { str_prop: "val", int_prop: 1 }, + }; + } +} + +// Helper function to create the client_fn +const getClientFn = (client: any) => (context: Context) => client; + +describe("Message Handler Tests", () => { + const createMessage = (messageType: string, content: RecordSet) => { + return new Message( + { + runId: BigInt(123), + messageId: "abc123", + groupId: "some-group-id", + srcNodeId: BigInt(0), + dstNodeId: BigInt(1123), + replyToMessage: "", + ttl: 10, + messageType, + createdAt: 0, + }, + content, + {} as LocalError, + ); + }; + + const context: Context = { + nodeId: BigInt(1123), + nodeConfig: {}, + state: new RecordSet(), + runConfig: {}, + }; + + test("Client without get_properties", () => { + const client = new ClientWithoutProps({} as Context); + const recordset = getPropertiesInsToRecordSet({} as GetPropertiesIns); + const message = createMessage("get_properties", recordset); + + const actualMessage = handleLegacyMessageFromMsgType(getClientFn(client), message, context); + + const expectedGetPropertiesRes: GetPropertiesRes = { + status: { + code: Code.GET_PROPERTIES_NOT_IMPLEMENTED, + message: "Client does not implement `get_properties`", + }, + properties: {}, + }; + const expectedRs = getPropertiesResToRecordSet(expectedGetPropertiesRes); + const expectedMessage = new Message( + { + ...message.metadata, + messageId: "", + srcNodeId: BigInt(1123), + dstNodeId: BigInt(0), + replyToMessage: message.metadata.messageId, + ttl: actualMessage.metadata.ttl, + }, + expectedRs, + {} as LocalError, + ); + + expect(actualMessage.content).toEqual(expectedMessage.content); + expect(removeCreatedAtField(actualMessage.metadata)).toMatchObject( + removeCreatedAtField(expectedMessage.metadata), + ); + // expect(actualMessage.metadata.createdAt).toBeGreaterThan(message.metadata.createdAt); + }); + + test("Client with get_properties", () => { + const client = new ClientWithProps({} as Context); + const recordset = getPropertiesInsToRecordSet({} as GetPropertiesIns); + const message = createMessage("get_properties", recordset); + + const actualMessage = handleLegacyMessageFromMsgType(getClientFn(client), message, context); + + const expectedGetPropertiesRes: GetPropertiesRes = { + status: { code: Code.OK, message: "Success" }, + properties: { str_prop: "val", int_prop: 1 }, + }; + const expectedRs = getPropertiesResToRecordSet(expectedGetPropertiesRes); + const expectedMessage = new Message( + { + ...message.metadata, + messageId: "", + srcNodeId: BigInt(1123), + dstNodeId: BigInt(0), + replyToMessage: message.metadata.messageId, + ttl: actualMessage.metadata.ttl, + }, + expectedRs, + {} as LocalError, + ); + + expect(actualMessage.content).toEqual(expectedMessage.content); + expect(removeCreatedAtField(actualMessage.metadata)).toMatchObject( + removeCreatedAtField(expectedMessage.metadata), + ); + // expect(actualMessage.metadata.createdAt).toBeGreaterThan(message.metadata.createdAt); + }); +}); + +describe("Message Validation", () => { + let inMetadata: Metadata; + let validOutMetadata: Metadata; + + beforeEach(() => { + inMetadata = { + runId: BigInt(123), + messageId: "qwerty", + srcNodeId: BigInt(10), + dstNodeId: BigInt(20), + replyToMessage: "", + groupId: "group1", + ttl: 100, + messageType: "train", + createdAt: Date.now() - 10, + }; + + validOutMetadata = { + runId: BigInt(123), + messageId: "", + srcNodeId: BigInt(20), + dstNodeId: BigInt(10), + replyToMessage: "qwerty", + groupId: "group1", + ttl: 100, + messageType: "train", + createdAt: Date.now(), + }; + }); + + test("Valid message", () => { + const validMessage: Message = new Message(validOutMetadata, new RecordSet(), {} as LocalError); + expect(validateOutMessage(validMessage, inMetadata)).toBe(true); + }); + + test("Invalid message run_id", () => { + const msg: Message = new Message(validOutMetadata, new RecordSet(), {} as LocalError); + + const invalidMetadata = { + runId: BigInt(12), // Different runId + messageId: "qwerty", + srcNodeId: BigInt(10), + dstNodeId: BigInt(20), + replyToMessage: "", + groupId: "group1", + ttl: 100, + messageType: "train", + createdAt: Date.now() - 10, + }; + expect(validateOutMessage(msg, invalidMetadata)).toBe(false); + }); +}); diff --git a/src/ts/src/lib/message_handler.ts b/src/ts/src/lib/message_handler.ts new file mode 100644 index 000000000000..6a7286d32ca4 --- /dev/null +++ b/src/ts/src/lib/message_handler.ts @@ -0,0 +1,129 @@ +import { + Client, + maybeCallEvaluate, + maybeCallFit, + maybeCallGetParameters, + maybeCallGetProperties, +} from "./client"; +import { + ClientMessage as ProtoClientMessage, + Reason as ProtoReason, + ServerMessage_ReconnectIns as ProtoServerMessage_ReconnectIns, + ClientMessage_DisconnectRes as ProtoClientMessage_DisconnectRes, + ServerMessage_ReconnectIns, +} from "../protos/flwr/proto/transport"; +import { Message, Context, Metadata } from "./typing"; +import { RecordSet, ConfigsRecord } from "./recordset"; +import { + getParametersResToRecordSet, + getPropertiesResToRecordSet, + recordSetToFitIns, + recordSetToGetParametersIns, + recordSetToGetPropertiesIns, + fitResToRecordSet, + recordSetToEvaluateIns, + evaluateResToRecordSet, +} from "./recordset_compat"; + +const reconnect = (reconnectMsg: ProtoServerMessage_ReconnectIns): [ProtoClientMessage, bigint] => { + let reason = ProtoReason.ACK; + let sleepDuration = BigInt(0); + if (reconnectMsg.seconds !== BigInt(0)) { + reason = ProtoReason.RECONNECT; + sleepDuration = reconnectMsg.seconds; + } + + const disconnectRes: ProtoClientMessage_DisconnectRes = { + reason, + }; + + return [ + { + msg: { oneofKind: "disconnectRes", disconnectRes }, + } as ProtoClientMessage, + sleepDuration, + ]; +}; + +export const handleControlMessage = (message: Message): [Message | null, number] => { + if (message.metadata.messageType === "reconnet") { + let recordset = message.content; + let seconds = recordset?.configsRecords["config"]["seconds"]!; + const reconnectMsg = reconnect({ seconds: seconds as bigint } as ServerMessage_ReconnectIns); + const disconnectMsg = reconnectMsg[0]; + const sleepDuration = reconnectMsg[1]; + if (disconnectMsg.msg.oneofKind === "disconnectRes") { + let reason = disconnectMsg.msg.disconnectRes.reason as number; + let recordset = new RecordSet(); + recordset.configsRecords["config"] = new ConfigsRecord({ reason: reason }); + let outMessage = message.createReply(recordset); + return [outMessage, Number(sleepDuration)]; + } + } + + return [null, 0]; +}; + +export const handleLegacyMessageFromMsgType = ( + client_fn: (context: Context) => Client, + message: Message, + context: Context, +): Message => { + let client = client_fn(context); + client.setContext(context); + + let messageType = message.metadata.messageType; + let outRecordset: RecordSet; + + switch (messageType) { + case "get_properties": { + const getPropertiesRes = maybeCallGetProperties( + client, + recordSetToGetPropertiesIns(message.content!), + ); + outRecordset = getPropertiesResToRecordSet(getPropertiesRes); + break; + } + case "get_parameters": { + const getParametersRes = maybeCallGetParameters( + client, + recordSetToGetParametersIns(message.content!), + ); + outRecordset = getParametersResToRecordSet(getParametersRes, false); + break; + } + case "train": { + const fitRes = maybeCallFit(client, recordSetToFitIns(message.content!, true)); + outRecordset = fitResToRecordSet(fitRes, false); + break; + } + case "evaluate": { + const evaluateRes = maybeCallEvaluate(client, recordSetToEvaluateIns(message.content!, true)); + outRecordset = evaluateResToRecordSet(evaluateRes); + break; + } + default: { + throw `Invalid message type: ${messageType}`; + } + } + + return message.createReply(outRecordset); +}; + +export const validateOutMessage = (outMessage: Message, inMessageMetadata: Metadata) => { + let outMeta = outMessage.metadata; + let inMeta = inMessageMetadata; + if ( + outMeta.runId === inMeta.runId && + outMeta.messageId === "" && + outMeta.srcNodeId === inMeta.dstNodeId && + outMeta.dstNodeId === inMeta.srcNodeId && + outMeta.replyToMessage === inMeta.messageId && + outMeta.groupId === inMeta.groupId && + outMeta.messageType === inMeta.messageType && + outMeta.createdAt > inMeta.createdAt + ) { + return true; + } + return false; +}; diff --git a/src/ts/src/lib/node_state.js b/src/ts/src/lib/node_state.js new file mode 100644 index 000000000000..2b7fab029aa1 --- /dev/null +++ b/src/ts/src/lib/node_state.js @@ -0,0 +1,70 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.NodeState = void 0; +const lodash_1 = require("lodash"); +const recordset_1 = require("./recordset"); +class NodeState { + nodeId; + nodeConfig; + runInfos; + constructor(nodeId, nodeConfig) { + this.nodeId = nodeId; + this.nodeConfig = nodeConfig; + this.runInfos = {}; + } + // Register a new run context for this node + registerContext(runId, run = null, flwrPath = null, appDir = null, fab = null) { + if (!(runId in this.runInfos)) { + let initialRunConfig = {}; + if (appDir) { + const appPath = appDir; // appPath is a string instead of Pathlike in this case + if ( /* Check if appPath is a directory - this needs specific NodeJS code */true) { + const overrideConfig = run?.overrideConfig || {}; + // initialRunConfig = getFusedConfigFromDir(appPath, overrideConfig); + } + else { + throw new Error('The specified `appDir` must be a directory.'); + } + } + else { + if (run) { + if (fab) { + // Load config from FAB and fuse + // initialRunConfig = getFusedConfigFromFab(fab.content, run); + } + else { + // Load config from installed FAB and fuse + // initialRunConfig = getFusedConfig(run, flwrPath); + } + } + else { + initialRunConfig = {}; + } + } + this.runInfos[runId] = { + initialRunConfig, + context: { + nodeId: this.nodeId, + nodeConfig: this.nodeConfig, + state: new recordset_1.RecordSet(), + runConfig: { ...initialRunConfig }, + }, + }; + } + } + // Retrieve the context given a runId + retrieveContext(runId) { + if (runId in this.runInfos) { + return this.runInfos[runId].context; + } + throw new Error(`Context for runId=${runId} doesn't exist. A run context must be registered before it can be retrieved or updated by a client.`); + } + // Update run context + updateContext(runId, context) { + if (!(0, lodash_1.isEqual)(context.runConfig, this.runInfos[runId].initialRunConfig)) { + throw new Error(`The run_config field of the Context object cannot be modified (runId: ${runId}).`); + } + this.runInfos[runId].context = context; + } +} +exports.NodeState = NodeState; diff --git a/src/ts/src/lib/node_state.ts b/src/ts/src/lib/node_state.ts new file mode 100644 index 000000000000..e0a244cbe53c --- /dev/null +++ b/src/ts/src/lib/node_state.ts @@ -0,0 +1,89 @@ +import { isEqual } from 'lodash'; +import { PathLike } from 'fs'; +import { RecordSet } from './recordset'; +import { Fab, Run, UserConfig, Context } from './typing'; +import { getFusedConfig, getFusedConfigFromDir, getFusedConfigFromFab } from './config'; + +interface RunInfo { + context: Context; + initialRunConfig: UserConfig; +} + +export class NodeState { + nodeId: bigint; + nodeConfig: UserConfig; + runInfos: { [key: number]: RunInfo }; + + constructor(nodeId: bigint, nodeConfig: UserConfig) { + this.nodeId = nodeId; + this.nodeConfig = nodeConfig; + this.runInfos = {}; + } + + // Register a new run context for this node + registerContext( + runId: number, + run: Run | null = null, + flwrPath: PathLike | null = null, + appDir: string | null = null, + fab: Fab | null = null, + ): void { + if (!(runId in this.runInfos)) { + let initialRunConfig: UserConfig = {}; + + if (appDir) { + const appPath = appDir; // appPath is a string instead of Pathlike in this case + + if (/* Check if appPath is a directory - this needs specific NodeJS code */ true) { + const overrideConfig = run?.overrideConfig || {}; + // initialRunConfig = getFusedConfigFromDir(appPath, overrideConfig); + } else { + throw new Error('The specified `appDir` must be a directory.'); + } + } else { + if (run) { + if (fab) { + // Load config from FAB and fuse + // initialRunConfig = getFusedConfigFromFab(fab.content, run); + } else { + // Load config from installed FAB and fuse + // initialRunConfig = getFusedConfig(run, flwrPath); + } + } else { + initialRunConfig = {}; + } + } + + this.runInfos[runId] = { + initialRunConfig, + context: { + nodeId: this.nodeId, + nodeConfig: this.nodeConfig, + state: new RecordSet(), + runConfig: { ...initialRunConfig }, + } as Context, + }; + } + } + + // Retrieve the context given a runId + retrieveContext(runId: number): Context { + if (runId in this.runInfos) { + return this.runInfos[runId].context; + } + + throw new Error( + `Context for runId=${runId} doesn't exist. A run context must be registered before it can be retrieved or updated by a client.`, + ); + } + + // Update run context + updateContext(runId: number, context: Context): void { + if (!isEqual(context.runConfig, this.runInfos[runId].initialRunConfig)) { + throw new Error( + `The run_config field of the Context object cannot be modified (runId: ${runId}).`, + ); + } + this.runInfos[runId].context = context; + } +} diff --git a/src/ts/src/lib/recordset.js b/src/ts/src/lib/recordset.js new file mode 100644 index 000000000000..fe4c07be867b --- /dev/null +++ b/src/ts/src/lib/recordset.js @@ -0,0 +1,45 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.RecordSet = exports.ConfigsRecord = exports.MetricsRecord = exports.ParametersRecord = exports.ArrayData = void 0; +class ArrayData { + dtype; + shape; + stype; + data; + constructor(dtype, shape, stype, data) { + this.dtype = dtype; + this.shape = shape; + this.stype = stype; + this.data = data; + } +} +exports.ArrayData = ArrayData; +class ParametersRecord { + constructor(data = {}) { + Object.assign(this, data); + } +} +exports.ParametersRecord = ParametersRecord; +class MetricsRecord { + constructor(data = {}) { + Object.assign(this, data); + } +} +exports.MetricsRecord = MetricsRecord; +class ConfigsRecord { + constructor(data = {}) { + Object.assign(this, data); + } +} +exports.ConfigsRecord = ConfigsRecord; +class RecordSet { + parametersRecords = {}; + metricsRecords = {}; + configsRecords = {}; + constructor(parametersRecords = {}, metricsRecords = {}, configsRecords = {}) { + this.parametersRecords = parametersRecords; + this.metricsRecords = metricsRecords; + this.configsRecords = configsRecords; + } +} +exports.RecordSet = RecordSet; diff --git a/src/ts/src/lib/recordset.ts b/src/ts/src/lib/recordset.ts new file mode 100644 index 000000000000..ed13757a9eb8 --- /dev/null +++ b/src/ts/src/lib/recordset.ts @@ -0,0 +1,56 @@ +export type ConfigsRecordValue = + | string + | bigint + | number + | boolean + | (string | bigint | number | boolean)[]; +export type MetricsRecordValue = number | bigint | (number | bigint)[]; + +export class ArrayData { + constructor( + public dtype: string, + public shape: number[], + public stype: string, + public data: Uint8Array, + ) {} +} + +export class ParametersRecord { + [key: string]: ArrayData; + + constructor(data: { [key: string]: ArrayData } = {}) { + Object.assign(this, data); + } +} + +export class MetricsRecord { + [key: string]: MetricsRecordValue; + + constructor(data: { [key: string]: MetricsRecordValue } = {}) { + Object.assign(this, data); + } +} + +export class ConfigsRecord { + [key: string]: ConfigsRecordValue; + + constructor(data: { [key: string]: ConfigsRecordValue } = {}) { + Object.assign(this, data); + } +} + +export class RecordSet { + public parametersRecords: { [key: string]: ParametersRecord } = {}; + public metricsRecords: { [key: string]: MetricsRecord } = {}; + public configsRecords: { [key: string]: ConfigsRecord } = {}; + + constructor( + parametersRecords: { [key: string]: ParametersRecord } = {}, + metricsRecords: { [key: string]: MetricsRecord } = {}, + configsRecords: { [key: string]: ConfigsRecord } = {}, + ) { + this.parametersRecords = parametersRecords; + this.metricsRecords = metricsRecords; + this.configsRecords = configsRecords; + } +} diff --git a/src/ts/src/lib/recordset_compat.js b/src/ts/src/lib/recordset_compat.js new file mode 100644 index 000000000000..cd7d203a8d23 --- /dev/null +++ b/src/ts/src/lib/recordset_compat.js @@ -0,0 +1,197 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.recordSetToFitIns = recordSetToFitIns; +exports.fitInsToRecordSet = fitInsToRecordSet; +exports.recordSetToFitRes = recordSetToFitRes; +exports.fitResToRecordSet = fitResToRecordSet; +exports.recordSetToEvaluateIns = recordSetToEvaluateIns; +exports.evaluateInsToRecordSet = evaluateInsToRecordSet; +exports.recordSetToEvaluateRes = recordSetToEvaluateRes; +exports.evaluateResToRecordSet = evaluateResToRecordSet; +exports.recordSetToGetParametersIns = recordSetToGetParametersIns; +exports.recordSetToGetPropertiesIns = recordSetToGetPropertiesIns; +exports.getParametersInsToRecordSet = getParametersInsToRecordSet; +exports.getPropertiesInsToRecordSet = getPropertiesInsToRecordSet; +exports.getParametersResToRecordSet = getParametersResToRecordSet; +exports.getPropertiesResToRecordSet = getPropertiesResToRecordSet; +exports.recordSetToGetParametersRes = recordSetToGetParametersRes; +exports.recordSetToGetPropertiesRes = recordSetToGetPropertiesRes; +const recordset_1 = require("./recordset"); +function parametersrecordToParameters(record, keepInput) { + const parameters = { tensors: [], tensorType: "" }; + Object.keys(record).forEach((key) => { + const arrayData = record[key]; + if (key !== "EMPTY_TENSOR_KEY") { + parameters.tensors.push(arrayData.data); + } + if (!parameters.tensorType) { + parameters.tensorType = arrayData.stype; + } + if (!keepInput) { + delete record[key]; + } + }); + return parameters; +} +function parametersToParametersRecord(parameters, keepInput) { + const tensorType = parameters.tensorType; + const orderedDict = new recordset_1.ParametersRecord({}); + parameters.tensors.forEach((tensor, idx) => { + const array = new recordset_1.ArrayData("", [], tensorType, tensor); + orderedDict[String(idx)] = array; + if (!keepInput) { + parameters.tensors.shift(); + } + }); + if (parameters.tensors.length === 0) { + orderedDict["EMPTY_TENSOR_KEY"] = new recordset_1.ArrayData("", [], tensorType, new Uint8Array()); + } + return orderedDict; +} +function checkMappingFromRecordScalarTypeToScalar(recordData) { + if (!recordData) { + throw new TypeError("Invalid input: recordData is undefined or null"); + } + Object.values(recordData).forEach((value) => { + if (typeof value !== "number" && + typeof value !== "string" && + typeof value !== "boolean" && + !(value instanceof Uint8Array)) { + throw new TypeError(`Invalid scalar type found: ${typeof value}`); + } + }); + return recordData; +} +function recordSetToFitOrEvaluateInsComponents(recordset, insStr, keepInput) { + const parametersRecord = recordset.parametersRecords[`${insStr}.parameters`]; + const parameters = parametersrecordToParameters(parametersRecord, keepInput); + const configRecord = recordset.configsRecords[`${insStr}.config`]; + const configDict = checkMappingFromRecordScalarTypeToScalar(configRecord); + return { parameters, config: configDict }; +} +function fitOrEvaluateInsToRecordSet(ins, keepInput, insStr) { + const recordset = new recordset_1.RecordSet(); + const parametersRecord = parametersToParametersRecord(ins.parameters, keepInput); + recordset.parametersRecords[`${insStr}.parameters`] = parametersRecord; + recordset.configsRecords[`${insStr}.config`] = new recordset_1.ConfigsRecord(ins.config); + return recordset; +} +function embedStatusIntoRecordSet(resStr, status, recordset) { + const statusDict = { + code: status.code, + message: status.message, + }; + recordset.configsRecords[`${resStr}.status`] = new recordset_1.ConfigsRecord(statusDict); + return recordset; +} +function extractStatusFromRecordSet(resStr, recordset) { + const status = recordset.configsRecords[`${resStr}.status`]; + const code = status["code"]; + return { code, message: status["message"] }; +} +function recordSetToFitIns(recordset, keepInput) { + const { parameters, config } = recordSetToFitOrEvaluateInsComponents(recordset, "fitins", keepInput); + return { parameters, config }; +} +function fitInsToRecordSet(fitins, keepInput) { + return fitOrEvaluateInsToRecordSet(fitins, keepInput, "fitins"); +} +function recordSetToFitRes(recordset, keepInput) { + const insStr = "fitres"; + const parameters = parametersrecordToParameters(recordset.parametersRecords[`${insStr}.parameters`], keepInput); + const numExamples = recordset.metricsRecords[`${insStr}.num_examples`]["num_examples"]; + const configRecord = recordset.configsRecords[`${insStr}.metrics`]; + const metrics = checkMappingFromRecordScalarTypeToScalar(configRecord); + const status = extractStatusFromRecordSet(insStr, recordset); + return { status, parameters, numExamples, metrics }; +} +function fitResToRecordSet(fitres, keepInput) { + const recordset = new recordset_1.RecordSet(); + const resStr = "fitres"; + recordset.configsRecords[`${resStr}.metrics`] = new recordset_1.ConfigsRecord(fitres.metrics); + recordset.metricsRecords[`${resStr}.num_examples`] = new recordset_1.MetricsRecord({ + num_examples: fitres.numExamples, + }); + recordset.parametersRecords[`${resStr}.parameters`] = parametersToParametersRecord(fitres.parameters, keepInput); + return embedStatusIntoRecordSet(resStr, fitres.status, recordset); +} +function recordSetToEvaluateIns(recordset, keepInput) { + const { parameters, config } = recordSetToFitOrEvaluateInsComponents(recordset, "evaluateins", keepInput); + return { parameters, config }; +} +function evaluateInsToRecordSet(evaluateIns, keepInput) { + return fitOrEvaluateInsToRecordSet(evaluateIns, keepInput, "evaluateins"); +} +function recordSetToEvaluateRes(recordset) { + const insStr = "evaluateres"; + const loss = recordset.metricsRecords[`${insStr}.loss`]["loss"]; + const numExamples = recordset.metricsRecords[`${insStr}.num_examples`]["numExamples"]; + const configsRecord = recordset.configsRecords[`${insStr}.metrics`]; + const metrics = Object.fromEntries(Object.entries(configsRecord).map(([key, value]) => [key, value])); + const status = extractStatusFromRecordSet(insStr, recordset); + return { status, loss, numExamples, metrics }; +} +function evaluateResToRecordSet(evaluateRes) { + const recordset = new recordset_1.RecordSet(); + const resStr = "evaluateres"; + recordset.metricsRecords[`${resStr}.loss`] = new recordset_1.MetricsRecord({ loss: evaluateRes.loss }); + recordset.metricsRecords[`${resStr}.num_examples`] = new recordset_1.MetricsRecord({ + numExamples: evaluateRes.numExamples, + }); + recordset.configsRecords[`${resStr}.metrics`] = new recordset_1.ConfigsRecord(evaluateRes.metrics); + return embedStatusIntoRecordSet(resStr, evaluateRes.status, recordset); +} +function recordSetToGetParametersIns(recordset) { + const configRecord = recordset.configsRecords["getparametersins.config"]; + const configDict = checkMappingFromRecordScalarTypeToScalar(configRecord); + return { config: configDict }; +} +function recordSetToGetPropertiesIns(recordset) { + const configRecord = recordset.configsRecords["getpropertiesins.config"]; + const configDict = checkMappingFromRecordScalarTypeToScalar(configRecord); + return { config: configDict }; +} +function getParametersInsToRecordSet(getParametersIns) { + const recordset = new recordset_1.RecordSet(); + recordset.configsRecords["getparametersins.config"] = new recordset_1.ConfigsRecord(getParametersIns.config); + return recordset; +} +function getPropertiesInsToRecordSet(getPropertiesIns) { + try { + const recordset = new recordset_1.RecordSet(); + let config; + if (getPropertiesIns && "config" in getPropertiesIns) + config = getPropertiesIns.config || {}; + else + config = {}; + recordset.configsRecords["getpropertiesins.config"] = new recordset_1.ConfigsRecord(config); + return recordset; + } + catch (error) { + console.error("Error in getPropertiesInsToRecordSet:", error); + throw error; // You can throw or return a default value based on your requirement + } +} +function getParametersResToRecordSet(getParametersRes, keepInput) { + const recordset = new recordset_1.RecordSet(); + const parametersRecord = parametersToParametersRecord(getParametersRes.parameters, keepInput); + recordset.parametersRecords["getparametersres.parameters"] = parametersRecord; + return embedStatusIntoRecordSet("getparametersres", getParametersRes.status, recordset); +} +function getPropertiesResToRecordSet(getPropertiesRes) { + const recordset = new recordset_1.RecordSet(); + recordset.configsRecords["getpropertiesres.properties"] = new recordset_1.ConfigsRecord(getPropertiesRes.properties); + return embedStatusIntoRecordSet("getpropertiesres", getPropertiesRes.status, recordset); +} +function recordSetToGetParametersRes(recordset, keepInput) { + const resStr = "getparametersres"; + const parameters = parametersrecordToParameters(recordset.parametersRecords[`${resStr}.parameters`], keepInput); + const status = extractStatusFromRecordSet(resStr, recordset); + return { status, parameters }; +} +function recordSetToGetPropertiesRes(recordset) { + const resStr = "getpropertiesres"; + const properties = checkMappingFromRecordScalarTypeToScalar(recordset.configsRecords[`${resStr}.properties`]); + const status = extractStatusFromRecordSet(resStr, recordset); + return { status, properties }; +} diff --git a/src/ts/src/lib/recordset_compat.test.js b/src/ts/src/lib/recordset_compat.test.js new file mode 100644 index 000000000000..20320221d8fc --- /dev/null +++ b/src/ts/src/lib/recordset_compat.test.js @@ -0,0 +1,73 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const recordset_compat_1 = require("./recordset_compat"); +// Mock data +const mockScalar = "test_scalar"; +const mockParameters = { + tensors: [new Uint8Array([1, 2, 3])], + tensorType: "float32", +}; +const mockFitIns = { parameters: mockParameters, config: { key1: mockScalar } }; +const mockFitRes = { + parameters: mockParameters, + numExamples: 100, + metrics: { key2: mockScalar }, + status: { code: 0, message: "OK" }, +}; +const mockEvaluateIns = { parameters: mockParameters, config: { key3: mockScalar } }; +const mockEvaluateRes = { + loss: 1.5, + numExamples: 50, + metrics: { key4: mockScalar }, + status: { code: 0, message: "OK" }, +}; +const mockGetParametersRes = { + parameters: mockParameters, + status: { code: 0, message: "OK" }, +}; +const mockGetPropertiesRes = { + properties: { key5: mockScalar }, + status: { code: 0, message: "OK" }, +}; +describe("RecordSet Compatibility Functions", () => { + it("should convert recordset to FitIns", () => { + const recordset = (0, recordset_compat_1.fitInsToRecordSet)(mockFitIns, true); + const fitIns = (0, recordset_compat_1.recordSetToFitIns)(recordset, true); + expect(fitIns).toEqual(mockFitIns); + }); + it("should convert recordset to FitRes", () => { + const recordset = (0, recordset_compat_1.fitResToRecordSet)(mockFitRes, true); + const fitRes = (0, recordset_compat_1.recordSetToFitRes)(recordset, true); + expect(fitRes).toEqual(mockFitRes); + }); + it("should convert recordset to EvaluateIns", () => { + const recordset = (0, recordset_compat_1.evaluateInsToRecordSet)(mockEvaluateIns, true); + const evaluateIns = (0, recordset_compat_1.recordSetToEvaluateIns)(recordset, true); + expect(evaluateIns).toEqual(mockEvaluateIns); + }); + it("should convert recordset to EvaluateRes", () => { + const recordset = (0, recordset_compat_1.evaluateResToRecordSet)(mockEvaluateRes); + const evaluateRes = (0, recordset_compat_1.recordSetToEvaluateRes)(recordset); + expect(evaluateRes).toEqual(mockEvaluateRes); + }); + it("should convert recordset to GetParametersIns", () => { + const recordset = (0, recordset_compat_1.getParametersInsToRecordSet)({ config: { key6: mockScalar } }); + const getParametersIns = (0, recordset_compat_1.recordSetToGetParametersIns)(recordset); + expect(getParametersIns).toEqual({ config: { key6: mockScalar } }); + }); + it("should convert recordset to GetPropertiesIns", () => { + const recordset = (0, recordset_compat_1.getPropertiesInsToRecordSet)({ config: { key7: mockScalar } }); + const getPropertiesIns = (0, recordset_compat_1.recordSetToGetPropertiesIns)(recordset); + expect(getPropertiesIns).toEqual({ config: { key7: mockScalar } }); + }); + it("should convert GetParametersRes to RecordSet and back", () => { + const recordset = (0, recordset_compat_1.getParametersResToRecordSet)(mockGetParametersRes, true); + const getParametersRes = (0, recordset_compat_1.recordSetToGetParametersRes)(recordset, true); + expect(getParametersRes).toEqual(mockGetParametersRes); + }); + it("should convert GetPropertiesRes to RecordSet and back", () => { + const recordset = (0, recordset_compat_1.getPropertiesResToRecordSet)(mockGetPropertiesRes); + const getPropertiesRes = (0, recordset_compat_1.recordSetToGetPropertiesRes)(recordset); + expect(getPropertiesRes).toEqual(mockGetPropertiesRes); + }); +}); diff --git a/src/ts/src/lib/recordset_compat.test.ts b/src/ts/src/lib/recordset_compat.test.ts new file mode 100644 index 000000000000..818ad949ba18 --- /dev/null +++ b/src/ts/src/lib/recordset_compat.test.ts @@ -0,0 +1,107 @@ +import { + Parameters, + Scalar, + FitIns, + FitRes, + GetParametersRes, + GetPropertiesRes, + EvaluateIns, + EvaluateRes, +} from "./typing"; +import { + recordSetToFitIns, + fitInsToRecordSet, + recordSetToFitRes, + fitResToRecordSet, + recordSetToEvaluateIns, + evaluateInsToRecordSet, + recordSetToEvaluateRes, + evaluateResToRecordSet, + recordSetToGetParametersIns, + getParametersInsToRecordSet, + recordSetToGetPropertiesIns, + getPropertiesInsToRecordSet, + getParametersResToRecordSet, + getPropertiesResToRecordSet, + recordSetToGetParametersRes, + recordSetToGetPropertiesRes, +} from "./recordset_compat"; + +// Mock data +const mockScalar: Scalar = "test_scalar"; +const mockParameters: Parameters = { + tensors: [new Uint8Array([1, 2, 3])], + tensorType: "float32", +}; +const mockFitIns: FitIns = { parameters: mockParameters, config: { key1: mockScalar } }; +const mockFitRes: FitRes = { + parameters: mockParameters, + numExamples: 100, + metrics: { key2: mockScalar }, + status: { code: 0, message: "OK" }, +}; +const mockEvaluateIns: EvaluateIns = { parameters: mockParameters, config: { key3: mockScalar } }; +const mockEvaluateRes: EvaluateRes = { + loss: 1.5, + numExamples: 50, + metrics: { key4: mockScalar }, + status: { code: 0, message: "OK" }, +}; +const mockGetParametersRes: GetParametersRes = { + parameters: mockParameters, + status: { code: 0, message: "OK" }, +}; +const mockGetPropertiesRes: GetPropertiesRes = { + properties: { key5: mockScalar }, + status: { code: 0, message: "OK" }, +}; + +describe("RecordSet Compatibility Functions", () => { + it("should convert recordset to FitIns", () => { + const recordset = fitInsToRecordSet(mockFitIns, true); + const fitIns = recordSetToFitIns(recordset, true); + expect(fitIns).toEqual(mockFitIns); + }); + + it("should convert recordset to FitRes", () => { + const recordset = fitResToRecordSet(mockFitRes, true); + const fitRes = recordSetToFitRes(recordset, true); + expect(fitRes).toEqual(mockFitRes); + }); + + it("should convert recordset to EvaluateIns", () => { + const recordset = evaluateInsToRecordSet(mockEvaluateIns, true); + const evaluateIns = recordSetToEvaluateIns(recordset, true); + expect(evaluateIns).toEqual(mockEvaluateIns); + }); + + it("should convert recordset to EvaluateRes", () => { + const recordset = evaluateResToRecordSet(mockEvaluateRes); + const evaluateRes = recordSetToEvaluateRes(recordset); + expect(evaluateRes).toEqual(mockEvaluateRes); + }); + + it("should convert recordset to GetParametersIns", () => { + const recordset = getParametersInsToRecordSet({ config: { key6: mockScalar } }); + const getParametersIns = recordSetToGetParametersIns(recordset); + expect(getParametersIns).toEqual({ config: { key6: mockScalar } }); + }); + + it("should convert recordset to GetPropertiesIns", () => { + const recordset = getPropertiesInsToRecordSet({ config: { key7: mockScalar } }); + const getPropertiesIns = recordSetToGetPropertiesIns(recordset); + expect(getPropertiesIns).toEqual({ config: { key7: mockScalar } }); + }); + + it("should convert GetParametersRes to RecordSet and back", () => { + const recordset = getParametersResToRecordSet(mockGetParametersRes, true); + const getParametersRes = recordSetToGetParametersRes(recordset, true); + expect(getParametersRes).toEqual(mockGetParametersRes); + }); + + it("should convert GetPropertiesRes to RecordSet and back", () => { + const recordset = getPropertiesResToRecordSet(mockGetPropertiesRes); + const getPropertiesRes = recordSetToGetPropertiesRes(recordset); + expect(getPropertiesRes).toEqual(mockGetPropertiesRes); + }); +}); diff --git a/src/ts/src/lib/recordset_compat.ts b/src/ts/src/lib/recordset_compat.ts new file mode 100644 index 000000000000..83e62c38d67c --- /dev/null +++ b/src/ts/src/lib/recordset_compat.ts @@ -0,0 +1,312 @@ +import { + Parameters, + Scalar, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + EvaluateIns, + EvaluateRes, + Metrics, +} from "./typing"; +import { + ParametersRecord, + RecordSet, + ArrayData, + ConfigsRecordValue, + ConfigsRecord, + MetricsRecord, + MetricsRecordValue, +} from "./recordset"; + +function parametersrecordToParameters(record: ParametersRecord, keepInput: boolean): Parameters { + const parameters: Parameters = { tensors: [], tensorType: "" }; + + Object.keys(record).forEach((key) => { + const arrayData = record[key]; + if (key !== "EMPTY_TENSOR_KEY") { + parameters.tensors.push(arrayData.data); + } + if (!parameters.tensorType) { + parameters.tensorType = arrayData.stype; + } + if (!keepInput) { + delete record[key]; + } + }); + + return parameters; +} + +function parametersToParametersRecord( + parameters: Parameters, + keepInput: boolean, +): ParametersRecord { + const tensorType = parameters.tensorType; + const orderedDict: ParametersRecord = new ParametersRecord({}); + + parameters.tensors.forEach((tensor, idx) => { + const array = new ArrayData("", [], tensorType, tensor); + orderedDict[String(idx)] = array; + + if (!keepInput) { + parameters.tensors.shift(); + } + }); + + if (parameters.tensors.length === 0) { + orderedDict["EMPTY_TENSOR_KEY"] = new ArrayData("", [], tensorType, new Uint8Array()); + } + + return orderedDict; +} + +function checkMappingFromRecordScalarTypeToScalar( + recordData: Record, +): Record { + if (!recordData) { + throw new TypeError("Invalid input: recordData is undefined or null"); + } + Object.values(recordData).forEach((value) => { + if ( + typeof value !== "number" && + typeof value !== "string" && + typeof value !== "boolean" && + !(value instanceof Uint8Array) + ) { + throw new TypeError(`Invalid scalar type found: ${typeof value}`); + } + }); + + return recordData as Record; +} + +function recordSetToFitOrEvaluateInsComponents( + recordset: RecordSet, + insStr: string, + keepInput: boolean, +): { parameters: Parameters; config: Record } { + const parametersRecord = recordset.parametersRecords[`${insStr}.parameters`]; + const parameters = parametersrecordToParameters(parametersRecord, keepInput); + + const configRecord = recordset.configsRecords[`${insStr}.config`]; + const configDict = checkMappingFromRecordScalarTypeToScalar(configRecord); + + return { parameters, config: configDict }; +} + +function fitOrEvaluateInsToRecordSet( + ins: { parameters: Parameters; config: Record }, + keepInput: boolean, + insStr: string, +): RecordSet { + const recordset = new RecordSet(); + + const parametersRecord = parametersToParametersRecord(ins.parameters, keepInput); + recordset.parametersRecords[`${insStr}.parameters`] = parametersRecord; + + recordset.configsRecords[`${insStr}.config`] = new ConfigsRecord( + ins.config as Record, + ); + + return recordset; +} + +function embedStatusIntoRecordSet( + resStr: string, + status: { code: number; message: string }, + recordset: RecordSet, +): RecordSet { + const statusDict: Record = { + code: status.code, + message: status.message, + }; + + recordset.configsRecords[`${resStr}.status`] = new ConfigsRecord( + statusDict as Record, + ); + + return recordset; +} + +function extractStatusFromRecordSet( + resStr: string, + recordset: RecordSet, +): { code: number; message: string } { + const status = recordset.configsRecords[`${resStr}.status`]; + const code = status["code"] as number; + return { code, message: status["message"] as string }; +} + +export function recordSetToFitIns(recordset: RecordSet, keepInput: boolean): FitIns { + const { parameters, config } = recordSetToFitOrEvaluateInsComponents( + recordset, + "fitins", + keepInput, + ); + return { parameters, config }; +} + +export function fitInsToRecordSet(fitins: FitIns, keepInput: boolean): RecordSet { + return fitOrEvaluateInsToRecordSet(fitins, keepInput, "fitins"); +} + +export function recordSetToFitRes(recordset: RecordSet, keepInput: boolean): FitRes { + const insStr = "fitres"; + const parameters = parametersrecordToParameters( + recordset.parametersRecords[`${insStr}.parameters`], + keepInput, + ); + + const numExamples = recordset.metricsRecords[`${insStr}.num_examples`]["num_examples"] as number; + + const configRecord = recordset.configsRecords[`${insStr}.metrics`]; + const metrics = checkMappingFromRecordScalarTypeToScalar(configRecord); + const status = extractStatusFromRecordSet(insStr, recordset); + + return { status, parameters, numExamples, metrics }; +} + +export function fitResToRecordSet(fitres: FitRes, keepInput: boolean): RecordSet { + const recordset = new RecordSet(); + const resStr = "fitres"; + + recordset.configsRecords[`${resStr}.metrics`] = new ConfigsRecord( + fitres.metrics as Record, + ); + recordset.metricsRecords[`${resStr}.num_examples`] = new MetricsRecord({ + num_examples: fitres.numExamples as MetricsRecordValue, + }); + + recordset.parametersRecords[`${resStr}.parameters`] = parametersToParametersRecord( + fitres.parameters, + keepInput, + ); + + return embedStatusIntoRecordSet(resStr, fitres.status, recordset); +} + +export function recordSetToEvaluateIns(recordset: RecordSet, keepInput: boolean): EvaluateIns { + const { parameters, config } = recordSetToFitOrEvaluateInsComponents( + recordset, + "evaluateins", + keepInput, + ); + return { parameters, config }; +} + +export function evaluateInsToRecordSet(evaluateIns: EvaluateIns, keepInput: boolean): RecordSet { + return fitOrEvaluateInsToRecordSet(evaluateIns, keepInput, "evaluateins"); +} + +export function recordSetToEvaluateRes(recordset: RecordSet): EvaluateRes { + const insStr = "evaluateres"; + + const loss = recordset.metricsRecords[`${insStr}.loss`]["loss"] as number; + const numExamples = recordset.metricsRecords[`${insStr}.num_examples`]["numExamples"] as number; + const configsRecord = recordset.configsRecords[`${insStr}.metrics`]; + const metrics = Object.fromEntries( + Object.entries(configsRecord).map(([key, value]) => [key, value]), + ) as Metrics; + const status = extractStatusFromRecordSet(insStr, recordset); + + return { status, loss, numExamples, metrics }; +} + +export function evaluateResToRecordSet(evaluateRes: EvaluateRes): RecordSet { + const recordset = new RecordSet(); + const resStr = "evaluateres"; + + recordset.metricsRecords[`${resStr}.loss`] = new MetricsRecord({ loss: evaluateRes.loss }); + recordset.metricsRecords[`${resStr}.num_examples`] = new MetricsRecord({ + numExamples: evaluateRes.numExamples, + }); + recordset.configsRecords[`${resStr}.metrics`] = new ConfigsRecord( + evaluateRes.metrics as Record, + ); + + return embedStatusIntoRecordSet(resStr, evaluateRes.status, recordset); +} + +export function recordSetToGetParametersIns(recordset: RecordSet): GetParametersIns { + const configRecord = recordset.configsRecords["getparametersins.config"]; + const configDict = checkMappingFromRecordScalarTypeToScalar(configRecord); + return { config: configDict }; +} + +export function recordSetToGetPropertiesIns(recordset: RecordSet): GetPropertiesIns { + const configRecord = recordset.configsRecords["getpropertiesins.config"]; + const configDict = checkMappingFromRecordScalarTypeToScalar(configRecord); + return { config: configDict }; +} + +export function getParametersInsToRecordSet(getParametersIns: GetParametersIns): RecordSet { + const recordset = new RecordSet(); + recordset.configsRecords["getparametersins.config"] = new ConfigsRecord( + getParametersIns.config as Record, + ); + return recordset; +} + +export function getPropertiesInsToRecordSet(getPropertiesIns: GetPropertiesIns | null): RecordSet { + try { + const recordset = new RecordSet(); + + let config: Record; + if (getPropertiesIns && "config" in getPropertiesIns) + config = (getPropertiesIns.config as Record) || {}; + else config = {}; + + recordset.configsRecords["getpropertiesins.config"] = new ConfigsRecord(config); + return recordset; + } catch (error) { + console.error("Error in getPropertiesInsToRecordSet:", error); + throw error; // You can throw or return a default value based on your requirement + } +} + +export function getParametersResToRecordSet( + getParametersRes: GetParametersRes, + keepInput: boolean, +): RecordSet { + const recordset = new RecordSet(); + const parametersRecord = parametersToParametersRecord(getParametersRes.parameters, keepInput); + recordset.parametersRecords["getparametersres.parameters"] = parametersRecord; + + return embedStatusIntoRecordSet("getparametersres", getParametersRes.status, recordset); +} + +export function getPropertiesResToRecordSet(getPropertiesRes: GetPropertiesRes): RecordSet { + const recordset = new RecordSet(); + recordset.configsRecords["getpropertiesres.properties"] = new ConfigsRecord( + getPropertiesRes.properties as Record, + ); + + return embedStatusIntoRecordSet("getpropertiesres", getPropertiesRes.status, recordset); +} + +export function recordSetToGetParametersRes( + recordset: RecordSet, + keepInput: boolean, +): GetParametersRes { + const resStr = "getparametersres"; + const parameters = parametersrecordToParameters( + recordset.parametersRecords[`${resStr}.parameters`], + keepInput, + ); + + const status = extractStatusFromRecordSet(resStr, recordset); + return { status, parameters }; +} + +export function recordSetToGetPropertiesRes(recordset: RecordSet): GetPropertiesRes { + const resStr = "getpropertiesres"; + const properties = checkMappingFromRecordScalarTypeToScalar( + recordset.configsRecords[`${resStr}.properties`], + ); + + const status = extractStatusFromRecordSet(resStr, recordset); + return { status, properties }; +} diff --git a/src/ts/src/lib/retry_invoker.js b/src/ts/src/lib/retry_invoker.js new file mode 100644 index 000000000000..da86a3fd5b14 --- /dev/null +++ b/src/ts/src/lib/retry_invoker.js @@ -0,0 +1,114 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.RetryInvoker = exports.sleep = void 0; +exports.exponential = exponential; +exports.constant = constant; +exports.fullJitter = fullJitter; +const sleep = (s) => new Promise((r) => setTimeout(r, s * 1000)); +exports.sleep = sleep; +// Generator function for exponential backoff strategy +function* exponential(baseDelay = 1, multiplier = 2, maxDelay) { + let delay = maxDelay === undefined ? baseDelay : Math.min(baseDelay, maxDelay); + while (true) { + yield delay; + delay *= multiplier; + if (maxDelay !== undefined) { + delay = Math.min(delay, maxDelay); + } + } +} +// Generator function for constant wait times +function* constant(interval = 1) { + if (typeof interval === "number") { + while (true) { + yield interval; + } + } + else { + yield* interval; + } +} +// Full jitter algorithm +function fullJitter(maxValue) { + return Math.random() * maxValue; +} +class RetryInvoker { + waitGenFactory; + recoverableExceptions; + maxTries; + maxTime; + onSuccess; + onBackoff; + onGiveup; + jitter; + shouldGiveup; + waitFunction; + constructor(waitGenFactory, recoverableExceptions, maxTries, maxTime, options = {}) { + this.waitGenFactory = waitGenFactory; + this.recoverableExceptions = recoverableExceptions; + this.maxTries = maxTries; + this.maxTime = maxTime; + this.onSuccess = options.onSuccess; + this.onBackoff = options.onBackoff; + this.onGiveup = options.onGiveup; + this.jitter = options.jitter ?? fullJitter; + this.shouldGiveup = options.shouldGiveup; + this.waitFunction = options.waitFunction ?? exports.sleep; + } + async invoke(target, ...args) { + const startTime = Date.now(); + let tryCount = 0; + const waitGenerator = this.waitGenFactory(); + while (true) { + tryCount++; + const elapsedTime = (Date.now() - startTime) / 1000; + const state = { + target, + args, + kwargs: {}, + tries: tryCount, + elapsedTime, + }; + try { + // Attempt the target function call + const result = await target(...args); + // On success, call onSuccess handler if defined + if (this.onSuccess) { + this.onSuccess(state); + } + return result; + } + catch (err) { + if (!(err instanceof this.recoverableExceptions)) { + throw err; // Not a recoverable exception, rethrow it + } + state.exception = err; + const giveup = this.shouldGiveup && this.shouldGiveup(err); + const maxTriesExceeded = this.maxTries !== null && tryCount >= this.maxTries; + const maxTimeExceeded = this.maxTime !== null && elapsedTime >= this.maxTime; + // Check if we should give up + if (giveup || maxTriesExceeded || maxTimeExceeded) { + if (this.onGiveup) { + this.onGiveup(state); + } + throw err; // Give up and rethrow the error + } + let waitTime = waitGenerator.next().value; + if (this.jitter) { + waitTime = this.jitter(waitTime); + } + if (this.maxTime !== null) { + waitTime = Math.min(waitTime, this.maxTime - elapsedTime); + } + state.actualWait = waitTime; + // Call onBackoff handler if defined + if (this.onBackoff) { + this.onBackoff(state); + } + // Wait for the specified time + await this.waitFunction(waitTime * 1000); + } + } + } +} +exports.RetryInvoker = RetryInvoker; diff --git a/src/ts/src/lib/retry_invoker.test.js b/src/ts/src/lib/retry_invoker.test.js new file mode 100644 index 000000000000..d697feca3a1f --- /dev/null +++ b/src/ts/src/lib/retry_invoker.test.js @@ -0,0 +1,83 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const retry_invoker_1 = require("./retry_invoker"); // Adjust import paths as necessary +describe("RetryInvoker", () => { + // Mocking utilities (similar to pytest's fixture) + let mockTime; + let mockSleep; + beforeEach(() => { + mockTime = jest.spyOn(Date, "now").mockImplementation(() => 0); + mockSleep = jest + .spyOn(global, "setTimeout") + .mockImplementation((fn, ms) => { + fn(); // Immediately call the function to avoid waiting + return setTimeout(fn, ms); // Return a proper Timeout object + }); + }); + afterEach(() => { + mockTime.mockRestore(); + mockSleep.mockRestore(); + }); + // Test successful invocation + it("should succeed when the function does not throw", async () => { + // Prepare + const successHandler = jest.fn(); + const backoffHandler = jest.fn(); + const giveupHandler = jest.fn(); + const invoker = new retry_invoker_1.RetryInvoker(() => (0, retry_invoker_1.constant)(0.1), Error, null, null, { + onSuccess: successHandler, + onBackoff: backoffHandler, + onGiveup: giveupHandler, + }); + const successfulFunction = () => "success"; + // Assert that the invoker returns the correct result + const result = await invoker.invoke(successfulFunction); + expect(result).toBe("success"); + }); + it("should retry and fail on the failing function", async () => { + const invoker = new retry_invoker_1.RetryInvoker(() => (0, retry_invoker_1.constant)(0.1), // Retry every 0.1 seconds + Error, // Retry on Error + 2, // Maximum 2 retries + 2.5); + const failingFunction = () => { + throw new Error("failed"); + }; + // Assert that the invoker throws an error + await expect(async () => { + await invoker.invoke(failingFunction); + }).rejects.toThrow("failed"); + }); + it("should call onSuccess handler when successful", async () => { + const successHandler = jest.fn(); + const invoker = new retry_invoker_1.RetryInvoker(() => (0, retry_invoker_1.constant)(0.1), Error, 2, 2.5, { + onSuccess: successHandler, + }); + const successfulFunction = () => "success"; + // Call the function and assert success + await invoker.invoke(successfulFunction); + // Ensure the onSuccess handler was called + expect(successHandler).toHaveBeenCalled(); + }); + it("should retry on failure and call onBackoff", async () => { + const backoffHandler = jest.fn(); + const invoker = new retry_invoker_1.RetryInvoker(() => (0, retry_invoker_1.constant)(0.1), Error, 2, 2.5, { + onBackoff: backoffHandler, + }); + const failingFunction = () => { + throw new Error("failed"); + }; + // Assert the invoker throws an error and triggers retries + await expect(invoker.invoke(failingFunction)).rejects.toThrow("failed"); + // Ensure the backoff handler was called + expect(backoffHandler).toHaveBeenCalled(); + }); + it("should stop after max retries", async () => { + const invoker = new retry_invoker_1.RetryInvoker(() => (0, retry_invoker_1.constant)(0.1), Error, 2, // Max retries is 2 + 2.5); + const failingFunction = () => { + throw new Error("failed"); + }; + // Assert the invoker gives up after max retries + await expect(invoker.invoke(failingFunction)).rejects.toThrow("failed"); + }); +}); diff --git a/src/ts/src/lib/retry_invoker.test.ts b/src/ts/src/lib/retry_invoker.test.ts new file mode 100644 index 000000000000..5682b00000de --- /dev/null +++ b/src/ts/src/lib/retry_invoker.test.ts @@ -0,0 +1,106 @@ +import { RetryInvoker, constant } from "./retry_invoker"; // Adjust import paths as necessary +describe("RetryInvoker", () => { + // Mocking utilities (similar to pytest's fixture) + let mockTime: jest.SpyInstance; + let mockSleep: jest.SpyInstance; + + beforeEach(() => { + mockTime = jest.spyOn(Date, "now").mockImplementation(() => 0); + + mockSleep = jest + .spyOn(global, "setTimeout") + .mockImplementation((fn: () => void, ms?: number) => { + fn(); // Immediately call the function to avoid waiting + return setTimeout(fn, ms) as unknown as NodeJS.Timeout; // Return a proper Timeout object + }); + }); + + afterEach(() => { + mockTime.mockRestore(); + mockSleep.mockRestore(); + }); + + // Test successful invocation + it("should succeed when the function does not throw", async () => { + // Prepare + const successHandler = jest.fn(); + const backoffHandler = jest.fn(); + const giveupHandler = jest.fn(); + const invoker = new RetryInvoker(() => constant(0.1), Error, null, null, { + onSuccess: successHandler, + onBackoff: backoffHandler, + onGiveup: giveupHandler, + }); + const successfulFunction = () => "success"; + + // Assert that the invoker returns the correct result + const result = await invoker.invoke(successfulFunction); + expect(result).toBe("success"); + }); + + it("should retry and fail on the failing function", async () => { + const invoker = new RetryInvoker( + () => constant(0.1), // Retry every 0.1 seconds + Error, // Retry on Error + 2, // Maximum 2 retries + 2.5, // Maximum time to retry (in seconds) + ); + + const failingFunction = () => { + throw new Error("failed"); + }; + + // Assert that the invoker throws an error + await expect(async () => { + await invoker.invoke(failingFunction); + }).rejects.toThrow("failed"); + }); + + it("should call onSuccess handler when successful", async () => { + const successHandler = jest.fn(); + const invoker = new RetryInvoker(() => constant(0.1), Error, 2, 2.5, { + onSuccess: successHandler, + }); + + const successfulFunction = () => "success"; + + // Call the function and assert success + await invoker.invoke(successfulFunction); + + // Ensure the onSuccess handler was called + expect(successHandler).toHaveBeenCalled(); + }); + + it("should retry on failure and call onBackoff", async () => { + const backoffHandler = jest.fn(); + const invoker = new RetryInvoker(() => constant(0.1), Error, 2, 2.5, { + onBackoff: backoffHandler, + }); + + const failingFunction = () => { + throw new Error("failed"); + }; + + // Assert the invoker throws an error and triggers retries + await expect(invoker.invoke(failingFunction)).rejects.toThrow("failed"); + + // Ensure the backoff handler was called + expect(backoffHandler).toHaveBeenCalled(); + }); + + it("should stop after max retries", async () => { + const invoker = new RetryInvoker( + () => constant(0.1), + Error, + 2, // Max retries is 2 + 2.5, + ); + + const failingFunction = () => { + throw new Error("failed"); + }; + + // Assert the invoker gives up after max retries + await expect(invoker.invoke(failingFunction)).rejects.toThrow("failed"); + }); +}); diff --git a/src/ts/src/lib/retry_invoker.ts b/src/ts/src/lib/retry_invoker.ts new file mode 100644 index 000000000000..a59209c5309a --- /dev/null +++ b/src/ts/src/lib/retry_invoker.ts @@ -0,0 +1,149 @@ +export const sleep = (s: number) => new Promise((r) => setTimeout(r, s * 1000)); + +export interface RetryState { + target: (...args: any[]) => any; + args: any[]; + kwargs: Record; + tries: number; + elapsedTime: number; + exception?: Error; + actualWait?: number; +} + +// Generator function for exponential backoff strategy +export function* exponential( + baseDelay: number = 1, + multiplier: number = 2, + maxDelay?: number, +): Generator { + let delay = maxDelay === undefined ? baseDelay : Math.min(baseDelay, maxDelay); + while (true) { + yield delay; + delay *= multiplier; + if (maxDelay !== undefined) { + delay = Math.min(delay, maxDelay); + } + } +} + +// Generator function for constant wait times +export function* constant( + interval: number | Iterable = 1, +): Generator { + if (typeof interval === "number") { + while (true) { + yield interval; + } + } else { + yield* interval; + } +} + +// Full jitter algorithm +export function fullJitter(maxValue: number): number { + return Math.random() * maxValue; +} + +export class RetryInvoker { + private waitGenFactory: () => Generator; + private recoverableExceptions: any; + private maxTries: number | null; + private maxTime: number | null; + private onSuccess?: (state: RetryState) => void; + private onBackoff?: (state: RetryState) => void; + private onGiveup?: (state: RetryState) => void; + private jitter?: (waitTime: number) => number; + private shouldGiveup?: (err: Error) => boolean; + private waitFunction: (waitTime: number) => Promise; + + constructor( + waitGenFactory: () => Generator, + recoverableExceptions: any, + maxTries: number | null, + maxTime: number | null, + options: { + onSuccess?: (state: RetryState) => void; + onBackoff?: (state: RetryState) => void; + onGiveup?: (state: RetryState) => void; + jitter?: (waitTime: number) => number; + shouldGiveup?: (err: Error) => boolean; + waitFunction?: (waitTime: number) => Promise; + } = {}, + ) { + this.waitGenFactory = waitGenFactory; + this.recoverableExceptions = recoverableExceptions; + this.maxTries = maxTries; + this.maxTime = maxTime; + this.onSuccess = options.onSuccess; + this.onBackoff = options.onBackoff; + this.onGiveup = options.onGiveup; + this.jitter = options.jitter ?? fullJitter; + this.shouldGiveup = options.shouldGiveup; + this.waitFunction = options.waitFunction ?? sleep; + } + + public async invoke(target: (...args: any[]) => any, ...args: any[]): Promise { + const startTime = Date.now(); + let tryCount = 0; + const waitGenerator = this.waitGenFactory(); + + while (true) { + tryCount++; + const elapsedTime = (Date.now() - startTime) / 1000; + const state: RetryState = { + target, + args, + kwargs: {}, + tries: tryCount, + elapsedTime, + }; + + try { + // Attempt the target function call + const result = await target(...args); + + // On success, call onSuccess handler if defined + if (this.onSuccess) { + this.onSuccess(state); + } + + return result; + } catch (err) { + if (!(err instanceof this.recoverableExceptions)) { + throw err; // Not a recoverable exception, rethrow it + } + + state.exception = err as Error; + + const giveup = this.shouldGiveup && this.shouldGiveup(err as Error); + const maxTriesExceeded = this.maxTries !== null && tryCount >= this.maxTries; + const maxTimeExceeded = this.maxTime !== null && elapsedTime >= this.maxTime; + + // Check if we should give up + if (giveup || maxTriesExceeded || maxTimeExceeded) { + if (this.onGiveup) { + this.onGiveup(state); + } + throw err; // Give up and rethrow the error + } + + let waitTime = waitGenerator.next().value as number; + if (this.jitter) { + waitTime = this.jitter(waitTime); + } + if (this.maxTime !== null) { + waitTime = Math.min(waitTime, this.maxTime - elapsedTime); + } + state.actualWait = waitTime; + + // Call onBackoff handler if defined + if (this.onBackoff) { + this.onBackoff(state); + } + + // Wait for the specified time + await this.waitFunction(waitTime * 1000); + } + } + } +} diff --git a/src/ts/src/lib/serde.js b/src/ts/src/lib/serde.js new file mode 100644 index 000000000000..08f3b0520aa6 --- /dev/null +++ b/src/ts/src/lib/serde.js @@ -0,0 +1,367 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.userConfigValueFromProto = exports.userConfigValueToProto = exports.userConfigFromProto = exports.messageToTaskRes = exports.messageFromTaskIns = exports.propertiesToProto = exports.propertiesFromProto = exports.getPropertiesResToProto = exports.getPropertiesInsFromProto = exports.statusToProto = exports.evaluateResToProto = exports.evaluateInsFromProto = exports.fitResToProto = exports.fitInsFromProto = exports.parameterResToProto = exports.metricsFromProto = exports.metricsToProto = exports.scalarFromProto = exports.scalarToProto = exports.parametersFromProto = exports.parametersToProto = void 0; +exports.recordSetToProto = recordSetToProto; +exports.recordSetFromProto = recordSetFromProto; +const task_1 = require("../protos/flwr/proto/task"); +const node_1 = require("../protos/flwr/proto/node"); +const typing_1 = require("./typing"); +const recordset_1 = require("./recordset"); +// Parameter conversions +const parametersToProto = (params) => { + return { tensors: params.tensors, tensorType: params.tensorType }; +}; +exports.parametersToProto = parametersToProto; +const parametersFromProto = (protoParams) => { + return { tensors: protoParams.tensors, tensorType: protoParams.tensorType }; +}; +exports.parametersFromProto = parametersFromProto; +// Scalar conversions +const scalarToProto = (scalar) => { + if (typeof scalar === "string") { + return { scalar: { oneofKind: "string", string: scalar } }; + } + else if (typeof scalar === "boolean") { + return { scalar: { oneofKind: "bool", bool: scalar } }; + } + else if (typeof scalar === "bigint") { + return { scalar: { oneofKind: "sint64", sint64: scalar } }; + } + else if (typeof scalar === "number") { + return { scalar: { oneofKind: "double", double: scalar } }; + } + else if (scalar instanceof Uint8Array) { + return { scalar: { oneofKind: "bytes", bytes: scalar } }; + } + throw new Error("Unsupported scalar type"); +}; +exports.scalarToProto = scalarToProto; +const scalarFromProto = (protoScalar) => { + switch (protoScalar.scalar?.oneofKind) { + case "double": + return protoScalar.scalar.double; + case "sint64": + return protoScalar.scalar.sint64; + case "bool": + return protoScalar.scalar.bool; + case "string": + return protoScalar.scalar.string; + case "bytes": + return protoScalar.scalar.bytes; + default: + throw new Error("Unknown scalar type"); + } +}; +exports.scalarFromProto = scalarFromProto; +// Metrics conversions +const metricsToProto = (metrics) => { + const protoMetrics = {}; + for (const key in metrics) { + protoMetrics[key] = (0, exports.scalarToProto)(metrics[key]); + } + return protoMetrics; +}; +exports.metricsToProto = metricsToProto; +const metricsFromProto = (protoMetrics) => { + const metrics = {}; + for (const key in protoMetrics) { + metrics[key] = (0, exports.scalarFromProto)(protoMetrics[key]); + } + return metrics; +}; +exports.metricsFromProto = metricsFromProto; +// GetParametersRes conversions +const parameterResToProto = (res) => { + return { + parameters: (0, exports.parametersToProto)(res.parameters), + status: (0, exports.statusToProto)(res.status), + }; +}; +exports.parameterResToProto = parameterResToProto; +// FitIns conversions +const fitInsFromProto = (fitInsMsg) => { + return { + parameters: (0, exports.parametersFromProto)(fitInsMsg.parameters), + config: (0, exports.metricsFromProto)(fitInsMsg.config), + }; +}; +exports.fitInsFromProto = fitInsFromProto; +// FitRes conversions +const fitResToProto = (res) => { + return { + parameters: (0, exports.parametersToProto)(res.parameters), + numExamples: BigInt(res.numExamples), + metrics: Object.keys(res.metrics).length > 0 ? (0, exports.metricsToProto)(res.metrics) : {}, + status: (0, exports.statusToProto)(res.status), + }; +}; +exports.fitResToProto = fitResToProto; +// EvaluateIns conversions +const evaluateInsFromProto = (evaluateInsMsg) => { + return { + parameters: (0, exports.parametersFromProto)(evaluateInsMsg.parameters), + config: (0, exports.metricsFromProto)(evaluateInsMsg.config), + }; +}; +exports.evaluateInsFromProto = evaluateInsFromProto; +// EvaluateRes conversions +const evaluateResToProto = (res) => { + return { + loss: res.loss, + numExamples: BigInt(res.numExamples), + metrics: Object.keys(res.metrics).length > 0 ? (0, exports.metricsToProto)(res.metrics) : {}, + status: (0, exports.statusToProto)(res.status), + }; +}; +exports.evaluateResToProto = evaluateResToProto; +// Status conversions +const statusToProto = (status) => { + return { + code: status.code, + message: status.message, + }; +}; +exports.statusToProto = statusToProto; +// GetPropertiesIns conversions +const getPropertiesInsFromProto = (getPropertiesMsg) => { + return { + config: (0, exports.propertiesFromProto)(getPropertiesMsg.config), + }; +}; +exports.getPropertiesInsFromProto = getPropertiesInsFromProto; +// GetPropertiesRes conversions +const getPropertiesResToProto = (res) => { + return { + properties: (0, exports.propertiesToProto)(res.properties), + status: (0, exports.statusToProto)(res.status), + }; +}; +exports.getPropertiesResToProto = getPropertiesResToProto; +// Properties conversions +const propertiesFromProto = (protoProperties) => { + const properties = {}; + for (const key in protoProperties) { + properties[key] = (0, exports.scalarFromProto)(protoProperties[key]); + } + return properties; +}; +exports.propertiesFromProto = propertiesFromProto; +const propertiesToProto = (properties) => { + const protoProperties = {}; + for (const key in properties) { + protoProperties[key] = (0, exports.scalarToProto)(properties[key]); + } + return protoProperties; +}; +exports.propertiesToProto = propertiesToProto; +function recordValueToProto(value) { + if (typeof value === "number") { + return { value: { oneofKind: "double", double: value } }; + } + else if (typeof value === "bigint") { + return { value: { oneofKind: "sint64", sint64: value } }; + } + else if (typeof value === "boolean") { + return { value: { oneofKind: "bool", bool: value } }; + } + else if (typeof value === "string") { + return { value: { oneofKind: "string", string: value } }; + } + else if (value instanceof Uint8Array) { + return { value: { oneofKind: "bytes", bytes: value } }; + } + else if (Array.isArray(value)) { + if (typeof value[0] === "number") { + return { value: { oneofKind: "doubleList", doubleList: { vals: value } } }; + } + else if (typeof value[0] === "bigint") { + return { value: { oneofKind: "sintList", sintList: { vals: value } } }; + } + else if (typeof value[0] === "boolean") { + return { value: { oneofKind: "boolList", boolList: { vals: value } } }; + } + else if (typeof value[0] === "string") { + return { value: { oneofKind: "stringList", stringList: { vals: value } } }; + } + else if (value[0] instanceof Uint8Array) { + return { value: { oneofKind: "bytesList", bytesList: { vals: value } } }; + } + } + throw new TypeError("Unsupported value type"); +} +// Helper for converting Protobuf messages back into values +function recordValueFromProto(proto) { + switch (proto.value.oneofKind) { + case "double": + return proto.value.double; + case "sint64": + return proto.value.sint64; + case "bool": + return proto.value.bool; + case "string": + return proto.value.string; + case "bytes": + return proto.value.bytes; + case "doubleList": + return proto.value.doubleList.vals; + case "sintList": + return proto.value.sintList.vals; + case "boolList": + return proto.value.boolList.vals; + case "stringList": + return proto.value.stringList.vals; + case "bytesList": + return proto.value.bytesList.vals; + default: + throw new Error("Unknown value kind"); + } +} +function arrayToProto(array) { + return { + dtype: array.dtype, + shape: array.shape, + stype: array.stype, + data: array.data, + }; +} +function arrayFromProto(proto) { + return new recordset_1.ArrayData(proto.dtype, proto.shape, proto.stype, proto.data); +} +function parametersRecordToProto(record) { + return { + dataKeys: Object.keys(record), + dataValues: Object.values(record).map(arrayToProto), + }; +} +function parametersRecordFromProto(proto) { + const arrayDict = Object.fromEntries(proto.dataKeys.map((k, i) => [k, arrayFromProto(proto.dataValues[i])])); + // Create a new instance of ParametersRecord and populate it with the arrayDict + return new recordset_1.ParametersRecord(arrayDict); +} +function metricsRecordToProto(record) { + const data = Object.fromEntries(Object.entries(record).map(([k, v]) => [k, recordValueToProto(v)])); + return { data }; +} +function metricsRecordFromProto(proto) { + const metrics = Object.fromEntries(Object.entries(proto.data).map(([k, v]) => [k, recordValueFromProto(v)])); + return new recordset_1.MetricsRecord(metrics); +} +function configsRecordToProto(record) { + const data = Object.fromEntries(Object.entries(record).map(([k, v]) => [k, recordValueToProto(v)])); + return { data }; +} +function configsRecordFromProto(proto) { + const config = Object.fromEntries(Object.entries(proto.data).map(([k, v]) => [k, recordValueFromProto(v)])); + return new recordset_1.ConfigsRecord(config); +} +function recordSetToProto(recordset) { + const parameters = Object.fromEntries(Object.entries(recordset.parametersRecords).map(([k, v]) => [ + k, + parametersRecordToProto(v), // Nested dictionary (string -> Record) + ])); + const metrics = Object.fromEntries(Object.entries(recordset.metricsRecords).map(([k, v]) => [k, metricsRecordToProto(v)])); + const configs = Object.fromEntries(Object.entries(recordset.configsRecords).map(([k, v]) => [k, configsRecordToProto(v)])); + return { parameters, metrics, configs }; +} +function recordSetFromProto(proto) { + const parametersRecords = Object.fromEntries(Object.entries(proto.parameters).map(([k, v]) => [k, parametersRecordFromProto(v)])); + const metricsRecords = Object.fromEntries(Object.entries(proto.metrics).map(([k, v]) => [k, metricsRecordFromProto(v)])); + const configsRecords = Object.fromEntries(Object.entries(proto.configs).map(([k, v]) => [k, configsRecordFromProto(v)])); + return new recordset_1.RecordSet(parametersRecords, metricsRecords, configsRecords); +} +const messageFromTaskIns = (taskIns) => { + let metadata = { + runId: taskIns.runId, + messageId: taskIns.taskId, + srcNodeId: taskIns.task?.producer?.nodeId, + dstNodeId: taskIns.task?.consumer?.nodeId, + replyToMessage: taskIns.task?.ancestry ? taskIns.task?.ancestry[0] : "", + groupId: taskIns.groupId, + ttl: taskIns.task?.ttl, + messageType: taskIns.task?.taskType, + }; + let message = new typing_1.Message(metadata, taskIns.task?.recordset ? recordSetFromProto(taskIns.task.recordset) : null, taskIns.task?.error ? { code: Number(taskIns.task.error.code), reason: taskIns.task.error.reason } : null); + if (taskIns.task?.createdAt) { + message.metadata.createdAt = taskIns.task?.createdAt; + } + return message; +}; +exports.messageFromTaskIns = messageFromTaskIns; +const messageToTaskRes = (message) => { + const md = message.metadata; + const taskRes = task_1.TaskRes.create(); + taskRes.taskId = "", + taskRes.groupId = md.groupId; + taskRes.runId = md.runId; + let task = task_1.Task.create(); + let producer = node_1.Node.create(); + producer.nodeId = md.srcNodeId; + producer.anonymous = false; + task.producer = producer; + let consumer = node_1.Node.create(); + consumer.nodeId = BigInt(0); + consumer.anonymous = true; + task.consumer = consumer; + task.createdAt = md.createdAt; + task.ttl = md.ttl; + task.ancestry = md.replyToMessage !== "" ? [md.replyToMessage] : []; + task.taskType = md.messageType; + task.recordset = message.content === null ? undefined : recordSetToProto(message.content); + task.error = message.error === null ? undefined : { code: BigInt(message.error.code), reason: message.error.reason }; + taskRes.task = task; + return taskRes; + // return { + // taskId: "", + // groupId: md.groupId, + // runId: md.runId, + // task: { + // producer: { nodeId: md.srcNodeId, anonymous: false } as Node, + // consumer: { nodeId: BigInt(0), anonymous: true } as Node, + // createdAt: md.createdAt, + // ttl: md.ttl, + // ancestry: md.replyToMessage ? [md.replyToMessage] : [], + // taskType: md.messageType, + // recordset: message.content ? recordSetToProto(message.content) : null, + // error: message.error ? ({ code: BigInt(message.error.code), reason: message.error.reason } as ProtoError) : null, + // } as Task, + // } as TaskRes; +}; +exports.messageToTaskRes = messageToTaskRes; +const userConfigFromProto = (proto) => { + let metrics = {}; + Object.entries(proto).forEach(([key, value]) => { + metrics[key] = (0, exports.userConfigValueFromProto)(value); + }); + return metrics; +}; +exports.userConfigFromProto = userConfigFromProto; +const userConfigValueToProto = (userConfigValue) => { + switch (typeof userConfigValue) { + case "string": + return { scalar: { oneofKind: "string", string: userConfigValue } }; + case "number": + return { scalar: { oneofKind: "double", double: userConfigValue } }; + case "bigint": + return { scalar: { oneofKind: "sint64", sint64: userConfigValue } }; + case "boolean": + return { scalar: { oneofKind: "bool", bool: userConfigValue } }; + default: + throw new Error(`Accepted types: {bool, float, int, str} (but not ${typeof userConfigValue})`); + } +}; +exports.userConfigValueToProto = userConfigValueToProto; +const userConfigValueFromProto = (scalarMsg) => { + switch (scalarMsg.scalar.oneofKind) { + case "string": + return scalarMsg.scalar.string; + case "bool": + return scalarMsg.scalar.bool; + case "sint64": + return scalarMsg.scalar.sint64; + case "double": + return scalarMsg.scalar.double; + default: + throw new Error(`Accepted types: {bool, float, int, str} (but not ${scalarMsg.scalar.oneofKind})`); + } +}; +exports.userConfigValueFromProto = userConfigValueFromProto; diff --git a/src/ts/src/lib/serde.test.js b/src/ts/src/lib/serde.test.js new file mode 100644 index 000000000000..67820aaf509d --- /dev/null +++ b/src/ts/src/lib/serde.test.js @@ -0,0 +1,202 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const serde_1 = require("./serde"); +const typing_1 = require("./typing"); +const recordset_1 = require("./recordset"); +let bytes = new Uint8Array(8); +bytes[0] = 256; +// Mock Protobuf messages and local types for testing +const mockProtoParams = { + tensors: [bytes, bytes], + tensorType: "float32", +}; +const mockLocalParams = { + tensors: [bytes, bytes], + tensorType: "float32", +}; +const mockProtoScalar = { + scalar: { oneofKind: "double", double: 1.23 }, +}; +const mockLocalScalar = 1.23; +const mockStatus = { + code: typing_1.Code.OK, + message: "OK", +}; +// Tests for parametersToProto +describe("parametersToProto", () => { + it("should convert local parameters to proto format", () => { + const protoParams = (0, serde_1.parametersToProto)(mockLocalParams); + expect(protoParams).toEqual(mockProtoParams); + }); +}); +// Tests for parametersFromProto +describe("parametersFromProto", () => { + it("should convert proto parameters to local format", () => { + const localParams = (0, serde_1.parametersFromProto)(mockProtoParams); + expect(localParams).toEqual(mockLocalParams); + }); +}); +// Tests for scalarToProto +describe("scalarToProto", () => { + it("should convert local scalar to proto format", () => { + const protoScalar = (0, serde_1.scalarToProto)(mockLocalScalar); + expect(protoScalar).toEqual(mockProtoScalar); + }); +}); +// Tests for scalarFromProto +describe("scalarFromProto", () => { + it("should convert proto scalar to local format", () => { + const localScalar = (0, serde_1.scalarFromProto)(mockProtoScalar); + expect(localScalar).toEqual(mockLocalScalar); + }); +}); +// Tests for metricsToProto +describe("metricsToProto", () => { + it("should convert metrics to proto format", () => { + const localMetrics = { accuracy: 0.95 }; + const expectedProtoMetrics = { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }; + const protoMetrics = (0, serde_1.metricsToProto)(localMetrics); + expect(protoMetrics).toEqual(expectedProtoMetrics); + }); +}); +// Tests for metricsFromProto +describe("metricsFromProto", () => { + it("should convert proto metrics to local format", () => { + const protoMetrics = { + accuracy: { scalar: { oneofKind: "double", double: 0.95 } }, + }; + const expectedLocalMetrics = { accuracy: 0.95 }; + const localMetrics = (0, serde_1.metricsFromProto)(protoMetrics); + expect(localMetrics).toEqual(expectedLocalMetrics); + }); +}); +// Tests for parameterResToProto +describe("parameterResToProto", () => { + it("should convert GetParametersRes to proto format", () => { + const res = { parameters: mockLocalParams, status: mockStatus }; + const protoRes = (0, serde_1.parameterResToProto)(res); + expect(protoRes.parameters).toEqual((0, serde_1.parametersToProto)(res.parameters)); + expect(protoRes.status).toEqual((0, serde_1.statusToProto)(res.status)); + }); +}); +// Tests for fitInsFromProto +describe("fitInsFromProto", () => { + it("should convert proto FitIns to local format", () => { + const protoFitIns = { + parameters: mockProtoParams, + config: { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }, + }; + const localFitIns = (0, serde_1.fitInsFromProto)(protoFitIns); + expect(localFitIns.parameters).toEqual(mockLocalParams); + expect(localFitIns.config).toEqual({ accuracy: 0.95 }); + }); +}); +// Tests for fitResToProto +describe("fitResToProto", () => { + it("should convert FitRes to proto format", () => { + const localFitRes = { + parameters: mockLocalParams, + numExamples: 100, + metrics: { accuracy: 0.95 }, + status: mockStatus, + }; + const protoFitRes = (0, serde_1.fitResToProto)(localFitRes); + expect(protoFitRes.parameters).toEqual((0, serde_1.parametersToProto)(localFitRes.parameters)); + expect(protoFitRes.metrics).toEqual((0, serde_1.metricsToProto)(localFitRes.metrics)); + expect(protoFitRes.status).toEqual((0, serde_1.statusToProto)(localFitRes.status)); + }); +}); +// Tests for evaluateInsFromProto +describe("evaluateInsFromProto", () => { + it("should convert proto EvaluateIns to local format", () => { + const protoEvaluateIns = { + parameters: mockProtoParams, + config: { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }, + }; + const localEvaluateIns = (0, serde_1.evaluateInsFromProto)(protoEvaluateIns); + expect(localEvaluateIns.parameters).toEqual(mockLocalParams); + expect(localEvaluateIns.config).toEqual({ accuracy: 0.95 }); + }); +}); +// Tests for evaluateResToProto +describe("evaluateResToProto", () => { + it("should convert EvaluateRes to proto format", () => { + const localEvaluateRes = { + loss: 0.05, + numExamples: 100, + metrics: { accuracy: 0.95 }, + status: mockStatus, + }; + const protoEvaluateRes = (0, serde_1.evaluateResToProto)(localEvaluateRes); + expect(protoEvaluateRes.loss).toEqual(localEvaluateRes.loss); + expect(protoEvaluateRes.metrics).toEqual((0, serde_1.metricsToProto)(localEvaluateRes.metrics)); + expect(protoEvaluateRes.status).toEqual((0, serde_1.statusToProto)(localEvaluateRes.status)); + }); +}); +// Tests for statusToProto +describe("statusToProto", () => { + it("should convert local status to proto format", () => { + const protoStatus = (0, serde_1.statusToProto)(mockStatus); + expect(protoStatus).toEqual({ code: typing_1.Code.OK, message: "OK" }); + }); +}); +// Tests for getPropertiesInsFromProto +describe("getPropertiesInsFromProto", () => { + it("should convert proto GetPropertiesIns to local format", () => { + const protoGetPropertiesIns = { + config: { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }, + }; + const localGetPropertiesIns = (0, serde_1.getPropertiesInsFromProto)(protoGetPropertiesIns); + expect(localGetPropertiesIns.config).toEqual({ accuracy: 0.95 }); + }); +}); +// Tests for getPropertiesResToProto +describe("getPropertiesResToProto", () => { + it("should convert GetPropertiesRes to proto format", () => { + const localGetPropertiesRes = { properties: { accuracy: 0.95 }, status: mockStatus }; + const protoGetPropertiesRes = (0, serde_1.getPropertiesResToProto)(localGetPropertiesRes); + expect(protoGetPropertiesRes.properties).toEqual((0, serde_1.propertiesToProto)(localGetPropertiesRes.properties)); + expect(protoGetPropertiesRes.status).toEqual((0, serde_1.statusToProto)(localGetPropertiesRes.status)); + }); +}); +// Tests for recordSetToProto and recordSetFromProto +describe("recordSetToProto and recordSetFromProto", () => { + it("should convert local record set to proto and back", () => { + const localRecordSet = new recordset_1.RecordSet({ + parametersRecord1: new recordset_1.ParametersRecord({ + tensor1: new recordset_1.ArrayData("float32", [1, 2], "NDArray", new Uint8Array([1, 2])), + }), + }, {}, {}); + const protoRecordSet = (0, serde_1.recordSetToProto)(localRecordSet); + const recoveredRecordSet = (0, serde_1.recordSetFromProto)(protoRecordSet); + expect(recoveredRecordSet).toEqual(localRecordSet); + }); +}); +// Tests for messageFromTaskIns and messageToTaskRes +describe("messageFromTaskIns and messageToTaskRes", () => { + it("should convert taskIns to message and back to taskRes", () => { + const mockTaskIns = { + runId: BigInt(1), + taskId: "task1", + groupId: "group1", + task: { + consumer: { nodeId: BigInt(1), anonymous: false }, + producer: { nodeId: BigInt(2), anonymous: false }, + taskType: "train", + ttl: 10, + }, + }; + const message = (0, serde_1.messageFromTaskIns)(mockTaskIns); + const taskRes = (0, serde_1.messageToTaskRes)(message); + expect(taskRes.task?.taskType).toEqual(mockTaskIns.task?.taskType); + }); +}); +// Tests for userConfigFromProto and userConfigValueToProto +describe("userConfigFromProto and userConfigValueToProto", () => { + it("should convert user config from proto and back", () => { + const protoConfig = { key1: { scalar: { oneofKind: "double", double: 0.95 } } }; + const localConfig = { key1: 0.95 }; + const recoveredConfig = (0, serde_1.userConfigFromProto)(protoConfig); + expect(recoveredConfig).toEqual(localConfig); + }); +}); diff --git a/src/ts/src/lib/serde.test.ts b/src/ts/src/lib/serde.test.ts new file mode 100644 index 000000000000..8c7b8853683d --- /dev/null +++ b/src/ts/src/lib/serde.test.ts @@ -0,0 +1,260 @@ +import { + parametersToProto, + parametersFromProto, + scalarToProto, + scalarFromProto, + metricsToProto, + metricsFromProto, + parameterResToProto, + fitInsFromProto, + fitResToProto, + evaluateInsFromProto, + evaluateResToProto, + statusToProto, + getPropertiesInsFromProto, + getPropertiesResToProto, + propertiesToProto, + recordSetToProto, + recordSetFromProto, + messageFromTaskIns, + messageToTaskRes, + userConfigFromProto, +} from "./serde"; +import { + Scalar as ProtoScalar, + Status as ProtoStatus, + ServerMessage_EvaluateIns, + ServerMessage_FitIns, + ServerMessage_GetPropertiesIns, +} from "../protos/flwr/proto/transport"; +import { Code as LocalCode, FitRes, EvaluateRes } from "./typing"; +import { RecordSet, ParametersRecord, ArrayData } from "./recordset"; +import { Task, TaskIns } from "../protos/flwr/proto/task"; + +let bytes = new Uint8Array(8); +bytes[0] = 256; + +// Mock Protobuf messages and local types for testing +const mockProtoParams = { + tensors: [bytes, bytes], + tensorType: "float32", +}; + +const mockLocalParams = { + tensors: [bytes, bytes], + tensorType: "float32", +}; + +const mockProtoScalar = { + scalar: { oneofKind: "double", double: 1.23 }, +}; + +const mockLocalScalar = 1.23; + +const mockStatus = { + code: LocalCode.OK, + message: "OK", +}; + +// Tests for parametersToProto +describe("parametersToProto", () => { + it("should convert local parameters to proto format", () => { + const protoParams = parametersToProto(mockLocalParams); + expect(protoParams).toEqual(mockProtoParams); + }); +}); + +// Tests for parametersFromProto +describe("parametersFromProto", () => { + it("should convert proto parameters to local format", () => { + const localParams = parametersFromProto(mockProtoParams); + expect(localParams).toEqual(mockLocalParams); + }); +}); + +// Tests for scalarToProto +describe("scalarToProto", () => { + it("should convert local scalar to proto format", () => { + const protoScalar = scalarToProto(mockLocalScalar); + expect(protoScalar).toEqual(mockProtoScalar); + }); +}); + +// Tests for scalarFromProto +describe("scalarFromProto", () => { + it("should convert proto scalar to local format", () => { + const localScalar = scalarFromProto(mockProtoScalar as ProtoScalar); + expect(localScalar).toEqual(mockLocalScalar); + }); +}); + +// Tests for metricsToProto +describe("metricsToProto", () => { + it("should convert metrics to proto format", () => { + const localMetrics = { accuracy: 0.95 }; + const expectedProtoMetrics = { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }; + const protoMetrics = metricsToProto(localMetrics); + expect(protoMetrics).toEqual(expectedProtoMetrics); + }); +}); + +// Tests for metricsFromProto +describe("metricsFromProto", () => { + it("should convert proto metrics to local format", () => { + const protoMetrics: Record = { + accuracy: { scalar: { oneofKind: "double", double: 0.95 } }, + }; + const expectedLocalMetrics = { accuracy: 0.95 }; + const localMetrics = metricsFromProto(protoMetrics); + expect(localMetrics).toEqual(expectedLocalMetrics); + }); +}); + +// Tests for parameterResToProto +describe("parameterResToProto", () => { + it("should convert GetParametersRes to proto format", () => { + const res = { parameters: mockLocalParams, status: mockStatus }; + const protoRes = parameterResToProto(res); + expect(protoRes.parameters).toEqual(parametersToProto(res.parameters)); + expect(protoRes.status).toEqual(statusToProto(res.status)); + }); +}); + +// Tests for fitInsFromProto +describe("fitInsFromProto", () => { + it("should convert proto FitIns to local format", () => { + const protoFitIns: ServerMessage_FitIns = { + parameters: mockProtoParams, + config: { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }, + }; + const localFitIns = fitInsFromProto(protoFitIns); + expect(localFitIns.parameters).toEqual(mockLocalParams); + expect(localFitIns.config).toEqual({ accuracy: 0.95 }); + }); +}); + +// Tests for fitResToProto +describe("fitResToProto", () => { + it("should convert FitRes to proto format", () => { + const localFitRes: FitRes = { + parameters: mockLocalParams, + numExamples: 100, + metrics: { accuracy: 0.95 }, + status: mockStatus, + }; + const protoFitRes = fitResToProto(localFitRes); + expect(protoFitRes.parameters).toEqual(parametersToProto(localFitRes.parameters)); + expect(protoFitRes.metrics).toEqual(metricsToProto(localFitRes.metrics)); + expect(protoFitRes.status).toEqual(statusToProto(localFitRes.status)); + }); +}); + +// Tests for evaluateInsFromProto +describe("evaluateInsFromProto", () => { + it("should convert proto EvaluateIns to local format", () => { + const protoEvaluateIns: ServerMessage_EvaluateIns = { + parameters: mockProtoParams, + config: { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }, + }; + const localEvaluateIns = evaluateInsFromProto(protoEvaluateIns); + expect(localEvaluateIns.parameters).toEqual(mockLocalParams); + expect(localEvaluateIns.config).toEqual({ accuracy: 0.95 }); + }); +}); + +// Tests for evaluateResToProto +describe("evaluateResToProto", () => { + it("should convert EvaluateRes to proto format", () => { + const localEvaluateRes: EvaluateRes = { + loss: 0.05, + numExamples: 100, + metrics: { accuracy: 0.95 }, + status: mockStatus, + }; + const protoEvaluateRes = evaluateResToProto(localEvaluateRes); + expect(protoEvaluateRes.loss).toEqual(localEvaluateRes.loss); + expect(protoEvaluateRes.metrics).toEqual(metricsToProto(localEvaluateRes.metrics)); + expect(protoEvaluateRes.status).toEqual(statusToProto(localEvaluateRes.status)); + }); +}); + +// Tests for statusToProto +describe("statusToProto", () => { + it("should convert local status to proto format", () => { + const protoStatus = statusToProto(mockStatus); + expect(protoStatus).toEqual({ code: LocalCode.OK, message: "OK" } as ProtoStatus); + }); +}); + +// Tests for getPropertiesInsFromProto +describe("getPropertiesInsFromProto", () => { + it("should convert proto GetPropertiesIns to local format", () => { + const protoGetPropertiesIns: ServerMessage_GetPropertiesIns = { + config: { accuracy: { scalar: { oneofKind: "double", double: 0.95 } } }, + }; + const localGetPropertiesIns = getPropertiesInsFromProto(protoGetPropertiesIns); + expect(localGetPropertiesIns.config).toEqual({ accuracy: 0.95 }); + }); +}); + +// Tests for getPropertiesResToProto +describe("getPropertiesResToProto", () => { + it("should convert GetPropertiesRes to proto format", () => { + const localGetPropertiesRes = { properties: { accuracy: 0.95 }, status: mockStatus }; + const protoGetPropertiesRes = getPropertiesResToProto(localGetPropertiesRes); + expect(protoGetPropertiesRes.properties).toEqual( + propertiesToProto(localGetPropertiesRes.properties), + ); + expect(protoGetPropertiesRes.status).toEqual(statusToProto(localGetPropertiesRes.status)); + }); +}); + +// Tests for recordSetToProto and recordSetFromProto +describe("recordSetToProto and recordSetFromProto", () => { + it("should convert local record set to proto and back", () => { + const localRecordSet = new RecordSet( + { + parametersRecord1: new ParametersRecord({ + tensor1: new ArrayData("float32", [1, 2], "NDArray", new Uint8Array([1, 2])), + }), + }, + {}, + {}, + ); + + const protoRecordSet = recordSetToProto(localRecordSet); + const recoveredRecordSet = recordSetFromProto(protoRecordSet); + + expect(recoveredRecordSet).toEqual(localRecordSet); + }); +}); + +// Tests for messageFromTaskIns and messageToTaskRes +describe("messageFromTaskIns and messageToTaskRes", () => { + it("should convert taskIns to message and back to taskRes", () => { + const mockTaskIns: TaskIns = { + runId: BigInt(1), + taskId: "task1", + groupId: "group1", + task: { + consumer: { nodeId: BigInt(1), anonymous: false }, + producer: { nodeId: BigInt(2), anonymous: false }, + taskType: "train", + ttl: 10, + } as Task, + }; + const message = messageFromTaskIns(mockTaskIns); + const taskRes = messageToTaskRes(message); + expect(taskRes.task?.taskType).toEqual(mockTaskIns.task?.taskType); + }); +}); + +// Tests for userConfigFromProto and userConfigValueToProto +describe("userConfigFromProto and userConfigValueToProto", () => { + it("should convert user config from proto and back", () => { + const protoConfig = { key1: { scalar: { oneofKind: "double", double: 0.95 } } }; + const localConfig = { key1: 0.95 }; + const recoveredConfig = userConfigFromProto(protoConfig); + expect(recoveredConfig).toEqual(localConfig); + }); +}); diff --git a/src/ts/src/lib/serde.ts b/src/ts/src/lib/serde.ts new file mode 100644 index 000000000000..0657696e3fc6 --- /dev/null +++ b/src/ts/src/lib/serde.ts @@ -0,0 +1,464 @@ +import { + Parameters as ProtoParams, + Scalar as ProtoScalar, + Status as ProtoStatus, + ClientMessage_GetPropertiesRes as ProtoClientMessage_GetPropertiesRes, + ClientMessage_GetParametersRes as ProtoClientMessage_GetParametersRes, + ClientMessage_FitRes as ProtoClientMessage_FitRes, + ClientMessage_EvaluateRes as ProtoClientMessage_EvaluateRes, + ServerMessage_GetPropertiesIns as ProtoServerMessage_GetPropertiesIns, + ServerMessage_FitIns as ProtoServerMessage_FitIns, + ServerMessage_EvaluateIns as ProtoServerMessage_EvaluateIns, + Scalar, +} from "../protos/flwr/proto/transport"; +import { Error as ProtoError } from "../protos/flwr/proto/error"; +import { Task, TaskIns, TaskRes } from "../protos/flwr/proto/task"; +import { Node } from "../protos/flwr/proto/node"; +import { + Parameters as LocalParams, + Scalar as LocalScalar, + Status as LocalStatus, + GetParametersRes as LocalGetParametersRes, + FitIns as LocalFitIns, + FitRes as LocalFitRes, + EvaluateIns as LocalEvaluateIns, + EvaluateRes as LocalEvaluateRes, + GetPropertiesIns as LocalGetPropertiesIns, + GetPropertiesRes as LocalGetPropertiesRes, + Properties as LocalProperties, + Message, + Metadata, + Error as LocalError, + UserConfigValue, + UserConfig, +} from "./typing"; +import { + RecordSet, + ConfigsRecord, + MetricsRecord, + MetricsRecordValue, + ParametersRecord, + ArrayData, +} from "./recordset"; +import { + RecordSet as ProtoRecordSet, + ConfigsRecord as ProtoConfigsRecord, + ConfigsRecordValue as ProtoConfigsRecordValue, + MetricsRecord as ProtoMetricsRecord, + MetricsRecordValue as ProtoMetricsRecordValue, + ParametersRecord as ProtoParametersRecord, + Array$ as ProtoArray, +} from "../protos/flwr/proto/recordset"; + +// Parameter conversions +export const parametersToProto = (params: LocalParams): ProtoParams => { + return { tensors: params.tensors, tensorType: params.tensorType } as ProtoParams; +}; + +export const parametersFromProto = (protoParams: ProtoParams): LocalParams => { + return { tensors: protoParams.tensors, tensorType: protoParams.tensorType } as LocalParams; +}; + +// Scalar conversions +export const scalarToProto = (scalar: LocalScalar): ProtoScalar => { + if (typeof scalar === "string") { + return { scalar: { oneofKind: "string", string: scalar } } as ProtoScalar; + } else if (typeof scalar === "boolean") { + return { scalar: { oneofKind: "bool", bool: scalar } } as ProtoScalar; + } else if (typeof scalar === "bigint") { + return { scalar: { oneofKind: "sint64", sint64: scalar } } as ProtoScalar; + } else if (typeof scalar === "number") { + return { scalar: { oneofKind: "double", double: scalar } } as ProtoScalar; + } else if (scalar instanceof Uint8Array) { + return { scalar: { oneofKind: "bytes", bytes: scalar } } as ProtoScalar; + } + throw new Error("Unsupported scalar type"); +}; + +export const scalarFromProto = (protoScalar: ProtoScalar): LocalScalar => { + switch (protoScalar.scalar?.oneofKind) { + case "double": + return protoScalar.scalar.double as number; + case "sint64": + return protoScalar.scalar.sint64 as bigint; + case "bool": + return protoScalar.scalar.bool as boolean; + case "string": + return protoScalar.scalar.string as string; + case "bytes": + return protoScalar.scalar.bytes as Uint8Array; + default: + throw new Error("Unknown scalar type"); + } +}; + +// Metrics conversions +export const metricsToProto = ( + metrics: Record, +): Record => { + const protoMetrics: Record = {}; + for (const key in metrics) { + protoMetrics[key] = scalarToProto(metrics[key]); + } + return protoMetrics; +}; + +export const metricsFromProto = ( + protoMetrics: Record, +): Record => { + const metrics: Record = {}; + for (const key in protoMetrics) { + metrics[key] = scalarFromProto(protoMetrics[key]); + } + return metrics; +}; + +// GetParametersRes conversions +export const parameterResToProto = ( + res: LocalGetParametersRes, +): ProtoClientMessage_GetParametersRes => { + return { + parameters: parametersToProto(res.parameters), + status: statusToProto(res.status), + }; +}; + +// FitIns conversions +export const fitInsFromProto = (fitInsMsg: ProtoServerMessage_FitIns): LocalFitIns => { + return { + parameters: parametersFromProto(fitInsMsg.parameters!), + config: metricsFromProto(fitInsMsg.config), + }; +}; + +// FitRes conversions +export const fitResToProto = (res: LocalFitRes): ProtoClientMessage_FitRes => { + return { + parameters: parametersToProto(res.parameters), + numExamples: BigInt(res.numExamples), + metrics: Object.keys(res.metrics).length > 0 ? metricsToProto(res.metrics) : {}, + status: statusToProto(res.status), + }; +}; + +// EvaluateIns conversions +export const evaluateInsFromProto = ( + evaluateInsMsg: ProtoServerMessage_EvaluateIns, +): LocalEvaluateIns => { + return { + parameters: parametersFromProto(evaluateInsMsg.parameters!), + config: metricsFromProto(evaluateInsMsg.config), + }; +}; + +// EvaluateRes conversions +export const evaluateResToProto = (res: LocalEvaluateRes): ProtoClientMessage_EvaluateRes => { + return { + loss: res.loss, + numExamples: BigInt(res.numExamples), + metrics: Object.keys(res.metrics).length > 0 ? metricsToProto(res.metrics) : {}, + status: statusToProto(res.status), + }; +}; + +// Status conversions +export const statusToProto = (status: LocalStatus): ProtoStatus => { + return { + code: status.code, + message: status.message, + }; +}; + +// GetPropertiesIns conversions +export const getPropertiesInsFromProto = ( + getPropertiesMsg: ProtoServerMessage_GetPropertiesIns, +): LocalGetPropertiesIns => { + return { + config: propertiesFromProto(getPropertiesMsg.config), + }; +}; + +// GetPropertiesRes conversions +export const getPropertiesResToProto = ( + res: LocalGetPropertiesRes, +): ProtoClientMessage_GetPropertiesRes => { + return { + properties: propertiesToProto(res.properties), + status: statusToProto(res.status), + }; +}; + +// Properties conversions +export const propertiesFromProto = ( + protoProperties: Record, +): LocalProperties => { + const properties: LocalProperties = {}; + for (const key in protoProperties) { + properties[key] = scalarFromProto(protoProperties[key]); + } + return properties; +}; + +export const propertiesToProto = (properties: LocalProperties): Record => { + const protoProperties: Record = {}; + for (const key in properties) { + protoProperties[key] = scalarToProto(properties[key]); + } + return protoProperties; +}; + +function recordValueToProto(value: any): ProtoMetricsRecordValue | ProtoConfigsRecordValue { + if (typeof value === "number") { + return { value: { oneofKind: "double", double: value } }; + } else if (typeof value === "bigint") { + return { value: { oneofKind: "sint64", sint64: value } }; + } else if (typeof value === "boolean") { + return { value: { oneofKind: "bool", bool: value } }; + } else if (typeof value === "string") { + return { value: { oneofKind: "string", string: value } }; + } else if (value instanceof Uint8Array) { + return { value: { oneofKind: "bytes", bytes: value } }; + } else if (Array.isArray(value)) { + if (typeof value[0] === "number") { + return { value: { oneofKind: "doubleList", doubleList: { vals: value } } }; + } else if (typeof value[0] === "bigint") { + return { value: { oneofKind: "sintList", sintList: { vals: value } } }; + } else if (typeof value[0] === "boolean") { + return { value: { oneofKind: "boolList", boolList: { vals: value } } }; + } else if (typeof value[0] === "string") { + return { value: { oneofKind: "stringList", stringList: { vals: value } } }; + } else if (value[0] instanceof Uint8Array) { + return { value: { oneofKind: "bytesList", bytesList: { vals: value } } }; + } + } + throw new TypeError("Unsupported value type"); +} + +// Helper for converting Protobuf messages back into values +function recordValueFromProto(proto: ProtoMetricsRecordValue | ProtoConfigsRecordValue): any { + switch (proto.value.oneofKind) { + case "double": + return proto.value.double; + case "sint64": + return proto.value.sint64; + case "bool": + return proto.value.bool; + case "string": + return proto.value.string; + case "bytes": + return proto.value.bytes; + case "doubleList": + return proto.value.doubleList.vals; + case "sintList": + return proto.value.sintList.vals; + case "boolList": + return proto.value.boolList.vals; + case "stringList": + return proto.value.stringList.vals; + case "bytesList": + return proto.value.bytesList.vals; + default: + throw new Error("Unknown value kind"); + } +} + +function arrayToProto(array: ArrayData): ProtoArray { + return { + dtype: array.dtype, + shape: array.shape, + stype: array.stype, + data: array.data, + }; +} + +function arrayFromProto(proto: ProtoArray): ArrayData { + return new ArrayData(proto.dtype, proto.shape, proto.stype, proto.data); +} + +function parametersRecordToProto(record: ParametersRecord): ProtoParametersRecord { + return { + dataKeys: Object.keys(record), + dataValues: Object.values(record).map(arrayToProto), + }; +} + +function parametersRecordFromProto(proto: ProtoParametersRecord): ParametersRecord { + const arrayDict = Object.fromEntries( + proto.dataKeys.map((k, i) => [k, arrayFromProto(proto.dataValues[i])]), + ); + + // Create a new instance of ParametersRecord and populate it with the arrayDict + return new ParametersRecord(arrayDict); +} + +function metricsRecordToProto(record: MetricsRecord): ProtoMetricsRecord { + const data = Object.fromEntries( + Object.entries(record).map(([k, v]) => [k, recordValueToProto(v) as ProtoMetricsRecordValue]), + ); + return { data }; +} + +function metricsRecordFromProto(proto: ProtoMetricsRecord): MetricsRecord { + const metrics = Object.fromEntries( + Object.entries(proto.data).map(([k, v]) => [k, recordValueFromProto(v) as MetricsRecordValue]), + ); + return new MetricsRecord(metrics); +} + +function configsRecordToProto(record: ConfigsRecord): ProtoConfigsRecord { + const data = Object.fromEntries( + Object.entries(record).map(([k, v]) => [k, recordValueToProto(v) as ProtoConfigsRecordValue]), + ); + return { data }; +} + +function configsRecordFromProto(proto: ProtoConfigsRecord): ConfigsRecord { + const config = Object.fromEntries( + Object.entries(proto.data).map(([k, v]) => [k, recordValueFromProto(v)]), + ); + return new ConfigsRecord(config); +} + +export function recordSetToProto(recordset: RecordSet): ProtoRecordSet { + const parameters = Object.fromEntries( + Object.entries(recordset.parametersRecords).map(([k, v]) => [ + k, + parametersRecordToProto(v), // Nested dictionary (string -> Record) + ]), + ); + const metrics = Object.fromEntries( + Object.entries(recordset.metricsRecords).map(([k, v]) => [k, metricsRecordToProto(v)]), + ); + const configs = Object.fromEntries( + Object.entries(recordset.configsRecords).map(([k, v]) => [k, configsRecordToProto(v)]), + ); + return { parameters, metrics, configs }; +} + +export function recordSetFromProto(proto: ProtoRecordSet): RecordSet { + const parametersRecords = Object.fromEntries( + Object.entries(proto.parameters).map(([k, v]) => [k, parametersRecordFromProto(v)]), + ); + const metricsRecords = Object.fromEntries( + Object.entries(proto.metrics).map(([k, v]) => [k, metricsRecordFromProto(v)]), + ); + const configsRecords = Object.fromEntries( + Object.entries(proto.configs).map(([k, v]) => [k, configsRecordFromProto(v)]), + ); + return new RecordSet(parametersRecords, metricsRecords, configsRecords); +} + +export const messageFromTaskIns = (taskIns: TaskIns): Message => { + let metadata = { + runId: taskIns.runId, + messageId: taskIns.taskId, + srcNodeId: taskIns.task?.producer?.nodeId, + dstNodeId: taskIns.task?.consumer?.nodeId, + replyToMessage: taskIns.task?.ancestry ? taskIns.task?.ancestry[0] : "", + groupId: taskIns.groupId, + ttl: taskIns.task?.ttl, + messageType: taskIns.task?.taskType, + } as Metadata; + + let message = new Message( + metadata, + taskIns.task?.recordset ? recordSetFromProto(taskIns.task.recordset) : null, + taskIns.task?.error ? ({ code: Number(taskIns.task.error.code), reason: taskIns.task.error.reason } as LocalError) : null, + ); + + if (taskIns.task?.createdAt) { + message.metadata.createdAt = taskIns.task?.createdAt; + } + return message; +}; + +export const messageToTaskRes = (message: Message): TaskRes => { + const md = message.metadata; + const taskRes = TaskRes.create(); + taskRes.taskId = "", + taskRes.groupId = md.groupId; + taskRes.runId = md.runId; + + let task = Task.create(); + + let producer = Node.create(); + producer.nodeId = md.srcNodeId; + producer.anonymous = false; + task.producer = producer; + + let consumer = Node.create(); + consumer.nodeId = BigInt(0); + consumer.anonymous = true; + task.consumer = consumer; + + task.createdAt = md.createdAt; + task.ttl = md.ttl; + task.ancestry = md.replyToMessage !== "" ? [md.replyToMessage] : []; + task.taskType = md.messageType; + task.recordset = message.content === null ? undefined : recordSetToProto(message.content); + task.error = message.error === null ? undefined : ({ code: BigInt(message.error.code), reason: message.error.reason } as ProtoError); + + taskRes.task = task; + return taskRes; + + + // return { + // taskId: "", + // groupId: md.groupId, + // runId: md.runId, + // task: { + // producer: { nodeId: md.srcNodeId, anonymous: false } as Node, + // consumer: { nodeId: BigInt(0), anonymous: true } as Node, + // createdAt: md.createdAt, + // ttl: md.ttl, + // ancestry: md.replyToMessage ? [md.replyToMessage] : [], + // taskType: md.messageType, + // recordset: message.content ? recordSetToProto(message.content) : null, + // error: message.error ? ({ code: BigInt(message.error.code), reason: message.error.reason } as ProtoError) : null, + // } as Task, + // } as TaskRes; +}; + +export const userConfigFromProto = (proto: Record): UserConfig => { + let metrics: UserConfig = {}; + + Object.entries(proto).forEach(([key, value]: [string, Scalar]) => { + metrics[key] = userConfigValueFromProto(value); + }); + + return metrics; +}; + +export const userConfigValueToProto = (userConfigValue: UserConfigValue): Scalar => { + switch (typeof userConfigValue) { + case "string": + return { scalar: { oneofKind: "string", string: userConfigValue } } as Scalar; + case "number": + return { scalar: { oneofKind: "double", double: userConfigValue } } as Scalar; + case "bigint": + return { scalar: { oneofKind: "sint64", sint64: userConfigValue } } as Scalar; + case "boolean": + return { scalar: { oneofKind: "bool", bool: userConfigValue } } as Scalar; + default: + throw new Error( + `Accepted types: {bool, float, int, str} (but not ${typeof userConfigValue})`, + ); + } +}; + +export const userConfigValueFromProto = (scalarMsg: Scalar): UserConfigValue => { + switch (scalarMsg.scalar.oneofKind) { + case "string": + return scalarMsg.scalar.string as UserConfigValue; + case "bool": + return scalarMsg.scalar.bool as UserConfigValue; + case "sint64": + return scalarMsg.scalar.sint64 as UserConfigValue; + case "double": + return scalarMsg.scalar.double as UserConfigValue; + default: + throw new Error( + `Accepted types: {bool, float, int, str} (but not ${scalarMsg.scalar.oneofKind})`, + ); + } +}; diff --git a/src/ts/src/lib/start.js b/src/ts/src/lib/start.js new file mode 100644 index 000000000000..ee41a7ff2063 --- /dev/null +++ b/src/ts/src/lib/start.js @@ -0,0 +1,174 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.startClientInternal = startClientInternal; +const connection_1 = require("./connection"); +const message_handler_1 = require("./message_handler"); +const retry_invoker_1 = require("./retry_invoker"); +const grpc_1 = require("./grpc"); +const typing_1 = require("./typing"); +const client_app_1 = require("./client_app"); +const address_1 = require("./address"); +const node_state_1 = require("./node_state"); +class StopIteration extends Error { +} +; +class AppStateTracker { + interrupt = false; + isConnected = false; + constructor() { + this.registerSignalHandler(); + } + // Register handlers for exit signals (SIGINT and SIGTERM) + registerSignalHandler() { + const signalHandler = (signal) => { + console.log(`Received ${signal}. Exiting...`); + this.interrupt = true; + throw new StopIteration(); + }; + // Listen for SIGINT and SIGTERM signals + process.on("SIGINT", () => signalHandler("SIGINT")); + process.on("SIGTERM", () => signalHandler("SIGTERM")); + } +} +async function startClientInternal(serverAddress, nodeConfig, grpcMaxMessageLength = grpc_1.GRPC_MAX_MESSAGE_LENGTH, loadClientAppFn = null, clientFn = null, insecure = null, maxRetries = null, maxWaitTime = null, client = null, rootCertificates = null, flwrPath = null) { + if (insecure === null) { + insecure = rootCertificates === null; + } + if (loadClientAppFn === null) { + if (clientFn === null && client === null) { + throw Error("Both \`client_fn\` and \`client\` are \`None\`, but one is required"); + } + if (clientFn !== null && client !== null) { + throw Error("Both \`client_fn\` and \`client\` are provided, but only one is allowed"); + } + } + if (clientFn === null) { + function singleClientFactory(_context) { + if (client === null) { + throw Error("Both \`client_fn\` and \`client\` are \`None\`, but one is required"); + } + return client; + } + clientFn = singleClientFactory; + } + function _loadClientApp(_fabId, _fabVersion) { + return new client_app_1.ClientApp(clientFn); + } + loadClientAppFn = _loadClientApp; + let appStateTracker = new AppStateTracker(); + function onSuccess(retryState) { + appStateTracker.isConnected = true; + if (retryState.tries > 1) { + console.log(`Connection successful after ${retryState.elapsedTime} seconds and ${retryState.tries} tries.`); + } + } + function onBackoff(retryState) { + appStateTracker.isConnected = false; + if (retryState.tries === 1) { + console.warn("Connection attempt failed, retrying..."); + } + else { + console.warn(`Connection attempt failed, retrying in ${retryState.actualWait} seconds`); + } + } + function onGiveup(retryState) { + if (retryState.tries > 1) { + console.warn(`Giving up reconnection after ${retryState.elapsedTime} seconds and ${retryState.tries} tries.`); + } + } + const retryInvoker = new retry_invoker_1.RetryInvoker(retry_invoker_1.exponential, Error, maxRetries ? maxRetries + 1 : null, maxWaitTime, { + onSuccess, + onBackoff, + onGiveup, + }); + const parsedAdress = (0, address_1.parseAddress)(serverAddress); + if (parsedAdress === null) { + process.exit(`Server address ${serverAddress} cannot be parsed.`); + } + const address = parsedAdress.version ? `[${parsedAdress.host}]:${parsedAdress.port}` : `${parsedAdress.host}:${parsedAdress.port}`; + let nodeState = null; + let runs = {}; + while (!appStateTracker.interrupt) { + let sleepDuration = 0; + const [receive, send, createNode, deleteNode, getRun, _getFab] = await (0, connection_1.grpcRequestResponse)(address, rootCertificates === null, retryInvoker, grpcMaxMessageLength, rootCertificates ? rootCertificates : undefined, null); + if (nodeState === null) { + const nodeId = await createNode(); + if (nodeId === null) { + throw new Error("Node registration failed"); + } + nodeState = new node_state_1.NodeState(nodeId, nodeConfig); + } + appStateTracker.registerSignalHandler(); + while (!appStateTracker.interrupt) { + try { + const message = await receive(); + if (message === null) { + console.log("Pulling..."); + await (0, retry_invoker_1.sleep)(3); + continue; + } + console.log(""); + if (message.metadata.groupId.length > 0) { + console.log(`[RUN ${message.metadata.runId}, ROUND ${message.metadata.groupId}]`); + } + console.log(`Received: ${message.metadata.messageType} message ${message.metadata.messageId}]`); + let outMessage; + [outMessage, sleepDuration] = (0, message_handler_1.handleControlMessage)(message); + if (outMessage) { + await send(outMessage); + break; + } + const runId = message.metadata.runId; + const nRunId = Number(runId); + if (!(nRunId in runs)) { + runs[nRunId] = await getRun(runId); + } + const run = runs[nRunId]; + const fab = null; + nodeState.registerContext(nRunId, run, flwrPath, fab); + let context = nodeState.retrieveContext(nRunId); + let replyMessage = message.createErrorReply({ code: typing_1.ErrorCode.UNKNOWN, reason: "Unknown" }); + try { + const clientApp = loadClientAppFn(run.fabId, run.fabVersion); + replyMessage = clientApp.call(message, context); + } + catch (err) { + let errorCode = typing_1.ErrorCode.CLIENT_APP_RAISED_EXCEPTION; + let reason = `${typeof err}:<'${err}'>`; + let excEntity = "ClientApp"; + if (err instanceof client_app_1.LoadClientAppError) { + reason = "An exception was raised when attempting to load `ClientApp`"; + err = typing_1.ErrorCode.LOAD_CLIENT_APP_EXCEPTION; + excEntity = "SuperNode"; + } + if (!appStateTracker.interrupt) { + // TODO Add excInfo=err + console.error(`${excEntity} raised an exception`); + } + replyMessage = message.createErrorReply({ code: errorCode, reason }); + } + context.runConfig = {}; + nodeState.updateContext(nRunId, context); + await send(replyMessage); + console.log("Sent reply"); + } + catch (err) { + if (err instanceof StopIteration) { + sleepDuration = 0; + break; + } + else { + console.log(err); + await (0, retry_invoker_1.sleep)(3); + } + } + } + await deleteNode(); + if (sleepDuration === 0) { + console.log("Disconnect and shut down"); + break; + } + console.log(`Disconnect, then re-establish connection after ${sleepDuration} second(s)`); + await (0, retry_invoker_1.sleep)(sleepDuration); + } +} diff --git a/src/ts/src/lib/start.ts b/src/ts/src/lib/start.ts new file mode 100644 index 000000000000..c86b1e4e8378 --- /dev/null +++ b/src/ts/src/lib/start.ts @@ -0,0 +1,222 @@ +import { grpcRequestResponse } from "./connection"; +import { Client } from "./client"; +import { handleControlMessage } from "./message_handler"; +import { RetryInvoker, RetryState, exponential, sleep } from "./retry_invoker"; +import { GRPC_MAX_MESSAGE_LENGTH } from "./grpc"; +import { + Context, + ClientFnExt, + UserConfig, + Run, + Message, + Error as LocalError, + ErrorCode, +} from "./typing"; +import { ClientApp, LoadClientAppError } from "./client_app"; +import { parseAddress } from "./address"; +import { NodeState } from "./node_state"; +import { PathLike } from "fs"; + + +class StopIteration extends Error { }; + +class AppStateTracker { + public interrupt: boolean = false; + public isConnected: boolean = false; + + constructor() { + this.registerSignalHandler(); + } + + // Register handlers for exit signals (SIGINT and SIGTERM) + registerSignalHandler(): void { + const signalHandler = (signal: string): void => { + console.log(`Received ${signal}. Exiting...`); + this.interrupt = true; + throw new StopIteration(); + }; + + // Listen for SIGINT and SIGTERM signals + process.on("SIGINT", () => signalHandler("SIGINT")); + process.on("SIGTERM", () => signalHandler("SIGTERM")); + } +} + +export async function startClientInternal( + serverAddress: string, + nodeConfig: UserConfig, + grpcMaxMessageLength: number = GRPC_MAX_MESSAGE_LENGTH, + loadClientAppFn: ((_1: string, _2: string) => ClientApp) | null = null, + clientFn: ClientFnExt | null = null, + insecure: boolean | null = null, + maxRetries: number | null = null, + maxWaitTime: number | null = null, + client: Client | null = null, + rootCertificates: string | null = null, + flwrPath: PathLike | null = null, +): Promise { + + if (insecure === null) { + insecure = rootCertificates === null; + } + + if (loadClientAppFn === null) { + if (clientFn === null && client === null) { + throw Error("Both \`client_fn\` and \`client\` are \`None\`, but one is required") + } + if (clientFn !== null && client !== null) { + throw Error("Both \`client_fn\` and \`client\` are provided, but only one is allowed") + } + } + + if (clientFn === null) { + function singleClientFactory(_context: Context): Client { + if (client === null) { + throw Error("Both \`client_fn\` and \`client\` are \`None\`, but one is required"); + } + return client; + } + clientFn = singleClientFactory; + } + function _loadClientApp(_fabId: string, _fabVersion: string): ClientApp { + return new ClientApp(clientFn!); + } + loadClientAppFn = _loadClientApp; + + let appStateTracker = new AppStateTracker(); + + function onSuccess(retryState: RetryState): void { + appStateTracker.isConnected = true; + if (retryState.tries > 1) { + console.log(`Connection successful after ${retryState.elapsedTime} seconds and ${retryState.tries} tries.`) + } + } + + function onBackoff(retryState: RetryState): void { + appStateTracker.isConnected = false; + if (retryState.tries === 1) { + console.warn("Connection attempt failed, retrying...") + } else { + console.warn(`Connection attempt failed, retrying in ${retryState.actualWait} seconds`) + } + } + + function onGiveup(retryState: RetryState): void { + if (retryState.tries > 1) { + console.warn(`Giving up reconnection after ${retryState.elapsedTime} seconds and ${retryState.tries} tries.`) + } + } + + const retryInvoker = new RetryInvoker( + exponential, Error, maxRetries ? maxRetries + 1 : null, maxWaitTime, { + onSuccess, + onBackoff, + onGiveup, + } + ); + + const parsedAdress = parseAddress(serverAddress) + if (parsedAdress === null) { + process.exit(`Server address ${serverAddress} cannot be parsed.`); + } + const address = parsedAdress.version ? `[${parsedAdress.host}]:${parsedAdress.port}` : `${parsedAdress.host}:${parsedAdress.port}`; + + let nodeState: NodeState | null = null; + let runs: { [runId: number]: Run } = {}; + + + while (!appStateTracker.interrupt) { + let sleepDuration = 0; + const [receive, send, createNode, deleteNode, getRun, _getFab] = await grpcRequestResponse( + address, + rootCertificates === null, + retryInvoker, + grpcMaxMessageLength, + rootCertificates ? rootCertificates : undefined, + null, + ); + if (nodeState === null) { + const nodeId = await createNode(); + if (nodeId === null) { + throw new Error("Node registration failed"); + } + nodeState = new NodeState(nodeId, nodeConfig); + } + appStateTracker.registerSignalHandler(); + while (!appStateTracker.interrupt) { + try { + const message = await receive(); + if (message === null) { + console.log("Pulling...") + await sleep(3); + continue; + } + + console.log("") + if (message.metadata.groupId.length > 0) { + console.log(`[RUN ${message.metadata.runId}, ROUND ${message.metadata.groupId}]`) + } + console.log(`Received: ${message.metadata.messageType} message ${message.metadata.messageId}]`) + + let outMessage: Message | null; + [outMessage, sleepDuration] = handleControlMessage(message); + if (outMessage) { + await send(outMessage); + break; + } + + const runId = message.metadata.runId; + const nRunId = Number(runId); + if (!(nRunId in runs)) { + runs[nRunId] = await getRun(runId); + } + + const run = runs[nRunId]; + const fab = null; + + nodeState.registerContext(nRunId, run, flwrPath, fab); + let context = nodeState.retrieveContext(nRunId); + let replyMessage = message.createErrorReply({ code: ErrorCode.UNKNOWN, reason: "Unknown" } as LocalError); + try { + const clientApp = loadClientAppFn(run.fabId, run.fabVersion); + replyMessage = clientApp.call(message, context); + } catch (err) { + let errorCode = ErrorCode.CLIENT_APP_RAISED_EXCEPTION; + let reason = `${typeof err}:<'${err}'>`; + let excEntity = "ClientApp"; + if (err instanceof LoadClientAppError) { + reason = "An exception was raised when attempting to load `ClientApp`"; + err = ErrorCode.LOAD_CLIENT_APP_EXCEPTION; + excEntity = "SuperNode"; + } + if (!appStateTracker.interrupt) { + // TODO Add excInfo=err + console.error(`${excEntity} raised an exception`) + } + replyMessage = message.createErrorReply({ code: errorCode, reason }); + } + context.runConfig = {}; + nodeState.updateContext(nRunId, context); + await send(replyMessage); + console.log("Sent reply"); + } catch (err) { + if (err instanceof StopIteration) { + sleepDuration = 0; + break; + } else { + console.log(err); + await sleep(3); + } + } + } + await deleteNode(); + + if (sleepDuration === 0) { + console.log("Disconnect and shut down") + break; + } + + console.log(`Disconnect, then re-establish connection after ${sleepDuration} second(s)`) + await sleep(sleepDuration); + } +} diff --git a/src/ts/src/lib/task_handler.js b/src/ts/src/lib/task_handler.js new file mode 100644 index 000000000000..d5954f3964f3 --- /dev/null +++ b/src/ts/src/lib/task_handler.js @@ -0,0 +1,17 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.getTaskIns = exports.validateTaskIns = void 0; +const validateTaskIns = (taskIns) => { + if (!(taskIns.task && taskIns.task.recordset)) { + return false; + } + return true; +}; +exports.validateTaskIns = validateTaskIns; +const getTaskIns = (pullTaskInsResponse) => { + if (pullTaskInsResponse.taskInsList.length === 0) { + return null; + } + return pullTaskInsResponse.taskInsList[0]; +}; +exports.getTaskIns = getTaskIns; diff --git a/src/ts/src/lib/task_handler.test.js b/src/ts/src/lib/task_handler.test.js new file mode 100644 index 000000000000..274420a7304a --- /dev/null +++ b/src/ts/src/lib/task_handler.test.js @@ -0,0 +1,57 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +const task_handler_1 = require("./task_handler"); // Adjust the path as necessary +const recordset_1 = require("./recordset"); // Assuming RecordSet is in the same file +const task_1 = require("../protos/flwr/proto/task"); // Adjust the import paths for Protobuf +const fleet_1 = require("../protos/flwr/proto/fleet"); // Assuming PullTaskInsResponse is here +const serde_1 = require("./serde"); +// Test for validateTaskIns: No task inside TaskIns +describe("validateTaskIns - No task", () => { + it("should return false when task is null", () => { + const taskIns = task_1.TaskIns.create({ task: {} }); + expect((0, task_handler_1.validateTaskIns)(taskIns)).toBe(false); + }); +}); +// Test for validateTaskIns: No content inside Task +describe("validateTaskIns - No content", () => { + it("should return false when recordset is null", () => { + const taskIns = task_1.TaskIns.create({ task: task_1.Task.create() }); + expect((0, task_handler_1.validateTaskIns)(taskIns)).toBe(false); + }); +}); +// Test for validateTaskIns: Valid TaskIns +describe("validateTaskIns - Valid TaskIns", () => { + it("should return true when task contains a valid recordset", () => { + const recordSet = new recordset_1.RecordSet(); + const taskIns = task_1.TaskIns.create({ task: { recordset: (0, serde_1.recordSetToProto)(recordSet) } }); + expect((0, task_handler_1.validateTaskIns)(taskIns)).toBe(true); + }); +}); +// Test for getTaskIns: Empty response +describe("getTaskIns - Empty response", () => { + it("should return null when task_ins_list is empty", () => { + const res = fleet_1.PullTaskInsResponse.create({ taskInsList: [] }); + const taskIns = (0, task_handler_1.getTaskIns)(res); + expect(taskIns).toBeNull(); + }); +}); +// Test for getTaskIns: Single TaskIns in response +describe("getTaskIns - Single TaskIns", () => { + it("should return the task ins when task_ins_list contains one task", () => { + const expectedTaskIns = task_1.TaskIns.create({ taskId: "123", task: task_1.Task.create() }); + const res = fleet_1.PullTaskInsResponse.create({ taskInsList: [expectedTaskIns] }); + const actualTaskIns = (0, task_handler_1.getTaskIns)(res); + expect(actualTaskIns).toEqual(expectedTaskIns); + }); +}); +// Test for getTaskIns: Multiple TaskIns in response +describe("getTaskIns - Multiple TaskIns", () => { + it("should return the first task ins when task_ins_list contains multiple tasks", () => { + const expectedTaskIns = task_1.TaskIns.create({ taskId: "123", task: task_1.Task.create() }); + const res = fleet_1.PullTaskInsResponse.create({ + taskInsList: [expectedTaskIns, task_1.TaskIns.create(), task_1.TaskIns.create()], + }); + const actualTaskIns = (0, task_handler_1.getTaskIns)(res); + expect(actualTaskIns).toEqual(expectedTaskIns); + }); +}); diff --git a/src/ts/src/lib/task_handler.test.ts b/src/ts/src/lib/task_handler.test.ts new file mode 100644 index 000000000000..cdf83fc44244 --- /dev/null +++ b/src/ts/src/lib/task_handler.test.ts @@ -0,0 +1,61 @@ +import { getTaskIns, validateTaskIns } from "./task_handler"; // Adjust the path as necessary +import { RecordSet } from "./recordset"; // Assuming RecordSet is in the same file +import { TaskIns, Task } from "../protos/flwr/proto/task"; // Adjust the import paths for Protobuf +import { PullTaskInsResponse } from "../protos/flwr/proto/fleet"; // Assuming PullTaskInsResponse is here +import { recordSetToProto } from "./serde"; + +// Test for validateTaskIns: No task inside TaskIns +describe("validateTaskIns - No task", () => { + it("should return false when task is null", () => { + const taskIns = TaskIns.create({ task: {} }); + expect(validateTaskIns(taskIns)).toBe(false); + }); +}); + +// Test for validateTaskIns: No content inside Task +describe("validateTaskIns - No content", () => { + it("should return false when recordset is null", () => { + const taskIns = TaskIns.create({ task: Task.create() }); + expect(validateTaskIns(taskIns)).toBe(false); + }); +}); + +// Test for validateTaskIns: Valid TaskIns +describe("validateTaskIns - Valid TaskIns", () => { + it("should return true when task contains a valid recordset", () => { + const recordSet = new RecordSet(); + const taskIns = TaskIns.create({ task: { recordset: recordSetToProto(recordSet) } }); + expect(validateTaskIns(taskIns)).toBe(true); + }); +}); + +// Test for getTaskIns: Empty response +describe("getTaskIns - Empty response", () => { + it("should return null when task_ins_list is empty", () => { + const res = PullTaskInsResponse.create({ taskInsList: [] }); + const taskIns = getTaskIns(res); + expect(taskIns).toBeNull(); + }); +}); + +// Test for getTaskIns: Single TaskIns in response +describe("getTaskIns - Single TaskIns", () => { + it("should return the task ins when task_ins_list contains one task", () => { + const expectedTaskIns = TaskIns.create({ taskId: "123", task: Task.create() }); + const res = PullTaskInsResponse.create({ taskInsList: [expectedTaskIns] }); + const actualTaskIns = getTaskIns(res); + expect(actualTaskIns).toEqual(expectedTaskIns); + }); +}); + +// Test for getTaskIns: Multiple TaskIns in response +describe("getTaskIns - Multiple TaskIns", () => { + it("should return the first task ins when task_ins_list contains multiple tasks", () => { + const expectedTaskIns = TaskIns.create({ taskId: "123", task: Task.create() }); + const res = PullTaskInsResponse.create({ + taskInsList: [expectedTaskIns, TaskIns.create(), TaskIns.create()], + }); + const actualTaskIns = getTaskIns(res); + expect(actualTaskIns).toEqual(expectedTaskIns); + }); +}); diff --git a/src/ts/src/lib/task_handler.ts b/src/ts/src/lib/task_handler.ts new file mode 100644 index 000000000000..e8ce106474cf --- /dev/null +++ b/src/ts/src/lib/task_handler.ts @@ -0,0 +1,18 @@ +import { TaskIns as ProtoTaskIns } from "../protos/flwr/proto/task"; + +import { PullTaskInsResponse as ProtoPullTaskInsResponse } from "../protos/flwr/proto/fleet"; + +export const validateTaskIns = (taskIns: ProtoTaskIns): boolean => { + if (!(taskIns.task && taskIns.task.recordset)) { + return false; + } + return true; +}; + +export const getTaskIns = (pullTaskInsResponse: ProtoPullTaskInsResponse): ProtoTaskIns | null => { + if (pullTaskInsResponse.taskInsList.length === 0) { + return null; + } + + return pullTaskInsResponse.taskInsList[0]; +}; diff --git a/src/ts/src/lib/typing.js b/src/ts/src/lib/typing.js new file mode 100644 index 000000000000..d732e0e86833 --- /dev/null +++ b/src/ts/src/lib/typing.js @@ -0,0 +1,78 @@ +"use strict"; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.Message = exports.MessageType = exports.ErrorCode = exports.Code = void 0; +const DEFAULT_TTL = 3600; +var Code; +(function (Code) { + Code[Code["OK"] = 0] = "OK"; + Code[Code["GET_PROPERTIES_NOT_IMPLEMENTED"] = 1] = "GET_PROPERTIES_NOT_IMPLEMENTED"; + Code[Code["GET_PARAMETERS_NOT_IMPLEMENTED"] = 2] = "GET_PARAMETERS_NOT_IMPLEMENTED"; + Code[Code["FIT_NOT_IMPLEMENTED"] = 3] = "FIT_NOT_IMPLEMENTED"; + Code[Code["EVALUATE_NOT_IMPLEMENTED"] = 4] = "EVALUATE_NOT_IMPLEMENTED"; +})(Code || (exports.Code = Code = {})); +var ErrorCode; +(function (ErrorCode) { + ErrorCode[ErrorCode["UNKNOWN"] = 0] = "UNKNOWN"; + ErrorCode[ErrorCode["LOAD_CLIENT_APP_EXCEPTION"] = 1] = "LOAD_CLIENT_APP_EXCEPTION"; + ErrorCode[ErrorCode["CLIENT_APP_RAISED_EXCEPTION"] = 2] = "CLIENT_APP_RAISED_EXCEPTION"; + ErrorCode[ErrorCode["NODE_UNAVAILABLE"] = 3] = "NODE_UNAVAILABLE"; +})(ErrorCode || (exports.ErrorCode = ErrorCode = {})); +var MessageType; +(function (MessageType) { + MessageType["TRAIN"] = "train"; + MessageType["EVALUATE"] = "evaluate"; + MessageType["QUERY"] = "query"; +})(MessageType || (exports.MessageType = MessageType = {})); +class Message { + metadata; + content; + error; + constructor(metadata, content, error) { + if (!content && !error) { + throw "Either `content` or `error` must be set, but not both."; + } + // Here we divide by 1000 because Python's time.time() is in s while + // here it is in ms by default + metadata.createdAt = (new Date()).valueOf() / 1000; + this.metadata = metadata; + this.content = content; + this.error = error; + } + createErrorReply = (error, ttl = null) => { + if (ttl) { + console.warn("A custom TTL was set, but note that the SuperLink does not enforce the TTL yet. The SuperLink will start enforcing the TTL in a future version of Flower."); + } + const ttl_ = ttl ? ttl : DEFAULT_TTL; + let message = new Message(createReplyMetadata(this, ttl_), null, error); + if (!ttl) { + ttl = this.metadata.ttl - (message.metadata.createdAt - this.metadata.createdAt); + message.metadata.ttl = ttl; + } + return message; + }; + createReply = (content, ttl = null) => { + if (ttl) { + console.warn("A custom TTL was set, but note that the SuperLink does not enforce the TTL yet. The SuperLink will start enforcing the TTL in a future version of Flower."); + } + const ttl_ = ttl !== null ? ttl : DEFAULT_TTL; + let message = new Message(createReplyMetadata(this, ttl_), content, null); + if (!ttl) { + ttl = this.metadata.ttl - (message.metadata.createdAt - this.metadata.createdAt); + message.metadata.ttl = ttl; + } + return message; + }; +} +exports.Message = Message; +const createReplyMetadata = (msg, ttl) => { + return { + runId: msg.metadata.runId, + messageId: "", + srcNodeId: msg.metadata.dstNodeId, + dstNodeId: msg.metadata.srcNodeId, + replyToMessage: msg.metadata.messageId, + groupId: msg.metadata.groupId, + ttl: ttl, + messageType: msg.metadata.messageType, + }; +}; diff --git a/src/ts/src/lib/typing.ts b/src/ts/src/lib/typing.ts new file mode 100644 index 000000000000..5a7421505a1b --- /dev/null +++ b/src/ts/src/lib/typing.ts @@ -0,0 +1,206 @@ +import { RecordSet } from "./recordset"; +import { Client } from "./client"; + +const DEFAULT_TTL = 3600; + +export type Scalar = boolean | number | bigint | string | Uint8Array; +export type Config = { [index: string]: Scalar }; +export type Properties = { [index: string]: Scalar }; +export type Metrics = { [index: string]: Scalar }; + +export type UserConfigValue = boolean | bigint | number | string; +export type UserConfig = { [index: string]: UserConfigValue }; + +export type ClientFn = (cid: string) => Client; +export type ClientFnExt = (context: Context) => Client; + +export type ClientAppCallable = (msg: Message, context: Context) => Message; +export type Mod = (msg: Message, context: Context, call_next: ClientAppCallable) => Message; + +export enum Code { + OK = 0, + GET_PROPERTIES_NOT_IMPLEMENTED = 1, + GET_PARAMETERS_NOT_IMPLEMENTED = 2, + FIT_NOT_IMPLEMENTED = 3, + EVALUATE_NOT_IMPLEMENTED = 4, +} + +export enum ErrorCode { + UNKNOWN = 0, + LOAD_CLIENT_APP_EXCEPTION = 1, + CLIENT_APP_RAISED_EXCEPTION = 2, + NODE_UNAVAILABLE = 3, +} + +export interface Status { + code: Code; + message: string; +} + +export interface Parameters { + tensors: Uint8Array[]; + tensorType: string; +} + +export interface GetPropertiesIns { + config: Config; +} + +export interface GetPropertiesRes { + status: Status; + properties: Properties; +} + +export interface GetParametersIns { + config: Config; +} + +export interface GetParametersRes { + status: Status; + parameters: Parameters; +} + +export interface FitIns { + parameters: Parameters; + config: Config; +} + +export interface FitRes { + status: Status; + parameters: Parameters; + numExamples: number; + metrics: Metrics; +} + +export interface EvaluateIns { + parameters: Parameters; + config: Config; +} + +export interface EvaluateRes { + status: Status; + loss: number; + numExamples: number; + metrics: Metrics; +} + +export interface ServerMessage { + getPropertiesIns: GetPropertiesIns | null; + getParametersIns: GetParametersIns | null; + fitIns: FitIns | null; + evaluateIns: EvaluateIns | null; +} + +export interface ClientMessage { + getPropertiesRes: GetPropertiesRes | null; + getParametersRes: GetParametersRes | null; + fitRes: FitRes | null; + evaluateRes: EvaluateRes | null; +} + +export interface Run { + runId: bigint; + fabId: string; + fabVersion: string; + fabHash: string; + overrideConfig: UserConfig; +} + +export interface Fab { + hashStr: string; + content: Uint8Array; +} + +export interface Context { + nodeId: bigint; + nodeConfig: UserConfig; + state: RecordSet; + runConfig: UserConfig; +} + +export interface Metadata { + runId: bigint; + messageId: string; + srcNodeId: bigint; + dstNodeId: bigint; + replyToMessage: string; + groupId: string; + ttl: number; + messageType: string; + createdAt: number; +} + +export interface Error { + code: number; + reason: string | null; +} + + +export enum MessageType { + TRAIN = "train", + EVALUATE = "evaluate", + QUERY = "query", +} + +export class Message { + metadata: Metadata; + content: RecordSet | null; + error: Error | null; + + constructor(metadata: Metadata, content: RecordSet | null, error: Error | null) { + if (!content && !error) { + throw "Either `content` or `error` must be set, but not both."; + } + // Here we divide by 1000 because Python's time.time() is in s while + // here it is in ms by default + metadata.createdAt = (new Date()).valueOf() / 1000; + this.metadata = metadata; + this.content = content; + this.error = error; + } + + createErrorReply = (error: Error, ttl: number | null = null) => { + if (ttl) { + console.warn( + "A custom TTL was set, but note that the SuperLink does not enforce the TTL yet. The SuperLink will start enforcing the TTL in a future version of Flower.", + ); + } + const ttl_ = ttl ? ttl : DEFAULT_TTL; + let message = new Message(createReplyMetadata(this, ttl_), null, error); + + if (!ttl) { + ttl = this.metadata.ttl - (message.metadata.createdAt - this.metadata.createdAt); + message.metadata.ttl = ttl; + } + return message; + }; + + createReply = (content: RecordSet, ttl: number | null = null) => { + if (ttl) { + console.warn( + "A custom TTL was set, but note that the SuperLink does not enforce the TTL yet. The SuperLink will start enforcing the TTL in a future version of Flower.", + ); + } + const ttl_ = ttl !== null ? ttl : DEFAULT_TTL; + let message = new Message(createReplyMetadata(this, ttl_), content, null); + + if (!ttl) { + ttl = this.metadata.ttl - (message.metadata.createdAt - this.metadata.createdAt); + message.metadata.ttl = ttl; + } + return message; + }; +} + +const createReplyMetadata = (msg: Message, ttl: number) => { + return { + runId: msg.metadata.runId, + messageId: "", + srcNodeId: msg.metadata.dstNodeId, + dstNodeId: msg.metadata.srcNodeId, + replyToMessage: msg.metadata.messageId, + groupId: msg.metadata.groupId, + ttl: ttl, + messageType: msg.metadata.messageType, + } as Metadata; +};