Skip to content

Commit

Permalink
webapp/task-creation: revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Jan 8, 2025
1 parent 56458f8 commit 6e3cf2a
Show file tree
Hide file tree
Showing 18 changed files with 744 additions and 1,329 deletions.
31 changes: 27 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

119 changes: 119 additions & 0 deletions webapp/cypress/e2e/task-creation.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import type { Task } from "@epfml/discojs";

import * as tf from "@tensorflow/tfjs";

it("submits with tabular task", () => {
cy.intercept(
{ hostname: "server", pathname: "tasks", method: "POST" },
{ statusCode: 200 },
).as("posted");

cy.visit("/#/create");

cy.get("input[name='id']").type("id");
cy.get("select[name='dataType']").select("tabular");

cy.get("input[name='displayInformation.title']").type("simple");
cy.get("input[name='displayInformation.summary.preview']").type("preview");
cy.get("textarea[name='displayInformation.summary.overview']").type(
"overview",
);
cy.contains("Example tabular data").within(() => {
cy.contains("add example").click();
cy.get("input[name='displayInformation.dataExample[0].name']").type("name");
cy.get("input[name='displayInformation.dataExample[0].data']").type("data");
});

cy.get("select[name='trainingInformation.scheme']").select("federated");
cy.get("input[name='trainingInformation.epochs']").type("10");
cy.get("input[name='trainingInformation.batchSize']").type("5");
cy.get("input[name='trainingInformation.roundDuration']").type("2");
cy.get("input[name='trainingInformation.validationSplit']").type("0");
cy.get("input[name='trainingInformation.minNbOfParticipants']").type("2");
cy.contains("Input columns names").within(() => {
cy.contains("add column").click();
cy.get("input[name='trainingInformation.inputColumns[0]']").type("input");
});
cy.get("input[name='trainingInformation.outputColumn']").type("output");

const model = tf.sequential({
layers: [
tf.layers.conv2d({
inputShape: [32, 32, 3],
kernelSize: 3,
filters: 16,
activation: "relu",
}),
],
});
model.compile({ loss: "hinge", optimizer: "sgd" });
cy.wrap(getArtifacts(model))
.then((artifacts) => JSON.stringify(artifacts))
.then((json) => new TextEncoder().encode(json))
.then((raw) =>
cy.get("input[type='file']").selectFile(new Uint8Array(raw), {
force: true, // input is hidden
}),
);
cy.get("input[name='model.loss']").type("hinge");
cy.get("input[name='model.optimizer']").type("sgd");

cy.get("button[type='submit']").click();

cy.wait("@posted")
.its("request.body")
.then((body) => JSON.parse(body))
.its("task")
.should("deep.equal", {
id: "id",
dataType: "tabular",
displayInformation: {
title: "simple",
summary: {
preview: "preview",
overview: "overview",
},
dataExample: [{ name: "name", data: "data" }],
},
trainingInformation: {
scheme: "federated",
epochs: 10,
batchSize: 5,
roundDuration: 2,
validationSplit: 0,
minNbOfParticipants: 2,
inputColumns: ["input"],
outputColumn: "output",
tensorBackend: "tfjs",
},
} satisfies Task<"tabular">);
});

async function getArtifacts(
model: tf.LayersModel,
): Promise<tf.io.ModelArtifacts & { weightsManifest: never[] }> {
let resolveArtifacts: (_: tf.io.ModelArtifacts) => void;
const ret = new Promise<tf.io.ModelArtifacts>((resolve) => {
resolveArtifacts = resolve;
});

await model.save(
{
save: (artifacts) => {
resolveArtifacts(artifacts);
return Promise.resolve({
modelArtifactsInfo: {
dateSaved: new Date(),
modelTopologyType: "JSON",
},
});
},
},
{ includeOptimizer: true },
);

return {
...(await ret),
weightsManifest: [], // required by tf.loadLayersModel
};
}
2 changes: 2 additions & 0 deletions webapp/cypress/support/e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,5 @@ beforeEach(() =>
.getDirectory()
.then((root) => root.removeEntry("models", { recursive: true })),
);

before(() => (localStorage.debug = "discojs*,webapp*"));
2 changes: 1 addition & 1 deletion webapp/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
"@epfml/discojs": "*",
"@epfml/discojs-web": "*",
"@msgpack/msgpack": "^3.0.0-beta2",
"@vee-validate/zod": "4",
"apexcharts": "3",
"cypress": "13",
"d3": "7",
"immutable": "4",
"pinia": "<2.2.3",
"pinia-plugin-persistedstate-2": "2",
"vee-validate": "4",
"vue": "3",
"vue-router": "4",
"vue-tippy": "6",
Expand Down
23 changes: 23 additions & 0 deletions webapp/src/components/task_creation_form/FormField.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<template>
<Field
:name
class="p-2 bg-gray-100 dark:bg-slate-700 border rounded-md dark:border-slate-400 text-gray-700 dark:text-gray-200 focus:outline-none focus:border-disco-cyan"
v-model="model"
v-bind="$attrs"
>
<!-- forward slots, https://stackoverflow.com/questions/50891858/vue-how-to-pass-down-slots-inside-wrapper-component -->
<template v-for="(_, slot) in $slots" v-slot:[slot]="scope">
<slot :name="slot" v-bind="scope" />
</template>
</Field>

<ErrorMessage class="text-red-600" :name />
</template>

<script lang="ts" setup>
import { ErrorMessage, Field } from "vee-validate";
defineProps<{ name: string }>();
const model = defineModel<string>();
</script>
32 changes: 32 additions & 0 deletions webapp/src/components/task_creation_form/FormLabel.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<template>
<label class="root">
<span class="text-slate-600 dark:text-slate-200">
{{ label }}
</span>

<slot />
</label>
</template>

<style lang="css" scoped>
.root {
display: flex;
flex-direction: column;
flex: auto;
}
.root > * {
margin: 0.25rem 0.5rem;
}
.root > :first-child {
margin-left: 0;
}
.root > :first-child {
font-weight: bold;
}
</style>

<script lang="ts" setup>
defineProps<{ label: string }>();
</script>
Loading

0 comments on commit 6e3cf2a

Please sign in to comment.