Skip to content

Commit

Permalink
fixup! discojs: simplify types
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Sep 24, 2024
1 parent e3a55cb commit a9d1035
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 32 deletions.
4 changes: 2 additions & 2 deletions webapp/src/components/dataset_input/DataDescription.vue
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@
</template>

<script setup lang="ts">
import type { Task } from '@epfml/discojs'
import type { DataType, Task } from "@epfml/discojs";
import DropdownCard from '@/components/containers/DropdownCard.vue'
interface Props {
task: Task
task: Task<DataType>
}
const _ = defineProps<Props>()
</script>
8 changes: 4 additions & 4 deletions webapp/src/components/pages/TaskList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ import { VueSpinner } from 'vue3-spinners';
import { List } from "immutable";
import type { Task } from '@epfml/discojs'
import type { DataType, Task } from "@epfml/discojs";
import { useTasksStore } from '@/store/tasks'
import { useTrainingStore } from '@/store/training'
Expand All @@ -106,7 +106,7 @@ const sortedTasks = computed(() => [...tasks.value.values()].sort(
(task1, task2) => task1.displayInformation.taskTitle.localeCompare(task2.displayInformation.taskTitle)
))
function getSchemeColor(task: Task): string {
function getSchemeColor(task: Task<DataType>): string {
switch (task.trainingInformation.scheme) {
case 'decentralized':
return 'bg-orange-200'
Expand All @@ -116,7 +116,7 @@ function getSchemeColor(task: Task): string {
return 'bg-blue-200'
}
}
function getDataTypeColor(task: Task): string {
function getDataTypeColor(task: Task<DataType>): string {
switch (task.trainingInformation.dataType) {
case 'image':
return 'bg-yellow-200'
Expand All @@ -127,7 +127,7 @@ function getDataTypeColor(task: Task): string {
}
}
const toTask = (task: Task): void => {
function toTask(task: Task<DataType>): void {
trainingStore.setTask(task.id)
trainingStore.setStep(1)
router.push(`/${task.id}`)
Expand Down
4 changes: 2 additions & 2 deletions webapp/src/components/task_creation_form/TaskForm.vue
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ import { Form as VeeForm, ErrorMessage } from 'vee-validate'
import { List, Map } from 'immutable'
import * as tf from '@tensorflow/tfjs'
import type { Task } from '@epfml/discojs'
import type { DataType, Task } from "@epfml/discojs";
import { models, pushTask } from '@epfml/discojs'
import type { FormDependency, FormField, FormSection } from '@/task_creation_form'
Expand Down Expand Up @@ -275,7 +275,7 @@ const onSubmit = async (rawTask: any): Promise<void> => {
.map((section) => formatSection(section, rawTask))
)
.set('id', rawTask.taskID)
.toObject() as unknown as Task
.toObject() as unknown as Task<DataType>
let model
try {
Expand Down
4 changes: 2 additions & 2 deletions webapp/src/components/testing/TestSteps.vue
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ const toaster = useToaster();
const validationStore = useValidationStore();
const props = defineProps<{
task: Task;
model: Model;
task: Task<D>;
model: Model<D>;
}>();
interface Tested {
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/components/testing/Testing.vue
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ onActivated(() => {
selectModel(validationStore.modelID, "test");
});
async function downloadModel(task: Task): Promise<void> {
async function downloadModel(task: Task<DataType>): Promise<void> {
try {
toaster.info("Downloading model...");
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/components/testing/__tests__/Testing.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { useTasksStore } from "@/store/tasks";

import Testing from "../Testing.vue";

const TASK: Task = {
const TASK: Task<"text"> = {
id: "task",
displayInformation: {
taskTitle: "task title",
Expand Down
4 changes: 2 additions & 2 deletions webapp/src/components/training/Description.vue
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
<script lang="ts" setup>
import { computed } from 'vue'

import type { Task } from '@epfml/discojs'
import type { DataType, Task } from "@epfml/discojs";

import type { FormDependency, FormField, FormSection } from '@/task_creation_form'
import { trainingInformation, privacyParameters } from '@/task_creation_form'
Expand All @@ -66,7 +66,7 @@ import DropdownCard from '@/components/containers/DropdownCard.vue'
import ModelIcon from '@/assets/svg/ModelIcon.vue'

interface Props {
task: Task
task: Task<DataType>
}
const props = defineProps<Props>()

Expand Down
8 changes: 4 additions & 4 deletions webapp/src/components/training/Finished.vue
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
</div>
</template>

<script setup lang="ts">
<script setup lang="ts" generic="D extends DataType">
import { List } from "immutable";
import { ref, toRaw, watch } from "vue";
import { useRouter } from "vue-router";
import type { Model, Task } from "@epfml/discojs";
import type { DataType, Model, Task } from "@epfml/discojs";
import { useToaster } from "@/composables/toaster";
import type { ModelID } from "@/store/models";
Expand All @@ -39,8 +39,8 @@ const router = useRouter();
const toaster = useToaster();
const props = defineProps<{
task: Task;
model?: Model;
task: Task<D>;
model?: Model<D>;
}>();
const saved = ref<ModelID>();
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/components/training/Trainer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ const props = defineProps<{
dataset?: Dataset<Raw[D]>;
}>();
const emit = defineEmits<{
model: [Model];
model: [Model<D>];
}>();
const trainingGenerator =
Expand Down
15 changes: 9 additions & 6 deletions webapp/src/components/training/Training.vue
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
import { computed, onMounted, ref, toRaw, watch } from "vue";
import { useRouter, useRoute } from "vue-router";
import type { Dataset, DataType, Model, Raw, TaskID } from "@epfml/discojs";
import type {
Dataset,
DataType,
Model,
Raw,
Task,
TaskID,
} from "@epfml/discojs";
import { useTrainingStore } from "@/store/training";
import { useTasksStore } from "@/store/tasks";
Expand Down Expand Up @@ -58,7 +65,7 @@ function setupTrainingStore() {
}
// Init the task once the taskStore has been loaded successfully
// If it is not we redirect to the task list
const task = computed(() => {
const task = computed<Task<DataType> | undefined>(() => {
console.log("training: recompute task");
if (tasksStore.status == "success") {
Expand Down Expand Up @@ -98,10 +105,6 @@ const unamedDataset = computed<Dataset<Raw[DataType]> | undefined>(() => {
case "tabular":
case "text":
return dataset.value as Dataset<Raw["tabular" | "text"]>;
default: {
const _: never = task.value.trainingInformation;
throw new Error("should never happen");
}
}
});
const trainedModel = ref<Model<DataType>>();
Expand Down
9 changes: 6 additions & 3 deletions webapp/src/store/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { StateTree } from "pinia";
import { defineStore } from "pinia";
import { computed, ref, toRaw } from "vue";

import type { Model } from "@epfml/discojs";
import type { DataType, Model } from "@epfml/discojs";
import { serialization } from "@epfml/discojs";

export type ModelID = number;
Expand Down Expand Up @@ -33,14 +33,17 @@ export const useModelsStore = defineStore(
})),
);

async function get(id: ModelID): Promise<Model | undefined> {
async function get(id: ModelID): Promise<Model<DataType> | undefined> {
const infos = idToModel.value.get(id);
if (infos === undefined) return undefined;

return await serialization.model.decode(toRaw(infos.encoded));
}

async function add(taskID: string, model: Model): Promise<ModelID> {
async function add(
taskID: string,
model: Model<DataType>,
): Promise<ModelID> {
const dateSaved = new Date();
const id = dateSaved.getTime();

Expand Down
10 changes: 6 additions & 4 deletions webapp/src/store/tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { defineStore } from 'pinia'
import { shallowRef, ref } from 'vue'
import { Map } from 'immutable'

import type { TaskID, Task } from '@epfml/discojs'
import type { TaskID, Task, DataType } from "@epfml/discojs";
import { fetchTasks } from '@epfml/discojs'

import { useToaster } from '@/composables/toaster'
Expand All @@ -15,13 +15,13 @@ const debug = createDebug("webapp:store");
export const useTasksStore = defineStore('tasks', () => {
const trainingStore = useTrainingStore()

const tasks = shallowRef<Map<TaskID, Task>>(Map())
const tasks = shallowRef<Map<TaskID, Task<DataType>>>(Map())

// 3-state variable used to test whether the tasks have been retrieved successfully,
// if the retrieving failed, or if they are currently being loaded
const status = ref<'success' | 'failed' | 'loading'>('loading')

function addTask (task: Task): void {
function addTask (task: Task<DataType>): void {
trainingStore.setTask(task.id);
trainingStore.setStep(0);
tasks.value = tasks.value.set(task.id, task)
Expand All @@ -32,7 +32,9 @@ export const useTasksStore = defineStore('tasks', () => {
async function initTasks (): Promise<void> {
try {
status.value = 'loading'
const tasks = (await fetchTasks(CONFIG.serverUrl)).filter((t: Task) => !TASKS_TO_FILTER_OUT.includes(t.id))
const tasks = (await fetchTasks(CONFIG.serverUrl)).filter(
(t: Task<DataType>) => !TASKS_TO_FILTER_OUT.includes(t.id),
);

tasks.forEach(addTask)
status.value = 'success'
Expand Down

0 comments on commit a9d1035

Please sign in to comment.