diff --git a/.github/Architecture.md b/.github/Architecture.md index d19920460..c140163df 100644 --- a/.github/Architecture.md +++ b/.github/Architecture.md @@ -7,228 +7,240 @@ | |- 📂 training: | | |- 📂 routes: | | | |- 📂 tabular: -| | | | |- 📜 schemas.py -| | | | |- 📜 __init__.py | | | | |- 📜 tabular.py -| | | |- 📂 datasets: -| | | | |- 📂 default: -| | | | | |- 📜 schemas.py -| | | | | |- 📜 __init__.py -| | | | | |- 📜 columns.py | | | | |- 📜 __init__.py +| | | | |- 📜 schemas.py | | | |- 📂 image: +| | | | |- 📜 __init__.py | | | | |- 📜 image.py | | | | |- 📜 schemas.py +| | | |- 📂 audio: | | | | |- 📜 __init__.py +| | | | |- 📜 audio.py +| | | | |- 📜 schemas.py +| | | |- 📂 datasets: +| | | | |- 📂 default: +| | | | | |- 📜 columns.py +| | | | | |- 📜 __init__.py +| | | | | |- 📜 schemas.py +| | | | |- 📜 __init__.py +| | | |- 📜 __init__.py | | | |- 📜 schemas.py +| | |- 📂 middleware: | | | |- 📜 __init__.py +| | | |- 📜 health_check_middleware.py | | |- 📂 core: -| | | |- 📜 criterion.py -| | | |- 📜 dl_model.py : torch model based on user specifications from drag and drop -| | | |- 📜 dataset.py : read in the dataset through URL or file upload +| | | |- 📂 celery: +| | | | |- 📜 trainer.py +| | | | |- 📜 criterion.py +| | | | |- 📜 dl_model.py : torch model based on user specifications from drag and drop +| | | | |- 📜 train_types.py +| | | | |- 📜 dataset.py : read in the dataset through URL or file upload +| | | | |- 📜 __init__.py +| | | | |- 📜 Dockerfile +| | | | |- 📜 worker.py +| | | | |- 📜 optimizer.py : what optimizer to use (ie: SGD or Adam for now) | | | |- 📜 __init__.py | | | |- 📜 authenticator.py -| | | |- 📜 trainer.py -| | | |- 📜 optimizer.py : what optimizer to use (ie: SGD or Adam for now) -| | |- 📂 middleware: -| | | |- 📜 __init__.py -| | | |- 📜 health_check_middleware.py +| | |- 📜 asgi.py +| | |- 📜 constants.py : list of helpful constants +| | |- 📜 celery_app.py | | |- 📜 settings.py -| | |- 📜 urls.py | | |- 📜 __init__.py | | |- 📜 wsgi.py -| | |- 📜 asgi.py +| | |- 📜 urls.py +| | |- 📜 celeryconfig.py +| |- 📜 README.md | |- 📜 docker-compose.yml -| |- 📜 docker-compose.prod.yml +| |- 📜 cli.py | |- 📜 pyproject.toml -| |- 📜 README.md | |- 📜 poetry.lock -| |- 📜 cli.py -| |- 📜 environment.yml -| |- 📜 Dockerfile.prod +| |- 📜 pytest.ini | |- 📜 Dockerfile | |- 📜 manage.py -| |- 📜 pytest.ini +| |- 📜 environment.yml +| |- 📜 docker-compose.prod.yml ``` ## Frontend Architecture ``` 📦 frontend -| |- 📂 layer_docs: -| | |- 📜 Linear.md : Doc for Linear layer -| | |- 📜 Softmax.md : Doc for Softmax layer -| | |- 📜 softmax_equation.png : PNG file of Softmax equation -| | |- 📜 ReLU.md : Doc for ReLU later | |- 📂 public: | | |- 📂 images: -| | | |- 📂 wiki_images: -| | | | |- 📜 maxpool2d.gif -| | | | |- 📜 conv2d.gif -| | | | |- 📜 softmax_equation.png : PNG file of Softmax equation -| | | | |- 📜 tanh_equation.png -| | | | |- 📜 dropout_diagram.png -| | | | |- 📜 batchnorm_diagram.png -| | | | |- 📜 tanh_plot.png -| | | | |- 📜 conv2d2.gif -| | | | |- 📜 sigmoid_equation.png -| | | | |- 📜 avgpool_maxpool.gif -| | | |- 📂 learn_mod_images: -| | | | |- 📜 lossExampleEquation.png -| | | | |- 📜 sigmoidactivation.png -| | | | |- 📜 neuralnet.png -| | | | |- 📜 binarystepactivation.png -| | | | |- 📜 lossExample.png -| | | | |- 📜 LeakyReLUactivation.png -| | | | |- 📜 lossExampleTable.png -| | | | |- 📜 tanhactivation.png -| | | | |- 📜 robotImage.jpg -| | | | |- 📜 ReLUactivation.png -| | | | |- 📜 neuron.png -| | | | |- 📜 neuronWithEquation.png -| | | | |- 📜 sigmoidfunction.png | | | |- 📂 logos: | | | | |- 📂 dlp_branding: -| | | | | |- 📜 dlp-logo.svg : DLP Logo, duplicate of files in public, but essential as the frontend can't read public | | | | | |- 📜 dlp-logo.png : DLP Logo, duplicate of files in public, but essential as the frontend can't read public -| | | | |- 📜 google.png -| | | | |- 📜 pytorch-logo.png +| | | | | |- 📜 dlp-logo.svg : DLP Logo, duplicate of files in public, but essential as the frontend can't read public +| | | | |- 📜 dsgt-logo-white-back.png | | | | |- 📜 python-logo.png -| | | | |- 📜 dsgt-logo-dark.png -| | | | |- 📜 aws-logo.png -| | | | |- 📜 github.png +| | | | |- 📜 google.png | | | | |- 📜 pandas-logo.png | | | | |- 📜 react-logo.png -| | | | |- 📜 dsgt-logo-white-back.png | | | | |- 📜 flask-logo.png +| | | | |- 📜 aws-logo.png +| | | | |- 📜 github.png +| | | | |- 📜 dsgt-logo-dark.png | | | | |- 📜 dsgt-logo-light.png +| | | | |- 📜 pytorch-logo.png +| | | |- 📂 learn_mod_images: +| | | | |- 📜 neuron.png +| | | | |- 📜 ReLUactivation.png +| | | | |- 📜 LeakyReLUactivation.png +| | | | |- 📜 lossExampleEquation.png +| | | | |- 📜 lossExampleTable.png +| | | | |- 📜 robotImage.jpg +| | | | |- 📜 neuralnet.png +| | | | |- 📜 sigmoidfunction.png +| | | | |- 📜 lossExample.png +| | | | |- 📜 tanhactivation.png +| | | | |- 📜 binarystepactivation.png +| | | | |- 📜 sigmoidactivation.png +| | | | |- 📜 neuronWithEquation.png +| | | |- 📂 wiki_images: +| | | | |- 📜 sigmoid_equation.png +| | | | |- 📜 conv2d.gif +| | | | |- 📜 conv2d2.gif +| | | | |- 📜 avgpool_maxpool.gif +| | | | |- 📜 softmax_equation.png : PNG file of Softmax equation +| | | | |- 📜 dropout_diagram.png +| | | | |- 📜 batchnorm_diagram.png +| | | | |- 📜 maxpool2d.gif +| | | | |- 📜 tanh_equation.png +| | | | |- 📜 tanh_plot.png | | | |- 📜 demo_video.gif : GIF tutorial of a simple classification training session -| | |- 📜 robots.txt +| | |- 📜 dlp-logo.ico : DLP Logo | | |- 📜 manifest.json : Default React file for choosing icon based on | | |- 📜 index.html : Base HTML file that will be initially rendered -| | |- 📜 dlp-logo.ico : DLP Logo +| | |- 📜 robots.txt +| |- 📂 layer_docs: +| | |- 📜 Softmax.md : Doc for Softmax layer +| | |- 📜 Linear.md : Doc for Linear layer +| | |- 📜 softmax_equation.png : PNG file of Softmax equation +| | |- 📜 ReLU.md : Doc for ReLU later | |- 📂 src: -| | |- 📂 pages: -| | | |- 📂 train: -| | | | |- 📜 [train_space_id].tsx -| | | | |- 📜 index.tsx -| | | |- 📜 dashboard.tsx -| | | |- 📜 learn.tsx -| | | |- 📜 settings.tsx -| | | |- 📜 about.tsx -| | | |- 📜 feedback.tsx -| | | |- 📜 wiki.tsx -| | | |- 📜 forgot.tsx -| | | |- 📜 _document.tsx -| | | |- 📜 LearnContent.tsx -| | | |- 📜 _app.tsx -| | | |- 📜 login.tsx +| | |- 📂 __tests__: +| | | |- 📂 common: +| | | | |- 📂 components: +| | | | | |- 📜 TitleText.test.tsx +| | |- 📂 backend_outputs: +| | | |- 📜 my_deep_learning_model.onnx : Last ONNX file output +| | | |- 📜 model.pkl +| | | |- 📜 model.pt : Last model.pt output | | |- 📂 common: -| | | |- 📂 utils: -| | | | |- 📜 dateFormat.ts -| | | | |- 📜 dndHelpers.ts -| | | | |- 📜 firebase.ts -| | | |- 📂 components: -| | | | |- 📜 Spacer.tsx -| | | | |- 📜 DlpTooltip.tsx -| | | | |- 📜 TitleText.tsx -| | | | |- 📜 EmailInput.tsx -| | | | |- 📜 HtmlTooltip.tsx -| | | | |- 📜 ClientOnlyPortal.tsx -| | | | |- 📜 Footer.tsx -| | | | |- 📜 NavBarMain.tsx | | | |- 📂 styles: | | | | |- 📜 Home.module.css | | | | |- 📜 globals.css | | | |- 📂 redux: | | | | |- 📜 hooks.ts | | | | |- 📜 store.ts +| | | | |- 📜 train.ts | | | | |- 📜 backendApi.ts | | | | |- 📜 userLogin.ts -| | | | |- 📜 train.ts +| | | |- 📂 utils: +| | | | |- 📜 dateFormat.ts +| | | | |- 📜 firebase.ts +| | | | |- 📜 dndHelpers.ts +| | | |- 📂 components: +| | | | |- 📜 EmailInput.tsx +| | | | |- 📜 HtmlTooltip.tsx +| | | | |- 📜 NavBarMain.tsx +| | | | |- 📜 Footer.tsx +| | | | |- 📜 DlpTooltip.tsx +| | | | |- 📜 ClientOnlyPortal.tsx +| | | | |- 📜 Spacer.tsx +| | | | |- 📜 TitleText.tsx | | |- 📂 features: +| | | |- 📂 LearnMod: +| | | | |- 📜 MCQuestion.tsx +| | | | |- 📜 ModulesSideBar.tsx +| | | | |- 📜 ImageComponent.tsx +| | | | |- 📜 ClassCard.tsx +| | | | |- 📜 FRQuestion.tsx +| | | | |- 📜 Exercise.tsx +| | | | |- 📜 LearningModulesContent.tsx +| | | |- 📂 OpenAi: +| | | | |- 📜 openAiUtils.ts +| | | |- 📂 Dashboard: +| | | | |- 📂 redux: +| | | | | |- 📜 dashboardApi.ts +| | | | |- 📂 components: +| | | | | |- 📜 TrainBarChart.tsx +| | | | | |- 📜 TrainDoughnutChart.tsx +| | | | | |- 📜 TrainDataGrid.tsx | | | |- 📂 Train: +| | | | |- 📂 redux: +| | | | | |- 📜 trainspaceSlice.ts +| | | | | |- 📜 trainspaceApi.ts +| | | | |- 📂 types: +| | | | | |- 📜 trainTypes.ts | | | | |- 📂 constants: | | | | | |- 📜 trainConstants.ts -| | | | |- 📂 components: -| | | | | |- 📜 CreateTrainspace.tsx -| | | | | |- 📜 TrainspaceLayout.tsx -| | | | | |- 📜 DatasetStepLayout.tsx | | | | |- 📂 features: -| | | | | |- 📂 Image: -| | | | | | |- 📂 constants: -| | | | | | | |- 📜 imageConstants.ts -| | | | | | |- 📂 components: -| | | | | | | |- 📜 ImageTrainspace.tsx -| | | | | | | |- 📜 ImageParametersStep.tsx -| | | | | | | |- 📜 ImageFlow.tsx -| | | | | | | |- 📜 ImageDatasetStep.tsx -| | | | | | | |- 📜 ImageReviewStep.tsx +| | | | | |- 📂 Tabular: | | | | | | |- 📂 redux: -| | | | | | | |- 📜 imageActions.ts -| | | | | | | |- 📜 imageApi.ts +| | | | | | | |- 📜 tabularActions.ts +| | | | | | | |- 📜 tabularApi.ts | | | | | | |- 📂 types: -| | | | | | | |- 📜 imageTypes.ts -| | | | | | |- 📜 index.ts -| | | | | |- 📂 Tabular: +| | | | | | | |- 📜 tabularTypes.ts | | | | | | |- 📂 constants: | | | | | | | |- 📜 tabularConstants.ts | | | | | | |- 📂 components: -| | | | | | | |- 📜 TabularDatasetStep.tsx -| | | | | | | |- 📜 TabularReviewStep.tsx -| | | | | | | |- 📜 TabularFlow.tsx | | | | | | | |- 📜 TabularTrainspace.tsx +| | | | | | | |- 📜 TabularReviewStep.tsx | | | | | | | |- 📜 TabularParametersStep.tsx +| | | | | | | |- 📜 TabularDatasetStep.tsx +| | | | | | | |- 📜 TabularFlow.tsx +| | | | | | |- 📜 index.ts +| | | | | |- 📂 Image: | | | | | | |- 📂 redux: -| | | | | | | |- 📜 tabularApi.ts -| | | | | | | |- 📜 tabularActions.ts +| | | | | | | |- 📜 imageApi.ts +| | | | | | | |- 📜 imageActions.ts | | | | | | |- 📂 types: -| | | | | | | |- 📜 tabularTypes.ts +| | | | | | | |- 📜 imageTypes.ts +| | | | | | |- 📂 constants: +| | | | | | | |- 📜 imageConstants.ts +| | | | | | |- 📂 components: +| | | | | | | |- 📜 ImageReviewStep.tsx +| | | | | | | |- 📜 ImageTrainspace.tsx +| | | | | | | |- 📜 ImageFlow.tsx +| | | | | | | |- 📜 ImageParametersStep.tsx +| | | | | | | |- 📜 ImageDatasetStep.tsx | | | | | | |- 📜 index.ts -| | | | |- 📂 redux: -| | | | | |- 📜 trainspaceApi.ts -| | | | | |- 📜 trainspaceSlice.ts -| | | | |- 📂 types: -| | | | | |- 📜 trainTypes.ts +| | | | |- 📂 components: +| | | | | |- 📜 CreateTrainspace.tsx +| | | | | |- 📜 DatasetStepLayout.tsx +| | | | | |- 📜 TrainspaceLayout.tsx | | | |- 📂 Feedback: | | | | |- 📂 redux: | | | | | |- 📜 feedbackApi.ts -| | | |- 📂 Dashboard: -| | | | |- 📂 components: -| | | | | |- 📜 TrainDoughnutChart.tsx -| | | | | |- 📜 TrainBarChart.tsx -| | | | | |- 📜 TrainDataGrid.tsx -| | | | |- 📂 redux: -| | | | | |- 📜 dashboardApi.ts -| | | |- 📂 LearnMod: -| | | | |- 📜 FRQuestion.tsx -| | | | |- 📜 MCQuestion.tsx -| | | | |- 📜 LearningModulesContent.tsx -| | | | |- 📜 Exercise.tsx -| | | | |- 📜 ClassCard.tsx -| | | | |- 📜 ImageComponent.tsx -| | | | |- 📜 ModulesSideBar.tsx -| | | |- 📂 OpenAi: -| | | | |- 📜 openAiUtils.ts -| | |- 📂 backend_outputs: -| | | |- 📜 my_deep_learning_model.onnx : Last ONNX file output -| | | |- 📜 model.pkl -| | | |- 📜 model.pt : Last model.pt output -| | |- 📂 __tests__: -| | | |- 📂 common: -| | | | |- 📂 components: -| | | | | |- 📜 TitleText.test.tsx -| | |- 📜 next-env.d.ts +| | |- 📂 pages: +| | | |- 📂 train: +| | | | |- 📜 [train_space_id].tsx +| | | | |- 📜 metrics_to_charts.tsx +| | | | |- 📜 index.tsx +| | | |- 📜 _app.tsx +| | | |- 📜 forgot.tsx +| | | |- 📜 about.tsx +| | | |- 📜 settings.tsx +| | | |- 📜 _document.tsx +| | | |- 📜 feedback.tsx +| | | |- 📜 dashboard.tsx +| | | |- 📜 learn.tsx +| | | |- 📜 LearnContent.tsx +| | | |- 📜 login.tsx +| | | |- 📜 wiki.tsx +| | |- 📜 constants.ts | | |- 📜 iris.csv : Sample CSV data | | |- 📜 GlobalStyle.ts -| | |- 📜 constants.ts -| |- 📜 next-env.d.ts -| |- 📜 .eslintignore -| |- 📜 jest.config.ts +| | |- 📜 next-env.d.ts | |- 📜 pnpm-lock.yaml -| |- 📜 .eslintrc.json -| |- 📜 next.config.js | |- 📜 tsconfig.json | |- 📜 package.json +| |- 📜 .eslintrc.json +| |- 📜 next.config.js +| |- 📜 next-env.d.ts +| |- 📜 jest.config.ts +| |- 📜 .eslintignore ``` diff --git a/dlp-terraform/ecs/s3.tf b/dlp-terraform/ecs/s3.tf new file mode 100644 index 000000000..2631fc1d5 --- /dev/null +++ b/dlp-terraform/ecs/s3.tf @@ -0,0 +1,15 @@ +resource "aws_s3_bucket" "s3bucket_executions" { + bucket = "dlp-executions" + + tags = { + Name = "Execution data" + } +} +resource "aws_s3_bucket_public_access_block" "access_block_uploads" { + bucket = aws_s3_bucket.s3bucket_executions.id + + block_public_acls = true + block_public_policy = true + ignore_public_acls = true + restrict_public_buckets = true +} diff --git a/dlp-terraform/ecs/sqs.tf b/dlp-terraform/ecs/sqs.tf new file mode 100644 index 000000000..d5ca20179 --- /dev/null +++ b/dlp-terraform/ecs/sqs.tf @@ -0,0 +1,27 @@ +resource "aws_sqs_queue" "training_queue" { + name = "training-queue.fifo" + fifo_queue = true + message_retention_seconds = 60*24 + + redrive_policy = jsonencode({ + deadLetterTargetArn = aws_sqs_queue.training_queue_deadletter.arn + maxReceiveCount = 4 + }) +} + +resource "aws_sqs_queue" "training_queue_deadletter" { + name = "training-deadletter-queue" +} + +resource "aws_sqs_queue_redrive_allow_policy" "training_queue_redrive_allow_policy" { + queue_url = aws_sqs_queue.training_queue_deadletter.id + + redrive_allow_policy = jsonencode({ + redrivePermission = "byQueue", + sourceQueueArns = [aws_sqs_queue.training_queue.arn] + }) +} + +output "sqs_queue_url" { + value = aws_sqs_queue.training_queue.url +} \ No newline at end of file diff --git a/frontend/next.config.js b/frontend/next.config.js index 089c75cb2..180c1257b 100644 --- a/frontend/next.config.js +++ b/frontend/next.config.js @@ -19,7 +19,7 @@ const nextConfig = { { source: "/api/lambda/:path*", destination: - "https://em9iri9g4j.execute-api.us-west-2.amazonaws.com/:path*", + "https://qt6nzp3sjd.execute-api.us-east-1.amazonaws.com/:path*", }, { source: "/api/training/:path*", diff --git a/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx b/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx index f01623925..ea548121f 100644 --- a/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx +++ b/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx @@ -19,6 +19,7 @@ import { import { useTrainImageMutation } from "../redux/imageApi"; import { useRouter } from "next/router"; import { removeTrainspaceData } from "@/features/Train/redux/trainspaceSlice"; +import { useCreateTrainspaceMutation } from "@/features/Train/redux/trainspaceApi"; const ImageTrainspace = () => { const trainspace = useAppSelector( @@ -93,6 +94,7 @@ const TrainspaceStepInner = ({ const Component = STEP_SETTINGS[TRAINSPACE_SETTINGS.steps[step]].component; const [isStepModified, setIsStepModified] = useState(false); const [train] = useTrainImageMutation(); + const [createTrainspace] = useCreateTrainspaceMutation(); const[isButtonClicked, setIsButtonClicked] = useState(false); const dispatch = useAppDispatch(); const router = useRouter(); @@ -100,13 +102,17 @@ const TrainspaceStepInner = ({ if (trainspace.step < TRAINSPACE_SETTINGS.steps.length) setStep(trainspace.step); else { - train(trainspace) - .unwrap() - .then(({ trainspaceId }) => { - router.push({ pathname: `/train/${trainspaceId}` }).then(() => { - dispatch(removeTrainspaceData()); - }); + const inner = async () => { + const { trainspaceId } = await createTrainspace(trainspace).unwrap(); + await train({ + trainspaceData: trainspace, + trainspaceId: trainspaceId, + }).unwrap(); + router.push({ pathname: `/train/${trainspaceId}` }).then(() => { + dispatch(removeTrainspaceData()); }); + }; + inner(); } }, [trainspace]); if (!Component) return <>; diff --git a/frontend/src/features/Train/features/Image/redux/imageApi.ts b/frontend/src/features/Train/features/Image/redux/imageApi.ts index 7cd5cfb70..56b64f6bf 100644 --- a/frontend/src/features/Train/features/Image/redux/imageApi.ts +++ b/frontend/src/features/Train/features/Image/redux/imageApi.ts @@ -5,12 +5,13 @@ const imageApi = backendApi.injectEndpoints({ endpoints: (builder) => ({ trainImage: builder.mutation< { trainspaceId: string }, - TrainspaceData<"TRAIN"> + { trainspaceData: TrainspaceData<"TRAIN">; trainspaceId: string } >({ - query: (trainspaceData) => ({ + query: ({ trainspaceData, trainspaceId }) => ({ url: "/api/train/img-run", method: "POST", body: { + trainspace_id: trainspaceId, name: trainspaceData.name, data_source: trainspaceData.dataSource, dataset_data: { diff --git a/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx b/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx index 40aef3e08..ab6adc3a1 100644 --- a/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx +++ b/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx @@ -19,6 +19,7 @@ import { import { useTrainTabularMutation } from "../redux/tabularApi"; import { useRouter } from "next/router"; import { removeTrainspaceData } from "@/features/Train/redux/trainspaceSlice"; +import { useCreateTrainspaceMutation } from "@/features/Train/redux/trainspaceApi"; const TabularTrainspace = () => { const trainspace = useAppSelector( @@ -93,6 +94,7 @@ const TrainspaceStepInner = ({ const Component = STEP_SETTINGS[TRAINSPACE_SETTINGS.steps[step]].component; const [isStepModified, setIsStepModified] = useState(false); const [isButtonClicked, setIsButtonClicked] = useState(false); + const [createTrainspace] = useCreateTrainspaceMutation(); const [train] = useTrainTabularMutation(); const dispatch = useAppDispatch(); const router = useRouter(); @@ -112,16 +114,20 @@ const TrainspaceStepInner = ({ if (trainspace.step < TRAINSPACE_SETTINGS.steps.length) setStep(trainspace.step); else { - train(trainspace) - .unwrap() - .then(({ trainspaceId }) => { - router.push({ pathname: `/train/${trainspaceId}` }).then(() => { - dispatch(removeTrainspaceData()); - }); + const inner = async () => { + const { trainspaceId } = await createTrainspace(trainspace).unwrap(); + await train({ + trainspaceData: trainspace, + trainspaceId: trainspaceId, + }).unwrap(); + router.push({ pathname: `/train/${trainspaceId}` }).then(() => { + dispatch(removeTrainspaceData()); }); + }; + inner(); } }, [trainspace]); - + if (!Component) return null; return ( ({ trainTabular: builder.mutation< { trainspaceId: string }, - TrainspaceData<"TRAIN"> + { trainspaceData: TrainspaceData<"TRAIN">; trainspaceId: string } >({ - query: (trainspaceData) => ({ + query: ({ trainspaceData, trainspaceId }) => ({ url: "/api/training/tabular", method: "POST", body: { + trainspace_id: trainspaceId, name: trainspaceData.name, data_source: trainspaceData.dataSource, target: trainspaceData.parameterData.targetCol, diff --git a/frontend/src/features/Train/redux/trainspaceApi.ts b/frontend/src/features/Train/redux/trainspaceApi.ts index e9dad9959..2f6e6ca8c 100644 --- a/frontend/src/features/Train/redux/trainspaceApi.ts +++ b/frontend/src/features/Train/redux/trainspaceApi.ts @@ -2,9 +2,12 @@ import { backendApi } from "@/common/redux/backendApi"; import { DATA_SOURCE, DatasetData, + DetailedTrainResultsData, FileUploadData, } from "@/features/Train/types/trainTypes"; import { fetchBaseQuery } from "@reduxjs/toolkit/dist/query"; +import { TrainspaceData as TabularTrainspaceData } from "../features/Tabular/types/tabularTypes"; +import { TrainspaceData as ImageTrainspaceData } from "../features/Image/types/imageTypes"; const trainspaceApi = backendApi .enhanceEndpoints({ addTagTypes: ["UserDatasetFilesData"] }) @@ -90,6 +93,37 @@ const trainspaceApi = backendApi return response.data; }, }), + createTrainspace: builder.mutation< + { trainspaceId: string }, + TabularTrainspaceData<"TRAIN"> | ImageTrainspaceData<"TRAIN"> + >({ + query: (trainspaceData) => ({ + url: "/api/lambda/trainspace", + method: "POST", + body: { + name: trainspaceData.name, + data_source: trainspaceData.dataSource, + dataset_data: trainspaceData.datasetData, + review_data: trainspaceData.reviewData, + // TODO: add model_id + }, + }), + }), + getTrainspace: builder.query< + { + config: unknown; + detailedTrainResultsData: DetailedTrainResultsData | undefined; + }, + { trainspaceId: string; withResults: boolean } + >({ + query: ({ trainspaceId, withResults }) => ({ + url: `/api/lambda/trainspace/${trainspaceId}`, + method: "GET", + params: { + with_results: withResults, + }, + }), + }), }), overrideExisting: true, }); @@ -98,4 +132,6 @@ export const { useGetDatasetFilesDataQuery, useUploadDatasetFileMutation, useLazyGetColumnsFromDatasetQuery, + useCreateTrainspaceMutation, + useGetTrainspaceQuery, } = trainspaceApi; diff --git a/frontend/src/features/Train/types/trainTypes.ts b/frontend/src/features/Train/types/trainTypes.ts index 886c796db..d390388c8 100644 --- a/frontend/src/features/Train/types/trainTypes.ts +++ b/frontend/src/features/Train/types/trainTypes.ts @@ -1,5 +1,6 @@ import { DATA_SOURCE_ARR } from "../constants/trainConstants"; +// keep in sync with schemas.py export type DATA_SOURCE = typeof DATA_SOURCE_ARR[number]; export type TRAIN_STATUS = @@ -16,16 +17,61 @@ export interface BaseTrainspaceData { step: number; } +// basic information, used on dashboard export interface TrainResultsData { name: string; - trainspaceId: number; + trainspaceId: string; dataSource: DATA_SOURCE; status: TRAIN_STATUS; created: Date; - step: string; uid: string; } +export type CHART_TYPE = "LINE" | "AUC/ROC" | "CONFUSION_MATRIX" + +export type Chart = TimeSeriesChart | AucRocChart | ConfusionMatrixChart + +export interface TimeSeriesMetric { + x_name: string; + y_name: string; + + x_values: number[]; + y_values: number[]; +} + +export interface TimeSeriesChart { + name: string; + + time_series: TimeSeriesMetric[] + chart_type: "LINE" + graph_index: number; +} + +export interface AucRocChart { + name: string; + + values: [number[], number[], number][]; + + chart_type: "AUC/ROC" + graph_index: number; +} + +export interface ConfusionMatrixChart { + name: string; + + values: number[][]; + + chart_type: "CONFUSION_MATRIX" + graph_index: number; +} + +// more detailed information, used when viewing a run +export interface DetailedTrainResultsData { + basic_info: TrainResultsData + + all_metrics: Chart[] +} + export interface FileUploadData { name: string; lastModified: string; diff --git a/frontend/src/pages/train/[train_space_id].tsx b/frontend/src/pages/train/[train_space_id].tsx index e228cb8b4..f22d507bc 100644 --- a/frontend/src/pages/train/[train_space_id].tsx +++ b/frontend/src/pages/train/[train_space_id].tsx @@ -2,281 +2,83 @@ import Footer from "@/common/components/Footer"; import NavbarMain from "@/common/components/NavBarMain"; import { useAppSelector } from "@/common/redux/hooks"; import { isSignedIn } from "@/common/redux/userLogin"; +import { useGetTrainspaceQuery } from "@/features/Train/redux/trainspaceApi"; +import { DetailedTrainResultsData } from "@/features/Train/types/trainTypes"; import Container from "@mui/material/Container"; import Grid from "@mui/material/Grid"; import Paper from "@mui/material/Paper"; -import dynamic from "next/dynamic"; import { useRouter } from "next/router"; -import { Data, XAxisName, YAxisName } from "plotly.js"; import React, { useEffect } from "react"; -const Plot = dynamic(() => import("react-plotly.js"), { ssr: false }); +import { + mapMetricToLinePlot, + mapMetricToAucRocPlot, + mapMetricToConfusionMatrixPlot, +} from "./metrics_to_charts"; + +const mapTrainResultsDataToCharts = ( + detailedTrainResultsData: DetailedTrainResultsData +) => { + // sort by graph_index asc and ignore negative graph indices + const sortedData = detailedTrainResultsData.all_metrics + .filter((metric) => metric.graph_index >= 0) + .sort((a, b) => a.graph_index - b.graph_index); + const charts = []; + let i = 0; + while (i < sortedData.length) { + const metric = sortedData[i]; + if (metric.chart_type === "LINE") { + charts.push(mapMetricToLinePlot(metric)); + } else if (metric.chart_type === "AUC/ROC") { + charts.push(mapMetricToAucRocPlot(metric)); + } else if (metric.chart_type === "CONFUSION_MATRIX") { + charts.push(mapMetricToConfusionMatrixPlot(metric)); + } else { + throw Error("Undefined chart type received"); + } + i += 1; + } + + return charts; +}; const TrainSpace = () => { const { train_space_id } = useRouter().query; - const data = { - success: true, - message: "Dataset trained and results outputted successfully", - dl_results: [ - { - epoch: 1, - train_time: 0.029964923858642578, - train_loss: 1.1126993695894878, - test_loss: 1.1082043647766113, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 2, - train_time: 0.0221712589263916, - train_loss: 1.1002190907796223, - test_loss: 1.100191593170166, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 3, - train_time: 0.0680840015411377, - train_loss: 1.0896958708763123, - test_loss: 1.0933666229248047, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 4, - train_time: 0.007375478744506836, - train_loss: 1.0802951455116272, - test_loss: 1.0868618488311768, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 5, - train_time: 0.008754491806030273, - train_loss: 1.071365197499593, - test_loss: 1.080164909362793, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - ], - auxiliary_outputs: { - confusion_matrix: [ - [0, 0, 6], - [0, 0, 8], - [0, 0, 6], - ], - AUC_ROC_curve_data: [ - [ - [0.0, 0.0, 0.0, 0.07142857142857142, 0.07142857142857142, 1.0], - [ - 0.0, 0.16666666666666666, 0.8333333333333334, 0.8333333333333334, - 1.0, 1.0, - ], - 0.9880952380952381, - ], - [ - [ - 0.0, 0.08333333333333333, 0.5, 0.5, 0.5833333333333334, - 0.5833333333333334, 0.6666666666666666, 0.6666666666666666, 1.0, - ], - [0.0, 0.0, 0.0, 0.75, 0.75, 0.875, 0.875, 1.0, 1.0], - 0.46875, - ], - [ - [0.0, 0.0, 0.0, 0.07142857142857142, 0.07142857142857142, 1.0], - [ - 0.0, 0.16666666666666666, 0.8333333333333334, 0.8333333333333334, - 1.0, 1.0, - ], - 0.9880952380952381, - ], - ], - }, - status: 200, - }; + const { data, isLoading, refetch, error } = useGetTrainspaceQuery({ + trainspaceId: train_space_id, + withResults: true, + }); + const user = useAppSelector((state) => state.currentUser.user); const router = useRouter(); useEffect(() => { if (router.isReady && !user) { + console.log("redirect to login"); router.replace({ pathname: "/login" }); } }, [user, router.isReady]); - if (!isSignedIn(user)) { + + if (error) { + setTimeout(() => refetch(), 3000); + } + + if (!isSignedIn(user) || !data || isLoading) { return <>; } + + const charts = mapTrainResultsDataToCharts( + data.trainspace.detailedTrainResultsData + ); return (

{train_space_id}

- - - x.epoch), - y: data.dl_results.map((x) => x["train_acc"]), - type: "scatter", - mode: "markers", - marker: { color: "red", size: 10 }, - }, - { - name: "Test accuracy", - x: data.dl_results.map((x) => x.epoch), - y: data.dl_results.map((x) => x["val/test acc"]), - type: "scatter", - mode: "markers", - marker: { color: "blue", size: 10 }, - }, - ]} - layout={{ - height: 350, - width: 525, - xaxis: { title: "Epoch Number" }, - yaxis: { title: "Accuracy" }, - title: "Train vs. Test Accuracy for your Deep Learning Model", - showlegend: true, - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - config={{ responsive: true }} - /> - - - - - x.epoch), - y: data.dl_results.map((x) => x.train_loss), - type: "scatter", - mode: "markers", - marker: { color: "red", size: 10 }, - }, - { - name: "Test loss", - x: data.dl_results.map((x) => x.epoch), - y: data.dl_results.map((x) => x.test_loss), - type: "scatter", - mode: "markers", - marker: { color: "blue", size: 10 }, - }, - ]} - layout={{ - height: 350, - width: 525, - xaxis: { title: "Epoch Number" }, - yaxis: { title: "Loss" }, - title: "Train vs. Test Loss for your Deep Learning Model", - showlegend: true, - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - config={{ responsive: true }} - /> - - - - - ({ - name: `(AUC: ${x[2]})`, - x: x[0] as number[], - y: x[1] as number[], - type: "scatter", - })) as Data[]), - ]} - layout={{ - height: 350, - width: 525, - xaxis: { title: "False Positive Rate" }, - yaxis: { title: "True Positive Rate" }, - title: "AUC/ROC Curves for your Deep Learning Model", - showlegend: true, - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - config={{ responsive: true }} - /> - - - - - - row.map((_, j) => ({ - xref: "x1" as XAxisName, - yref: "y1" as YAxisName, - x: j, - y: - (i + - data.auxiliary_outputs.confusion_matrix.length - - 1) % - data.auxiliary_outputs.confusion_matrix.length, - text: data.auxiliary_outputs.confusion_matrix[ - (i + - data.auxiliary_outputs.confusion_matrix.length - - 1) % - data.auxiliary_outputs.confusion_matrix.length - ][j].toString(), - font: { - color: - data.auxiliary_outputs.confusion_matrix[ - (i + - data.auxiliary_outputs.confusion_matrix.length - - 1) % - data.auxiliary_outputs.confusion_matrix.length - ][j] > 0 - ? "white" - : "black", - }, - showarrow: false, - })) - ) - .flat(), - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - /> - - + {charts.map((chart) => ( + + {chart} + + ))}