Skip to content

Commit

Permalink
discojs-core: don't reexport tfjs
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Feb 28, 2024
1 parent 675c9f9 commit b425596
Show file tree
Hide file tree
Showing 65 changed files with 215 additions and 120 deletions.
4 changes: 3 additions & 1 deletion cli/src/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ import { Range } from 'immutable'
import fs from 'node:fs'
import fs_promises from 'fs/promises'
import path from 'node:path'
import tf from '@tensorflow/tfjs-node'

import { tf, node, data, type Task } from '@epfml/discojs-node'
import type { Task } from '@epfml/discojs-node'
import { node, data } from '@epfml/discojs-node'

function filesFromFolder (dir: string, folder: string, fractionToKeep: number): string[] {
const f = fs.readdirSync(dir + folder)
Expand Down
5 changes: 3 additions & 2 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { type client, type Task, type tf, type AsyncInformant } from '..'

import { List, Map, Set } from 'immutable'
import type tf from '@tensorflow/tfjs'

import type { client, Task, AsyncInformant } from '..'

export enum AggregationStep {
ADD,
Expand Down
6 changes: 4 additions & 2 deletions discojs/discojs-core/src/aggregator/mean.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { assert, expect } from 'chai'
import { type Map } from 'immutable'
import type { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import { aggregator, defaultTasks, type client, type Task, type tf } from '..'
import type { client, Task } from '..'
import { aggregator, defaultTasks } from '..'
import { AggregationStep } from './base'

const task = defaultTasks.titanic.getTask()
Expand Down
6 changes: 4 additions & 2 deletions discojs/discojs-core/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { type Map } from 'immutable'
import type { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import { type Task, type WeightsContainer, aggregation, type tf, type client } from '..'
import type { Task, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
Expand Down
9 changes: 5 additions & 4 deletions discojs/discojs-core/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { AggregationStep, Base as Aggregator } from './base'
import { tf, aggregation, type Task, type WeightsContainer, type client } from '..'

import * as crypto from 'crypto'

import { Map, List, Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import { AggregationStep, Base as Aggregator } from './base'
import type { Task, WeightsContainer, client } from '..'
import { aggregation } from '..'

/**
* Aggregator implementing secure multi-party computation for decentralized learning.
Expand Down
12 changes: 7 additions & 5 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import axios from 'axios'
import { type Set } from 'immutable'
import type { Set } from 'immutable'
import type tf from '@tensorflow/tfjs'

import { type tf, type Task, type TrainingInformant, serialization, type WeightsContainer } from '..'
import { type NodeID } from './types'
import { type EventConnection } from './event_connection'
import { type Aggregator } from '../aggregator'
import type { Task, TrainingInformant, WeightsContainer } from '..'
import { serialization } from '..'
import type { NodeID } from './types'
import type { EventConnection } from './event_connection'
import type { Aggregator } from '../aggregator'

/**
* Main, abstract, class representing a Disco client in a network, which handles
Expand Down
9 changes: 5 additions & 4 deletions discojs/discojs-core/src/dataset/data/data.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { type tf, type Task } from '../..'
import { type Dataset } from '../dataset'
import { type PreprocessingFunction } from './preprocessing/base'
import type tf from '@tensorflow/tfjs'
import type { List } from 'immutable'

import { type List } from 'immutable'
import type { Task } from '../..'
import type { Dataset } from '../dataset'
import type { PreprocessingFunction } from './preprocessing/base'

/**
* Abstract class representing an immutable Disco dataset, including a TF.js dataset,
Expand Down
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/dataset/data/image_data.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { assert, expect } from 'chai'
import tf from '@tensorflow/tfjs'

import { ImageData } from './image_data'
import { tf, type Task } from '../..'
import type { Task } from '../..'

describe('image data checks', () => {
const simplefaceMock: Task = {
Expand Down
6 changes: 4 additions & 2 deletions discojs/discojs-core/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { type tf, type Task } from '../..'
import { type Dataset } from '../dataset'
import type tf from '@tensorflow/tfjs'

import type { Task } from '../..'
import type { Dataset } from '../dataset'
import { Data } from './data'
import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing'

Expand Down
10 changes: 6 additions & 4 deletions discojs/discojs-core/src/dataset/data/preprocessing/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { type tf, type Task } from '../../..'
import { type ImagePreprocessing } from './image_preprocessing'
import { type TabularPreprocessing } from './tabular_preprocessing'
import { type TextPreprocessing } from './text_preprocessing'
import type tf from '@tensorflow/tfjs'

import type { Task } from '../../..'
import type { ImagePreprocessing } from './image_preprocessing'
import type { TabularPreprocessing } from './tabular_preprocessing'
import type { TextPreprocessing } from './text_preprocessing'

/**
* All available preprocessing type enums.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { type Task, tf } from '../../..'
import { type PreprocessingFunction } from './base'

import { List } from 'immutable'
import tf from '@tensorflow/tfjs'

import type { Task } from '../../..'
import type { PreprocessingFunction } from './base'

/**
* Available image preprocessing types.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { type Task, type tf } from '../../..'
import type tf from '@tensorflow/tfjs'
import { List } from 'immutable'
import { type PreprocessingFunction } from './base'

import type { Task } from '../../..'
import type { PreprocessingFunction } from './base'

/**
* Available tabular preprocessing types.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { type Task, tf } from '../../..'
import { type PreprocessingFunction } from './base'

import { List } from 'immutable'
import tf from '@tensorflow/tfjs'

import type { Task } from '../../..'
import type { PreprocessingFunction } from './base'

/**
* Available text preprocessing types.
Expand Down
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/dataset/data/tabular_data.spec.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { assert, expect } from 'chai'
import { Map, Set } from 'immutable'
import tf from '@tensorflow/tfjs'

import { TabularData } from './tabular_data'
import { tf, type Task } from '../..'
import type { Task } from '../..'

describe('tabular data checks', () => {
const titanicMock: Task = {
Expand Down
10 changes: 6 additions & 4 deletions discojs/discojs-core/src/dataset/data_loader/image_loader.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import { tf } from '../..'
import { type Dataset } from '../dataset'
import { type Data, ImageData, type DataSplit } from '../data'
import { DataLoader, type DataConfig } from '../data_loader'
import type { Dataset } from '../dataset'
import type { Data, DataSplit } from '../data'
import { ImageData } from '../data'
import type { DataConfig } from '../data_loader'
import { DataLoader } from '../data_loader'

/**
* Image data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/dataset/dataset.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type tf } from '..'
import type tf from '@tensorflow/tfjs'

/**
* Convenient type for the common dataset type used in TF.js.
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { tf, type Task, data, type TaskProvider } from '..'
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'

export const cifar10: TaskProvider = {
getTask (): Task {
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/default_tasks/geotags.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { tf, type Task, data, type TaskProvider } from '..'
import { Range } from 'immutable'
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'
import { LabelTypeEnum } from '../task/label_type'

export const geotags: TaskProvider = {
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/default_tasks/lus_covid.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { tf, type Task, data, type TaskProvider } from '..'
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'

export const lusCovid: TaskProvider = {
getTask (): Task {
Expand Down
4 changes: 3 additions & 1 deletion discojs/discojs-core/src/default_tasks/mnist.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { tf, type Task, type TaskProvider } from '..'
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'

export const mnist: TaskProvider = {
getTask (): Task {
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/default_tasks/simple_face.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { type tf, type Task, data, type TaskProvider } from '..'
import type tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'

export const simpleFace: TaskProvider = {
getTask (): Task {
Expand Down
4 changes: 3 additions & 1 deletion discojs/discojs-core/src/default_tasks/skin_mnist.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { tf, data } from '..'
import { data } from '..'

export const skinMnist: TaskProvider = {
getTask (): Task {
Expand Down
5 changes: 4 additions & 1 deletion discojs/discojs-core/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { tf, type Task, type TaskProvider, data } from '..'
import tf from '@tensorflow/tfjs'

import type { Task, TaskProvider } from '..'
import { data } from '..'

export const titanic: TaskProvider = {
getTask (): Task {
Expand Down
2 changes: 0 additions & 2 deletions discojs/discojs-core/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export * as tf from '@tensorflow/tfjs'

export * as data from './dataset'
export * as serialization from './serialization'
export * as training from './training'
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/logging/trainer_logger.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { List } from 'immutable'
import tf from '@tensorflow/tfjs'

import { tf } from '..'
import { ConsoleLogger } from '.'

export class TrainerLog {
Expand Down
6 changes: 3 additions & 3 deletions discojs/discojs-core/src/memory/base.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// only used browser-side
// TODO: replace IO type
import type * as tf from '@tensorflow/tfjs'
import type tf from '@tensorflow/tfjs'

import { type TaskID } from '..'
import { type ModelType } from './model_type'
import type { TaskID } from '..'
import type { ModelType } from './model_type'

/**
* Model path which uniquely identifies a model in memory.
Expand Down
5 changes: 3 additions & 2 deletions discojs/discojs-core/src/memory/empty.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { type tf } from '..'
import type tf from '@tensorflow/tfjs'

import { Memory, type ModelInfo, type Path } from './base'
import type { ModelInfo, Path } from './base'
import { Memory } from './base'

/**
* Represents an empty model memory.
Expand Down
4 changes: 3 additions & 1 deletion discojs/discojs-core/src/privacy.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { tf, type Task, type WeightsContainer } from '.'
import tf from '@tensorflow/tfjs'

import type { Task, WeightsContainer } from '.'

/**
* Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
Expand Down
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/serialization/model.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { assert } from 'chai'
import tf from '@tensorflow/tfjs'

import { tf, serialization } from '..'
import { serialization } from '..'

async function getRawWeights (model: tf.LayersModel): Promise<Array<[number, Float32Array]>> {
return Array.from(
Expand Down
2 changes: 1 addition & 1 deletion discojs/discojs-core/src/serialization/model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { tf } from '..'
import tf from '@tensorflow/tfjs'
import msgpack from 'msgpack-lite'

export type Encoded = number[]
Expand Down
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/serialization/weights.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as msgpack from 'msgpack-lite'
import tf from '@tensorflow/tfjs'

import { tf, WeightsContainer } from '..'
import { WeightsContainer } from '..'

interface Serialized {
shape: number[]
Expand Down
6 changes: 4 additions & 2 deletions discojs/discojs-core/src/task/task_handler.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import axios from 'axios'
import { Map } from 'immutable'
import type tf from '@tensorflow/tfjs'

import { serialization, type tf, WeightsContainer } from '..'
import { isTask, type Task, type TaskID } from './task'
import { serialization, WeightsContainer } from '..'
import type { Task, TaskID } from './task'
import { isTask } from './task'

const TASK_ENDPOINT = 'tasks'

Expand Down
4 changes: 3 additions & 1 deletion discojs/discojs-core/src/task/task_provider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { type tf, type Task } from '..'
import type tf from '@tensorflow/tfjs'

import type { Task } from '..'

export interface TaskProvider {
getTask: () => Task
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { type tf, type Memory, type Task, type TrainingInformant, type TrainingFunction, WeightsContainer, type client as clients } from '../..'
import { type Aggregator } from '../../aggregator'
import type tf from '@tensorflow/tfjs'

import type { Memory, Task, TrainingInformant, TrainingFunction, client as clients } from '../..'
import { WeightsContainer } from '../..'
import type { Aggregator } from '../../aggregator'
import { Trainer } from './trainer'

/**
Expand Down
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/training/trainer/local_trainer.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { type tf } from '../..'
import type tf from '@tensorflow/tfjs'

import { Trainer } from './trainer'

/** Class whose role is to locally (alone) train a model on a given dataset,
Expand Down
8 changes: 6 additions & 2 deletions discojs/discojs-core/src/training/trainer/trainer.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { type tf, type Memory, type Task, type TrainingInformant, type TrainingFunction, fitModelFunctions } from '../..'
import type tf from '@tensorflow/tfjs'

import type { Memory, Task, TrainingInformant, TrainingFunction } from '../..'
import { fitModelFunctions } from '../..'

import { RoundTracker } from './round_tracker'
import { TrainerLogger, type TrainerLog } from '../../logging/trainer_logger'
import type { TrainerLog } from '../../logging/trainer_logger'
import { TrainerLogger } from '../../logging/trainer_logger'

/** Abstract class whose role is to train a model with a given dataset. This can be either done
* locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
Expand Down
9 changes: 6 additions & 3 deletions discojs/discojs-core/src/training/trainer/trainer_builder.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { type tf, type client as clients, type Task, type TrainingInformant, type TrainingFunction, type Memory, ModelType, type ModelInfo } from '../..'
import { type Aggregator } from '../../aggregator'
import type tf from '@tensorflow/tfjs'

import type { client as clients, Task, TrainingInformant, TrainingFunction, ModelInfo, Memory } from '../..'
import { ModelType } from '../..'
import type { Aggregator } from '../../aggregator'

import { DistributedTrainer } from './distributed_trainer'
import { LocalTrainer } from './local_trainer'
import { type Trainer } from './trainer'
import type { Trainer } from './trainer'

/**
* A class that helps build the Trainer and auxiliary classes.
Expand Down
Loading

0 comments on commit b425596

Please sign in to comment.