Skip to content

Commit

Permalink
*: Provide initial tasks to the server when calling serve instead o…
Browse files Browse the repository at this point in the history
…f the constructor and get rid of the static async server constructor
  • Loading branch information
JulienVig committed Oct 9, 2024
1 parent 5d20ac9 commit e2314c6
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 32 deletions.
3 changes: 1 addition & 2 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
contextLength, batchSize, modelPath } = args

// Launch a server instance
const disco = await Server.of(defaultTasks.wikitext);
const [server, url] = await disco.serve();
const [server, url] = await new Server().serve(undefined, defaultTasks.wikitext);

// Fetch the wikitext task from the server
const tasks = await fetchTasks(url)
Expand Down
3 changes: 1 addition & 2 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ async function main (provider: TaskProvider, numberOfUsers: number): Promise<voi
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })

const discoServer = await Server.of(provider)
const [server, url] = await discoServer.serve()
const [server, url] = await new Server().serve(undefined, provider)

const data = await getTaskData(task)

Expand Down
12 changes: 5 additions & 7 deletions docs/examples/custom_task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,7 @@ const customTask: TaskProvider = {

async function runServer (): Promise<void> {
// Create server
const server = await DiscoServer.of(
// with some tasks provided by Disco
defaultTasks.titanic,
// or your own custom task
customTask,
)
const server = new DiscoServer()

// You can also provide your own task object containing the URL of the model

Expand All @@ -86,7 +81,10 @@ async function runServer (): Promise<void> {
// await server.addTask(customTask.getTask(), new URL('https://example.com/path/to/your/model.json'))

// Start the server
await server.serve()
await server.serve(8080,
defaultTasks.titanic, // with some tasks provided by Disco
// or your own custom task
customTask,)
}

runServer().catch(console.error)
3 changes: 1 addition & 2 deletions docs/examples/training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ async function main (): Promise<void> {
const NAME: string = 'titanic'

// Launch a server instance
const discoServer = await Server.of(defaultTasks.simpleFace, defaultTasks.titanic)
const [server, url] = await discoServer.serve()
const [server, url] = await new Server().serve(undefined, defaultTasks.simpleFace, defaultTasks.titanic)

// Get all pre-defined tasks
const tasks = await fetchTasks(url)
Expand Down
5 changes: 2 additions & 3 deletions server/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import { Server } from "./server.js";
const PORT = 8080;

const providers = Object.values(defaultTasks);
// Init the server with default tasks
const server = await Server.of(...providers);

console.info("Server loaded the tasks below");
console.table(
Expand All @@ -26,5 +24,6 @@ console.table(
})),
);

const [_, serverURL] = await server.serve(PORT);
// Init the server with default tasks
const [_, serverURL] = await new Server().serve(PORT, ...providers);
console.log(`Disco Server listening on ${serverURL.toString()}`);
14 changes: 6 additions & 8 deletions server/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ const debug = createDebug("server");
export class Server {
readonly #taskSet = new TaskSet();

// Static method to asynchronously init the Server
static async of(...tasks: TaskProvider[]): Promise<Server> {
const ret = new Server();
await Promise.all(tasks.map((t) => ret.addTask(t)));
return ret;
}

async addTask(taskProvider: TaskProvider): Promise<void> {
await this.#taskSet.addTask(taskProvider);
}
Expand All @@ -38,8 +31,11 @@ export class Server {
* start server
*
* @param port where to start, if not given, choose a random one
* @param tasks list of initial tasks to serve
* @returns a tuple with the server instance and the URL
*
**/
async serve(port?: number): Promise<[http.Server, URL]> {
async serve(port?: number, ...tasks: TaskProvider[]): Promise<[http.Server, URL]> {
const wsApplier = expressWS(express(), undefined, {
leaveRouterUntouched: true,
});
Expand All @@ -53,6 +49,8 @@ export class Server {
const taskRouter = new TaskRouter(this.#taskSet)
const federatedRouter = new TrainingRouter('federated', wsApplier, this.#taskSet)
const decentralizedRouter = new TrainingRouter('decentralized', wsApplier, this.#taskSet)
// Add tasks to the server
await Promise.all(tasks.map((t) => this.addTask(t)));

wsApplier.getWss().on('connection', (ws, req) => {
if (!federatedRouter.isValidUrl(req.url) && !decentralizedRouter.isValidUrl(req.url)) {
Expand Down
3 changes: 1 addition & 2 deletions server/tests/client/decentralized.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ function test (
let server: http.Server
let url: URL
beforeEach(async () => {
const disco = await Server.of(TASK);
[server, url] = await disco.serve();
[server, url] = await new Server().serve(undefined, TASK);
});
afterEach(() => { server?.close() })

Expand Down
3 changes: 1 addition & 2 deletions server/tests/client/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ describe("federated client", () => {
let server: http.Server;
let url: URL;
beforeEach(async () => {
const disco = await Server.of(TASK_PROVIDER);
[server, url] = await disco.serve();
[server, url] = await new Server().serve(undefined, TASK_PROVIDER);
});
afterEach(() => {
server?.close();
Expand Down
3 changes: 1 addition & 2 deletions server/tests/e2e/decentralized.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ describe('end-to-end decentralized', function () {
let server: http.Server
let url: URL
beforeEach(async () => {
const disco = await Server.of(defaultTasks.cifar10, defaultTasks.lusCovid);
[server, url] = await disco.serve();
[server, url] = await new Server().serve(undefined, defaultTasks.cifar10, defaultTasks.lusCovid);
});
afterEach(() => { server?.close() })

Expand Down
5 changes: 3 additions & 2 deletions server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ describe("end-to-end federated", () => {
let url: URL;
beforeEach(async function () {
this.timeout("10s");
[server, url] = await Server.of(
[server, url] = await new Server().serve(
undefined,
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.titanic,
defaultTasks.wikitext,
).then((s) => s.serve());
);
});
afterEach(() => {
server?.close();
Expand Down

0 comments on commit e2314c6

Please sign in to comment.