diff --git a/frontend/src/Home.js b/frontend/src/Home.js index 6c87faad3..8eb22b2d5 100644 --- a/frontend/src/Home.js +++ b/frontend/src/Home.js @@ -45,7 +45,7 @@ const Home = () => { // input responses const [customModelName, setCustomModelName] = useState( - `Model ${new Date().toLocaleString()}` + `Model ${new Date().toLocaleString()}`, ); const [fileURL, setFileURL] = useState(""); const [notificationPhoneNumber, setNotificationPhoneNumber] = useState(); @@ -55,11 +55,11 @@ const Home = () => { const [features, setFeatures] = useState([]); const [problemType, setProblemType] = useState(PROBLEM_TYPES[0]); const [criterion, setCriterion] = useState( - problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0] + problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0], ); const [optimizerName, setOptimizerName] = useState(OPTIMIZER_NAMES[0]); const [usingDefaultDataset, setUsingDefaultDataset] = useState( - DEFAULT_DATASETS[0] + DEFAULT_DATASETS[0], ); const [shuffle, setShuffle] = useState(BOOL_OPTIONS[1]); const [epochs, setEpochs] = useState(5); @@ -69,7 +69,7 @@ const Home = () => { uploadedColumns.map((e, i) => ({ label: e.name, value: i, - })) + })), ); const [activeColumns, setActiveColumns] = useState([]); const [beginnerMode, setBeginnerMode] = useState(true); @@ -151,7 +151,7 @@ const Home = () => { { queryText: "Criterion", options: CRITERIONS.filter((crit) => - crit.problem_type.includes(problemType.value) + crit.problem_type.includes(problemType.value), ), onChange: setCriterion, defaultValue: criterion, @@ -193,7 +193,7 @@ const Home = () => { useEffect(() => { setCriterion( - problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0] + problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0], ); setInputKey((e) => e + 1); }, [problemType]); @@ -356,7 +356,7 @@ const Home = () => { problemType={problemType} /> ), - [dlpBackendResponse, problemType] + [dlpBackendResponse, problemType], ); return ( diff --git a/frontend/src/components_old/ClassicalMLModel/ClassicalMLModel.js b/frontend/src/components_old/ClassicalMLModel/ClassicalMLModel.js index 4b3158e00..823b36d75 100644 --- a/frontend/src/components_old/ClassicalMLModel/ClassicalMLModel.js +++ b/frontend/src/components_old/ClassicalMLModel/ClassicalMLModel.js @@ -31,11 +31,11 @@ import DataTable from "react-data-table-component"; const ClassicalMLModel = () => { const [customModelName, setCustomModelName] = useState( - `Model ${new Date().toLocaleString()}` + `Model ${new Date().toLocaleString()}`, ); const [addedLayers, setAddedLayers] = useState([]); const [usingDefaultDataset, setUsingDefaultDataset] = useState( - DEFAULT_DATASETS[0] + DEFAULT_DATASETS[0], ); const [shuffle, setShuffle] = useState(BOOL_OPTIONS[1]); const [email, setEmail] = useState(""); @@ -56,7 +56,7 @@ const ClassicalMLModel = () => { uploadedColumns.map((e, i) => ({ label: e.name, value: i, - })) + })), ); const input_responses = { shuffle: shuffle?.value, @@ -138,7 +138,7 @@ const ClassicalMLModel = () => { choice="classicalml" /> ), - [dlpBackendResponse, PROBLEM_TYPES[0]] + [dlpBackendResponse, PROBLEM_TYPES[0]], ); const onClick = () => { diff --git a/frontend/src/components_old/ImageModels/ImageModels.js b/frontend/src/components_old/ImageModels/ImageModels.js index 8031061a4..1b5233bce 100644 --- a/frontend/src/components_old/ImageModels/ImageModels.js +++ b/frontend/src/components_old/ImageModels/ImageModels.js @@ -38,7 +38,7 @@ import { const ImageModels = () => { const [customModelName, setCustomModelName] = useState( - `Model ${new Date().toLocaleString()}` + `Model ${new Date().toLocaleString()}`, ); const [addedLayers, setAddedLayers] = useState(DEFAULT_IMG_LAYERS); const [trainTransforms, setTrainTransforms] = useState(DEFAULT_TRANSFORMS); @@ -121,7 +121,7 @@ const ImageModels = () => { problemType={PROBLEM_TYPES[0]} /> ), - [dlpBackendResponse, PROBLEM_TYPES[0]] + [dlpBackendResponse, PROBLEM_TYPES[0]], ); const onClick = () => { diff --git a/frontend/src/components_old/LearnMod/Exercise.js b/frontend/src/components_old/LearnMod/Exercise.js index 946f58f07..cf3239eaa 100644 --- a/frontend/src/components_old/LearnMod/Exercise.js +++ b/frontend/src/components_old/LearnMod/Exercise.js @@ -40,11 +40,11 @@ const Exercise = (props) => { const [features, setFeatures] = useState([]); const [problemType, setProblemType] = useState(PROBLEM_TYPES[0]); const [criterion, setCriterion] = useState( - problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0] + problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0], ); const [optimizerName, setOptimizerName] = useState(OPTIMIZER_NAMES[0]); const [usingDefaultDataset, setUsingDefaultDataset] = useState( - DEFAULT_DATASETS[0] + DEFAULT_DATASETS[0], ); const [shuffle, setShuffle] = useState(BOOL_OPTIONS[1]); const [epochs, setEpochs] = useState(5); @@ -54,7 +54,7 @@ const Exercise = (props) => { uploadedColumns.map((e, i) => ({ label: e.name, value: i, - })) + })), ); const [activeColumns, setActiveColumns] = useState([]); const [finalAccuracy, setFinalAccuracy] = useState(""); @@ -128,7 +128,7 @@ const Exercise = (props) => { { queryText: "Criterion", options: CRITERIONS.filter((crit) => - crit.problem_type.includes(problemType.value) + crit.problem_type.includes(problemType.value), ), onChange: setCriterion, defaultValue: criterion, @@ -182,7 +182,7 @@ const Exercise = (props) => { useEffect(() => { setCriterion( - problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0] + problemType === PROBLEM_TYPES[0] ? CRITERIONS[3] : CRITERIONS[0], ); setInputKey((e) => e + 1); }, [problemType]); @@ -220,7 +220,7 @@ const Exercise = (props) => { useEffect(() => { if (dlpBackendResponse != null) { setFinalAccuracy( - dlpBackendResponse["dl_results"][epochs - 1]["train_acc"] + dlpBackendResponse["dl_results"][epochs - 1]["train_acc"], ); } @@ -303,7 +303,7 @@ const Exercise = (props) => { /> ), - [dlpBackendResponse, problemType] + [dlpBackendResponse, problemType], ); return ( diff --git a/frontend/src/components_old/ObjectDetection/ObjectDetection.js b/frontend/src/components_old/ObjectDetection/ObjectDetection.js index f71907993..c69d07f27 100644 --- a/frontend/src/components_old/ObjectDetection/ObjectDetection.js +++ b/frontend/src/components_old/ObjectDetection/ObjectDetection.js @@ -28,7 +28,7 @@ import FilerobotImageEditor, { const ObjectDetection = () => { const [customModelName, setCustomModelName] = useState( - `Model ${new Date().toLocaleString()}` + `Model ${new Date().toLocaleString()}`, ); const [problemType, setProblemType] = useState(""); const [detectionType, setDetectionType] = useState(""); @@ -37,7 +37,7 @@ const ObjectDetection = () => { const [inputKey, setInputKey] = useState(0); const [uploadFile, setUploadFile] = useState(null); const [imageTransforms, setImageTransforms] = useState( - DEFAULT_DETECTION_TRANSFORMS + DEFAULT_DETECTION_TRANSFORMS, ); const input_responses = { @@ -71,7 +71,7 @@ const ObjectDetection = () => { choice="objectdetection" /> ), - [dlpBackendResponse, OBJECT_DETECTION_PROBLEM_TYPES[0]] + [dlpBackendResponse, OBJECT_DETECTION_PROBLEM_TYPES[0]], ); const onClick = () => { @@ -126,7 +126,7 @@ const ObjectDetection = () => { onSave={(editedImageObject) => { const file = dataURLtoFile( editedImageObject.imageBase64, - editedImageObject.fullName + editedImageObject.fullName, ); setUploadFile(file); }} diff --git a/frontend/src/features/Train/features/Image/components/ImageParametersStep.tsx b/frontend/src/features/Train/features/Image/components/ImageParametersStep.tsx index 58b51e9ad..94c50e7b3 100644 --- a/frontend/src/features/Train/features/Image/components/ImageParametersStep.tsx +++ b/frontend/src/features/Train/features/Image/components/ImageParametersStep.tsx @@ -1,6 +1,10 @@ import React, { useEffect, useMemo, useState } from "react"; import { useGetColumnsFromDatasetQuery } from "@/features/Train/redux/trainspaceApi"; import { useAppDispatch, useAppSelector } from "@/common/redux/hooks"; +import { styled } from "@mui/material/styles"; +import Tooltip, { tooltipClasses } from "@mui/material/Tooltip"; +import InfoIcon from "@mui/icons-material/Info"; + import { Autocomplete, Card, @@ -53,1172 +57,1221 @@ import { import ClientOnlyPortal from "@/common/components/ClientOnlyPortal"; import { updateImageTrainspaceData } from "../redux/imageActions"; -const ImageParametersStep = ({ - renderStepperButtons, - setIsModified, +const HtmlTooltip = styled( + ({ + className, + title, + children, + ...props }: { - renderStepperButtons: ( - submitTrainspace: (data: TrainspaceData<"PARAMETERS">) => void - ) => React.ReactNode; - setIsModified: React.Dispatch>; - }) => { - const trainspace = useAppSelector( - (state) => - state.trainspace.current as TrainspaceData<"PARAMETERS"> | undefined - ); - const dispatch = useAppDispatch(); - const { - handleSubmit, - formState: { errors, isDirty }, - control, - } = useForm({ - defaultValues: { - criterion: trainspace?.parameterData?.criterion ?? "CELOSS", - optimizerName: trainspace?.parameterData?.optimizerName ?? "SGD", - shuffle: trainspace?.parameterData?.shuffle ?? true, - epochs: trainspace?.parameterData?.epochs ?? 5, - batchSize: trainspace?.parameterData?.batchSize ?? 20, - trainTransforms: trainspace?.parameterData?.trainTransforms ?? [ - { - value: "GRAYSCALE", - parameters: [], - }, - { - value: "TO_TENSOR", - parameters: [], - }, - { - value: "RESIZE", - parameters: [32, 32], - }, - ], - testTransforms: trainspace?.parameterData?.testTransforms ?? [ - { - value: "GRAYSCALE", - parameters: [], - }, - { - value: "TO_TENSOR", - parameters: [], - }, - { - value: "RESIZE", - parameters: [32, 32], - }, - ], - layers: trainspace?.parameterData?.layers ?? [ - { - value: "CONV2D", - parameters: [1, 5, 3, 1, 1], - }, - { - value: "MAXPOOL2D", - parameters: [3, 1], - }, - { - value: "FLATTEN", - parameters: [1, -1], - }, - { - value: "LINEAR", - parameters: [500, 10], - }, - { - value: "SIGMOID", - parameters: [], - }, - ], - }, - }); - useEffect(() => { - setIsModified(isDirty); - }, [isDirty]); - if (!trainspace) return <>; - return ( - - ( - - {STEP_SETTINGS["PARAMETERS"].criterions.map((criterion) => ( - - {criterion.label} - - ))} - - )} - /> - ( - - {STEP_SETTINGS["PARAMETERS"].optimizers.map((optimizer) => ( - - {optimizer.label} - - ))} - - )} - /> - - ( - - } - label="Shuffle" - /> - )} - /> - - ( - - )} - /> + className?: string; + children: React.ReactElement; + title: React.ReactNode; + }) => ( + + {children} + + ) +)(({ theme }) => ({ + [`& .${tooltipClasses.tooltip}`]: { + backgroundColor: "rgba(255, 255, 255, 0.95)", + color: "rgba(0, 0, 0, 0.87)", + maxWidth: 220, + fontSize: theme.typography.pxToRem(12), + border: "none", + }, +})); + +const ImageParametersStep = ({ + renderStepperButtons, + setIsModified, +}: { + renderStepperButtons: ( + submitTrainspace: (data: TrainspaceData<"PARAMETERS">) => void + ) => React.ReactNode; + setIsModified: React.Dispatch>; +}) => { + const trainspace = useAppSelector( + (state) => + state.trainspace.current as TrainspaceData<"PARAMETERS"> | undefined + ); + const dispatch = useAppDispatch(); + const { + handleSubmit, + formState: { errors, isDirty }, + control, + } = useForm({ + defaultValues: { + criterion: trainspace?.parameterData?.criterion ?? "CELOSS", + optimizerName: trainspace?.parameterData?.optimizerName ?? "SGD", + shuffle: trainspace?.parameterData?.shuffle ?? true, + epochs: trainspace?.parameterData?.epochs ?? 5, + batchSize: trainspace?.parameterData?.batchSize ?? 20, + trainTransforms: trainspace?.parameterData?.trainTransforms ?? [ + { + value: "GRAYSCALE", + parameters: [], + }, + { + value: "TO_TENSOR", + parameters: [], + }, + { + value: "RESIZE", + parameters: [32, 32], + }, + ], + testTransforms: trainspace?.parameterData?.testTransforms ?? [ + { + value: "GRAYSCALE", + parameters: [], + }, + { + value: "TO_TENSOR", + parameters: [], + }, + { + value: "RESIZE", + parameters: [32, 32], + }, + ], + layers: trainspace?.parameterData?.layers ?? [ + { + value: "CONV2D", + parameters: [1, 5, 3, 1, 1], + }, + { + value: "MAXPOOL2D", + parameters: [3, 1], + }, + { + value: "FLATTEN", + parameters: [1, -1], + }, + { + value: "LINEAR", + parameters: [500, 10], + }, + { + value: "SIGMOID", + parameters: [], + }, + ], + }, + }); + useEffect(() => { + setIsModified(isDirty); + }, [isDirty]); + if (!trainspace) return <>; + return ( + + ( + + {STEP_SETTINGS["PARAMETERS"].criterions.map((criterion) => ( + + {criterion.label} + + ))} + + )} + /> + ( + + {STEP_SETTINGS["PARAMETERS"].optimizers.map((optimizer) => ( + + {optimizer.label} + + ))} + + )} + /> + ( - + } + label="Shuffle" /> )} /> - - - - {renderStepperButtons((trainspaceData) => { - handleSubmit((data) => { - dispatch( - updateImageTrainspaceData({ - current: { - ...trainspaceData, - parameterData: data, - reviewData: undefined, - }, - stepLabel: "PARAMETERS", - }) - ); - })(); - })} - - ); - }; - - const LayersDnd = ({ - control, - errors, - }: { - control: Control; - errors: FieldErrors; - }) => { - const { fields, move, insert, remove } = useFieldArray({ - control: control, - name: "layers", - }); - const genLayerInvIds = () => - Object.fromEntries( - STEP_SETTINGS.PARAMETERS.layerValues.map((layerValue) => [ - layerValue, - Math.floor(Math.random() * Date.now()), - ]) - ); - const [layerInvIds, setLayerInvIds] = useState<{ - [layerValue: string]: number; - }>(genLayerInvIds()); - const [dndActive, setDndActive] = useState(null); - const [invHovering, setInvHovering] = useState(false); - - const dndActiveItem = useMemo(() => { - if (!dndActive) return; - if (dndActive.data.current && "inventory" in dndActive.data.current) { - const value = dndActive.data.current.inventory - .value as (typeof STEP_SETTINGS.PARAMETERS.layerValues)[number]; - return { - id: layerInvIds[value], - value: value, - parameters: STEP_SETTINGS.PARAMETERS.layers[value].parameters.map( - () => "" - ) as ""[], - }; - } else if (dndActive.data.current && "sortable" in dndActive.data.current) { - return fields[dndActive.data.current.sortable.index]; - } - }, [dndActive]); - const sensors = useSensors( - useCustomPointerSensor(), - useCustomKeyboardSensor({ coordinateGetter: sortableKeyboardCoordinates }) + + ( + + )} + /> + ( + + )} + /> + + + + {renderStepperButtons((trainspaceData) => { + handleSubmit((data) => { + dispatch( + updateImageTrainspaceData({ + current: { + ...trainspaceData, + parameterData: data, + reviewData: undefined, + }, + stepLabel: "PARAMETERS", + }) + ); + })(); + })} + + ); +}; + +const LayersDnd = ({ + control, + errors, +}: { + control: Control; + errors: FieldErrors; +}) => { + const { fields, move, insert, remove } = useFieldArray({ + control: control, + name: "layers", + }); + const genLayerInvIds = () => + Object.fromEntries( + STEP_SETTINGS.PARAMETERS.layerValues.map((layerValue) => [ + layerValue, + Math.floor(Math.random() * Date.now()), + ]) ); - return ( - { - if (dndActive !== null) return; - setDndActive(active); - }} - onDragOver={({ over }) => { - if (!over || !over.data.current) { - setInvHovering(false); - return; - } - if (!invHovering) { - setInvHovering(true); - } - }} - onDragEnd={({ active, over }) => { - if (dndActive && dndActive.data.current && dndActiveItem) { - if ( - "inventory" in dndActive.data.current && - over?.data.current && - "sortable" in over.data.current - ) { - insert(over.data.current.sortable.index, { - value: dndActiveItem.value, - parameters: dndActiveItem.parameters as number[], - }); - } else if ( - "sortable" in dndActive.data.current && - over?.data.current && - "sortable" in over.data.current - ) { - move( - fields.findIndex((field) => field.id === active.id), - fields.findIndex((field) => field.id === over.id) - ); - } - } - setLayerInvIds(genLayerInvIds()); + const [layerInvIds, setLayerInvIds] = useState<{ + [layerValue: string]: number; + }>(genLayerInvIds()); + const [dndActive, setDndActive] = useState(null); + const [invHovering, setInvHovering] = useState(false); + + const dndActiveItem = useMemo(() => { + if (!dndActive) return; + if (dndActive.data.current && "inventory" in dndActive.data.current) { + const value = dndActive.data.current.inventory + .value as (typeof STEP_SETTINGS.PARAMETERS.layerValues)[number]; + return { + id: layerInvIds[value], + value: value, + parameters: STEP_SETTINGS.PARAMETERS.layers[value].parameters.map( + () => "" + ) as ""[], + }; + } else if (dndActive.data.current && "sortable" in dndActive.data.current) { + return fields[dndActive.data.current.sortable.index]; + } + }, [dndActive]); + const sensors = useSensors( + useCustomPointerSensor(), + useCustomKeyboardSensor({ coordinateGetter: sortableKeyboardCoordinates }) + ); + return ( + { + if (dndActive !== null) return; + setDndActive(active); + }} + onDragOver={({ over }) => { + if (!over || !over.data.current) { setInvHovering(false); - setDndActive(null); - }} - onDragCancel={({ active }) => { - if (active.data.current && "inventory" in active.data.current) { - const index = fields.findIndex((field) => field.id === active.id); - if (index !== -1) { - remove(fields.findIndex((field) => field.id === active.id)); - } + return; + } + if (!invHovering) { + setInvHovering(true); + } + }} + onDragEnd={({ active, over }) => { + if (dndActive && dndActive.data.current && dndActiveItem) { + if ( + "inventory" in dndActive.data.current && + over?.data.current && + "sortable" in over.data.current + ) { + insert(over.data.current.sortable.index, { + value: dndActiveItem.value, + parameters: dndActiveItem.parameters as number[], + }); + } else if ( + "sortable" in dndActive.data.current && + over?.data.current && + "sortable" in over.data.current + ) { + move( + fields.findIndex((field) => field.id === active.id), + fields.findIndex((field) => field.id === over.id) + ); } - setLayerInvIds(genLayerInvIds()); - setInvHovering(false); - setDndActive(null); - }} - > - - - - Layers - - - {STEP_SETTINGS.PARAMETERS.layerValues.map((value) => ( - - ))} - + } + setLayerInvIds(genLayerInvIds()); + setInvHovering(false); + setDndActive(null); + }} + onDragCancel={({ active }) => { + if (active.data.current && "inventory" in active.data.current) { + const index = fields.findIndex((field) => field.id === active.id); + if (index !== -1) { + remove(fields.findIndex((field) => field.id === active.id)); + } + } + setLayerInvIds(genLayerInvIds()); + setInvHovering(false); + setDndActive(null); + }} + > + + + + Layers + + + {STEP_SETTINGS.PARAMETERS.layerValues.map((value) => ( + + ))} - - - - - + + + + + + {fields.length > 0 ? ( + [ dndActiveItem && dndActive?.data.current && - "inventory" in dndActive.data.current - ? [dndActiveItem, ...fields] - : fields - } - strategy={verticalListSortingStrategy} - > - {fields.length > 0 ? ( - [ - dndActiveItem && - dndActive?.data.current && - "inventory" in dndActive.data.current && - invHovering ? ( - - ) : null, - ...fields.map((field, index) => ( - remove(index), - }} - /> - )), - ] - ) : ( - This is Unimplemented - )} - - - - - - - {dndActiveItem ? ( - dndActive?.data.current && "sortable" in dndActive.data.current ? ( - - ) : ( - - ) - ) : null} - - - - ); + "inventory" in dndActive.data.current && + invHovering ? ( + + ) : null, + ...fields.map((field, index) => ( + remove(index), + }} + /> + )), + ] + ) : ( + This is Unimplemented + )} + + + + + + + {dndActiveItem ? ( + dndActive?.data.current && "sortable" in dndActive.data.current ? ( + + ) : ( + + ) + ) : null} + + + + ); +}; + +const LayerComponent = ({ + id, + data, + formProps, +}: { + id?: string | number; + data: ParameterData["layers"][number]; + formProps?: { + index: number; + control: Control; + errors: FieldErrors; + remove?: () => void; }; - - const LayerComponent = ({ - id, - data, - formProps, - }: { - id?: string | number; - data: ParameterData["layers"][number]; - formProps?: { - index: number; - control: Control; - errors: FieldErrors; - remove?: () => void; - }; - }) => { - const { - attributes, - listeners, - setNodeRef, - isDragging, - transform, - transition, - } = id - ? useSortable({ id }) - : { - attributes: undefined, - listeners: undefined, - setNodeRef: undefined, - isDragging: undefined, - transform: undefined, - transition: undefined, - }; - const style = transform - ? { - opacity: isDragging ? 0.4 : undefined, - transform: CSS.Transform.toString(transform), - transition: transition, - } - : undefined; - return ( -
- { + const { + attributes, + listeners, + setNodeRef, + isDragging, + transform, + transition, + } = id + ? useSortable({ id }) + : { + attributes: undefined, + listeners: undefined, + setNodeRef: undefined, + isDragging: undefined, + transform: undefined, + transition: undefined, + }; + const style = transform + ? { + opacity: isDragging ? 0.4 : undefined, + transform: CSS.Transform.toString(transform), + transition: transition, + } + : undefined; + return ( +
+ + - + + {STEP_SETTINGS.PARAMETERS.layers[data.value].label} + + {STEP_SETTINGS.PARAMETERS.layers[data.value].description} + + } > - - {STEP_SETTINGS.PARAMETERS.layers[data.value].label} - - Info + + + {STEP_SETTINGS.PARAMETERS.layers[data.value].label} + + + } > - } - > - {STEP_SETTINGS.PARAMETERS.layers[data.value].parameters.map( - (parameter, index) => ( -
- {formProps ? ( - ( - - )} - /> - ) : ( - - )} -
- ) - )} -
-
- - - -
+ {STEP_SETTINGS.PARAMETERS.layers[data.value].parameters.map( + (parameter, index) => ( +
+ {formProps ? ( + ( + + )} + /> + ) : ( + + )} +
+ ) + )}
+
+ + + +
-
-
- ); - }; - - const LayerInventoryComponent = ({ - id, - value, - }: { - id: number; - value: (typeof STEP_SETTINGS.PARAMETERS.layerValues)[number]; - }) => { - const { attributes, listeners, isDragging, setNodeRef } = useDraggable({ - id: id, - data: { - inventory: { - value, - }, + +
+
+ ); +}; + +const LayerInventoryComponent = ({ + id, + value, +}: { + id: number; + value: (typeof STEP_SETTINGS.PARAMETERS.layerValues)[number]; +}) => { + const { attributes, listeners, isDragging, setNodeRef } = useDraggable({ + id: id, + data: { + inventory: { + value, }, - }); - - const style = { - opacity: isDragging ? 0.4 : undefined, - }; - return ( -
- - {STEP_SETTINGS.PARAMETERS.layers[value].label} - -
- ); + }, + }); + + const style = { + opacity: isDragging ? 0.4 : undefined, }; + return ( +
+ + {STEP_SETTINGS.PARAMETERS.layers[value].label} + +
+ ); +}; - const TrainTransformsDnd = ({ - control, - errors, - }: { - control: Control; - errors: FieldErrors; - }) => { - const { fields, move, insert, remove } = useFieldArray({ - control: control, - name: "trainTransforms", - }); - const genTransformInvIds = () => - Object.fromEntries( - STEP_SETTINGS.PARAMETERS.transformValues.map((transformValue) => [ - transformValue, - Math.floor(Math.random() * Date.now()), - ]) - ); - const [transformInvIds, setTransformInvIds] = useState<{ - [transformValue: string]: number; - }>(genTransformInvIds()); - const [dndActive, setDndActive] = useState(null); - const [invHovering, setInvHovering] = useState(false); - - const dndActiveItem = useMemo(() => { - if (!dndActive) return; - if (dndActive.data.current && "inventory" in dndActive.data.current) { - const value = dndActive.data.current.inventory - .value as (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; - return { - id: transformInvIds[value], - value: value, - parameters: STEP_SETTINGS.PARAMETERS.transforms[value].parameters.map( - () => "" - ) as ""[], - }; - } else if (dndActive.data.current && "sortable" in dndActive.data.current) { - return fields[dndActive.data.current.sortable.index]; - } - }, [dndActive]); - const sensors = useSensors( - useCustomPointerSensor(), - useCustomKeyboardSensor({ coordinateGetter: sortableKeyboardCoordinates }) +const TrainTransformsDnd = ({ + control, + errors, +}: { + control: Control; + errors: FieldErrors; +}) => { + const { fields, move, insert, remove } = useFieldArray({ + control: control, + name: "trainTransforms", + }); + const genTransformInvIds = () => + Object.fromEntries( + STEP_SETTINGS.PARAMETERS.transformValues.map((transformValue) => [ + transformValue, + Math.floor(Math.random() * Date.now()), + ]) ); - return ( - { - if (dndActive !== null) return; - setDndActive(active); - }} - onDragOver={({ over }) => { - if (!over || !over.data.current) { - setInvHovering(false); - return; - } - if (!invHovering) { - setInvHovering(true); - } - }} - onDragEnd={({ active, over }) => { - if (dndActive && dndActive.data.current && dndActiveItem) { - if ( - "inventory" in dndActive.data.current && - over?.data.current && - "sortable" in over.data.current - ) { - insert(over.data.current.sortable.index, { - value: dndActiveItem.value, - parameters: dndActiveItem.parameters as number[], - }); - } else if ( - "sortable" in dndActive.data.current && - over?.data.current && - "sortable" in over.data.current - ) { - move( - fields.findIndex((field) => field.id === active.id), - fields.findIndex((field) => field.id === over.id) - ); - } - } - setTransformInvIds(genTransformInvIds()); + const [transformInvIds, setTransformInvIds] = useState<{ + [transformValue: string]: number; + }>(genTransformInvIds()); + const [dndActive, setDndActive] = useState(null); + const [invHovering, setInvHovering] = useState(false); + + const dndActiveItem = useMemo(() => { + if (!dndActive) return; + if (dndActive.data.current && "inventory" in dndActive.data.current) { + const value = dndActive.data.current.inventory + .value as (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; + return { + id: transformInvIds[value], + value: value, + parameters: STEP_SETTINGS.PARAMETERS.transforms[value].parameters.map( + () => "" + ) as ""[], + }; + } else if (dndActive.data.current && "sortable" in dndActive.data.current) { + return fields[dndActive.data.current.sortable.index]; + } + }, [dndActive]); + const sensors = useSensors( + useCustomPointerSensor(), + useCustomKeyboardSensor({ coordinateGetter: sortableKeyboardCoordinates }) + ); + return ( + { + if (dndActive !== null) return; + setDndActive(active); + }} + onDragOver={({ over }) => { + if (!over || !over.data.current) { setInvHovering(false); - setDndActive(null); - }} - onDragCancel={({ active }) => { - if (active.data.current && "inventory" in active.data.current) { - const index = fields.findIndex((field) => field.id === active.id); - if (index !== -1) { - remove(fields.findIndex((field) => field.id === active.id)); - } + return; + } + if (!invHovering) { + setInvHovering(true); + } + }} + onDragEnd={({ active, over }) => { + if (dndActive && dndActive.data.current && dndActiveItem) { + if ( + "inventory" in dndActive.data.current && + over?.data.current && + "sortable" in over.data.current + ) { + insert(over.data.current.sortable.index, { + value: dndActiveItem.value, + parameters: dndActiveItem.parameters as number[], + }); + } else if ( + "sortable" in dndActive.data.current && + over?.data.current && + "sortable" in over.data.current + ) { + move( + fields.findIndex((field) => field.id === active.id), + fields.findIndex((field) => field.id === over.id) + ); } - setTransformInvIds(genTransformInvIds()); - setInvHovering(false); - setDndActive(null); - }} - > - - - + } + setTransformInvIds(genTransformInvIds()); + setInvHovering(false); + setDndActive(null); + }} + onDragCancel={({ active }) => { + if (active.data.current && "inventory" in active.data.current) { + const index = fields.findIndex((field) => field.id === active.id); + if (index !== -1) { + remove(fields.findIndex((field) => field.id === active.id)); + } + } + setTransformInvIds(genTransformInvIds()); + setInvHovering(false); + setDndActive(null); + }} + > + + + Train Transforms - - + + {STEP_SETTINGS.PARAMETERS.transformValues.map((value) => ( ))} - - - - - + + + + + {fields.length > 0 ? ( + [ dndActiveItem && dndActive?.data.current && - "inventory" in dndActive.data.current - ? [dndActiveItem, ...fields] - : fields - } - strategy={verticalListSortingStrategy} - > - {fields.length > 0 ? ( - [ - dndActiveItem && - dndActive?.data.current && - "inventory" in dndActive.data.current && - invHovering ? ( - - ) : null, - ...fields.map((field, index) => ( - remove(index), - }} - /> - )), - ] - ) : ( - This is Unimplemented - )} - - - - - - - {dndActiveItem ? ( - dndActive?.data.current && "sortable" in dndActive.data.current ? ( - - ) : ( - - ) - ) : null} - - - - ); + "inventory" in dndActive.data.current && + invHovering ? ( + + ) : null, + ...fields.map((field, index) => ( + remove(index), + }} + /> + )), + ] + ) : ( + This is Unimplemented + )} + + + + + + + {dndActiveItem ? ( + dndActive?.data.current && "sortable" in dndActive.data.current ? ( + + ) : ( + + ) + ) : null} + + + + ); +}; + +const TrainTransformComponent = ({ + id, + data, + formProps, +}: { + id?: string | number; + data: ParameterData["trainTransforms"][number]; + formProps?: { + index: number; + control: Control; + errors: FieldErrors; + remove?: () => void; }; - - const TrainTransformComponent = ({ - id, - data, - formProps, - }: { - id?: string | number; - data: ParameterData["trainTransforms"][number]; - formProps?: { - index: number; - control: Control; - errors: FieldErrors; - remove?: () => void; - }; - }) => { - const { - attributes, - listeners, - setNodeRef, - isDragging, - transform, - transition, - } = id - ? useSortable({ id }) - : { - attributes: undefined, - listeners: undefined, - setNodeRef: undefined, - isDragging: undefined, - transform: undefined, - transition: undefined, - }; - const style = transform - ? { - opacity: isDragging ? 0.4 : undefined, - transform: CSS.Transform.toString(transform), - transition: transition, - } - : undefined; - return ( -
- { + const { + attributes, + listeners, + setNodeRef, + isDragging, + transform, + transition, + } = id + ? useSortable({ id }) + : { + attributes: undefined, + listeners: undefined, + setNodeRef: undefined, + isDragging: undefined, + transform: undefined, + transition: undefined, + }; + const style = transform + ? { + opacity: isDragging ? 0.4 : undefined, + transform: CSS.Transform.toString(transform), + transition: transition, + } + : undefined; + return ( +
+ + - - - {STEP_SETTINGS.PARAMETERS.transforms[data.value].label} - - - } - > - {STEP_SETTINGS.PARAMETERS.transforms[data.value].parameters.map( - (parameter, index) => ( -
- {formProps ? ( - ( - - )} - /> - ) : ( - - )} -
- ) - )} -
-
- - - -
+ + {STEP_SETTINGS.PARAMETERS.transforms[data.value].label} + + + } + > + {STEP_SETTINGS.PARAMETERS.transforms[data.value].parameters.map( + (parameter, index) => ( +
+ {formProps ? ( + ( + + )} + /> + ) : ( + + )} +
+ ) + )}
+
+ + + +
-
-
- ); - }; - - const TrainTransformInventoryComponent = ({ - id, - value, - }: { - id: number; - value: (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; - }) => { - const { attributes, listeners, isDragging, setNodeRef } = useDraggable({ - id: id, - data: { - inventory: { - value, - }, + +
+
+ ); +}; + +const TrainTransformInventoryComponent = ({ + id, + value, +}: { + id: number; + value: (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; +}) => { + const { attributes, listeners, isDragging, setNodeRef } = useDraggable({ + id: id, + data: { + inventory: { + value, }, - }); - - const style = { - opacity: isDragging ? 0.4 : undefined, - }; - return ( -
- - {STEP_SETTINGS.PARAMETERS.transforms[value].label} - -
- ); + }, + }); + + const style = { + opacity: isDragging ? 0.4 : undefined, }; + return ( +
+ + {STEP_SETTINGS.PARAMETERS.transforms[value].label} + +
+ ); +}; - const TestTransformsDnd = ({ - control, - errors, - }: { - control: Control; - errors: FieldErrors; - }) => { - const { fields, move, insert, remove } = useFieldArray({ - control: control, - name: "testTransforms", - }); - const genTransformInvIds = () => - Object.fromEntries( - STEP_SETTINGS.PARAMETERS.transformValues.map((transformValue) => [ - transformValue, - Math.floor(Math.random() * Date.now()), - ]) - ); - const [transformInvIds, setTransformInvIds] = useState<{ - [transformValue: string]: number; - }>(genTransformInvIds()); - const [dndActive, setDndActive] = useState(null); - const [invHovering, setInvHovering] = useState(false); - - const dndActiveItem = useMemo(() => { - if (!dndActive) return; - if (dndActive.data.current && "inventory" in dndActive.data.current) { - const value = dndActive.data.current.inventory - .value as (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; - return { - id: transformInvIds[value], - value: value, - parameters: STEP_SETTINGS.PARAMETERS.transforms[value].parameters.map( - () => "" - ) as ""[], - }; - } else if (dndActive.data.current && "sortable" in dndActive.data.current) { - return fields[dndActive.data.current.sortable.index]; - } - }, [dndActive]); - const sensors = useSensors( - useCustomPointerSensor(), - useCustomKeyboardSensor({ coordinateGetter: sortableKeyboardCoordinates }) +const TestTransformsDnd = ({ + control, + errors, +}: { + control: Control; + errors: FieldErrors; +}) => { + const { fields, move, insert, remove } = useFieldArray({ + control: control, + name: "testTransforms", + }); + const genTransformInvIds = () => + Object.fromEntries( + STEP_SETTINGS.PARAMETERS.transformValues.map((transformValue) => [ + transformValue, + Math.floor(Math.random() * Date.now()), + ]) ); - return ( - { - if (dndActive !== null) return; - setDndActive(active); - }} - onDragOver={({ over }) => { - if (!over || !over.data.current) { - setInvHovering(false); - return; - } - if (!invHovering) { - setInvHovering(true); - } - }} - onDragEnd={({ active, over }) => { - if (dndActive && dndActive.data.current && dndActiveItem) { - if ( - "inventory" in dndActive.data.current && - over?.data.current && - "sortable" in over.data.current - ) { - insert(over.data.current.sortable.index, { - value: dndActiveItem.value, - parameters: dndActiveItem.parameters as number[], - }); - } else if ( - "sortable" in dndActive.data.current && - over?.data.current && - "sortable" in over.data.current - ) { - move( - fields.findIndex((field) => field.id === active.id), - fields.findIndex((field) => field.id === over.id) - ); - } - } - setTransformInvIds(genTransformInvIds()); + const [transformInvIds, setTransformInvIds] = useState<{ + [transformValue: string]: number; + }>(genTransformInvIds()); + const [dndActive, setDndActive] = useState(null); + const [invHovering, setInvHovering] = useState(false); + + const dndActiveItem = useMemo(() => { + if (!dndActive) return; + if (dndActive.data.current && "inventory" in dndActive.data.current) { + const value = dndActive.data.current.inventory + .value as (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; + return { + id: transformInvIds[value], + value: value, + parameters: STEP_SETTINGS.PARAMETERS.transforms[value].parameters.map( + () => "" + ) as ""[], + }; + } else if (dndActive.data.current && "sortable" in dndActive.data.current) { + return fields[dndActive.data.current.sortable.index]; + } + }, [dndActive]); + const sensors = useSensors( + useCustomPointerSensor(), + useCustomKeyboardSensor({ coordinateGetter: sortableKeyboardCoordinates }) + ); + return ( + { + if (dndActive !== null) return; + setDndActive(active); + }} + onDragOver={({ over }) => { + if (!over || !over.data.current) { setInvHovering(false); - setDndActive(null); - }} - onDragCancel={({ active }) => { - if (active.data.current && "inventory" in active.data.current) { - const index = fields.findIndex((field) => field.id === active.id); - if (index !== -1) { - remove(fields.findIndex((field) => field.id === active.id)); - } + return; + } + if (!invHovering) { + setInvHovering(true); + } + }} + onDragEnd={({ active, over }) => { + if (dndActive && dndActive.data.current && dndActiveItem) { + if ( + "inventory" in dndActive.data.current && + over?.data.current && + "sortable" in over.data.current + ) { + insert(over.data.current.sortable.index, { + value: dndActiveItem.value, + parameters: dndActiveItem.parameters as number[], + }); + } else if ( + "sortable" in dndActive.data.current && + over?.data.current && + "sortable" in over.data.current + ) { + move( + fields.findIndex((field) => field.id === active.id), + fields.findIndex((field) => field.id === over.id) + ); } - setTransformInvIds(genTransformInvIds()); - setInvHovering(false); - setDndActive(null); - }} - > - - - + } + setTransformInvIds(genTransformInvIds()); + setInvHovering(false); + setDndActive(null); + }} + onDragCancel={({ active }) => { + if (active.data.current && "inventory" in active.data.current) { + const index = fields.findIndex((field) => field.id === active.id); + if (index !== -1) { + remove(fields.findIndex((field) => field.id === active.id)); + } + } + setTransformInvIds(genTransformInvIds()); + setInvHovering(false); + setDndActive(null); + }} + > + + + Test Transforms - - + + {STEP_SETTINGS.PARAMETERS.transformValues.map((value) => ( ))} - - - - - + + + + + {fields.length > 0 ? ( + [ dndActiveItem && dndActive?.data.current && - "inventory" in dndActive.data.current - ? [dndActiveItem, ...fields] - : fields - } - strategy={verticalListSortingStrategy} - > - {fields.length > 0 ? ( - [ - dndActiveItem && - dndActive?.data.current && - "inventory" in dndActive.data.current && - invHovering ? ( - - ) : null, - ...fields.map((field, index) => ( - remove(index), - }} - /> - )), - ] - ) : ( - This is Unimplemented - )} - - - - - - - {dndActiveItem ? ( - dndActive?.data.current && "sortable" in dndActive.data.current ? ( - - ) : ( - - ) - ) : null} - - - - ); + "inventory" in dndActive.data.current && + invHovering ? ( + + ) : null, + ...fields.map((field, index) => ( + remove(index), + }} + /> + )), + ] + ) : ( + This is Unimplemented + )} + + + + + + + {dndActiveItem ? ( + dndActive?.data.current && "sortable" in dndActive.data.current ? ( + + ) : ( + + ) + ) : null} + + + + ); +}; + +const TestTransformComponent = ({ + id, + data, + formProps, +}: { + id?: string | number; + data: ParameterData["testTransforms"][number]; + formProps?: { + index: number; + control: Control; + errors: FieldErrors; + remove?: () => void; }; - - const TestTransformComponent = ({ - id, - data, - formProps, - }: { - id?: string | number; - data: ParameterData["testTransforms"][number]; - formProps?: { - index: number; - control: Control; - errors: FieldErrors; - remove?: () => void; - }; - }) => { - const { - attributes, - listeners, - setNodeRef, - isDragging, - transform, - transition, - } = id - ? useSortable({ id }) - : { - attributes: undefined, - listeners: undefined, - setNodeRef: undefined, - isDragging: undefined, - transform: undefined, - transition: undefined, - }; - const style = transform - ? { - opacity: isDragging ? 0.4 : undefined, - transform: CSS.Transform.toString(transform), - transition: transition, - } - : undefined; - return ( -
- { + const { + attributes, + listeners, + setNodeRef, + isDragging, + transform, + transition, + } = id + ? useSortable({ id }) + : { + attributes: undefined, + listeners: undefined, + setNodeRef: undefined, + isDragging: undefined, + transform: undefined, + transition: undefined, + }; + const style = transform + ? { + opacity: isDragging ? 0.4 : undefined, + transform: CSS.Transform.toString(transform), + transition: transition, + } + : undefined; + return ( +
+ + - - - {STEP_SETTINGS.PARAMETERS.transforms[data.value].label} - - - } - > - {STEP_SETTINGS.PARAMETERS.transforms[data.value].parameters.map( - (parameter, index) => ( -
- {formProps ? ( - ( - - )} - /> - ) : ( - - )} -
- ) - )} -
-
- - - -
+ + {STEP_SETTINGS.PARAMETERS.transforms[data.value].label} + + + } + > + {STEP_SETTINGS.PARAMETERS.transforms[data.value].parameters.map( + (parameter, index) => ( +
+ {formProps ? ( + ( + + )} + /> + ) : ( + + )} +
+ ) + )}
+
+ + + +
-
-
- ); - }; - - const TestTransformInventoryComponent = ({ - id, - value, - }: { - id: number; - value: (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; - }) => { - const { attributes, listeners, isDragging, setNodeRef } = useDraggable({ - id: id, - data: { - inventory: { - value, - }, + +
+
+ ); +}; + +const TestTransformInventoryComponent = ({ + id, + value, +}: { + id: number; + value: (typeof STEP_SETTINGS.PARAMETERS.transformValues)[number]; +}) => { + const { attributes, listeners, isDragging, setNodeRef } = useDraggable({ + id: id, + data: { + inventory: { + value, }, - }); - - const style = { - opacity: isDragging ? 0.4 : undefined, - }; - return ( -
- - {STEP_SETTINGS.PARAMETERS.transforms[value].label} - -
- ); + }, + }); + + const style = { + opacity: isDragging ? 0.4 : undefined, }; + return ( +
+ + {STEP_SETTINGS.PARAMETERS.transforms[value].label} + +
+ ); +}; export default ImageParametersStep; diff --git a/frontend/src/features/Train/features/Image/constants/imageConstants.ts b/frontend/src/features/Train/features/Image/constants/imageConstants.ts index 2676bf504..b4aed9d9c 100644 --- a/frontend/src/features/Train/features/Image/constants/imageConstants.ts +++ b/frontend/src/features/Train/features/Image/constants/imageConstants.ts @@ -82,6 +82,7 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "The `CONV2d` function applies a filter to input data in a sliding window manner. By performing element-wise multiplication and sum of the overlapping portions, it captures local spatial patterns and features, enabling applications such as image recognition, object detection, and semantic segmentation in computer vision tasks.", }, MAXPOOL2D: { label: "MaxPool2d", @@ -102,6 +103,7 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "MaxPool2d function reduces the size of input data by dividing it into non-overlapping rectangular regions and selecting the maximum value from each region. This downsampling operation preserves important features while decreasing spatial dimensions, making it beneficial for tasks like image classification and extracting spatial characteristics." }, FLATTEN: { label: "Flatten", @@ -122,6 +124,7 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "The flatten operation takes a two-dimensional image representation and transforms it into a one-dimensional vector. This process unravels the image structure by concatenating the rows or columns of pixels, creating a linear sequence of values. By doing so, it allows the network to process the image as a simple list of numbers, facilitating tasks such as image classification or object detection in neural networks." }, LINEAR: { label: "Linear", @@ -142,11 +145,13 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "A linear layer in an image dataset takes the flattened image representation and applies a learned linear transformation to map the input values to a new set of values, allowing the network to learn complex relationships for tasks like image classification or object recognition." }, SIGMOID: { label: "Sigmoid", objectName: "nn.Sigmoid", parameters: [], + description: "The sigmoid function takes the input values, typically the output of a linear layer, and applies a mathematical function that compresses them into a range between 0 and 1. This transformation is useful for interpreting the output as probabilities, where values closer to 1 indicate higher confidence in the presence of a particular feature or class in the image." }, }, transformValues: ["GAUSSIAN_BLUR", "GRAYSCALE", "NORMALIZE", "RANDOM_HORIZONTAL_FLIP", "RANDOM_VERTICAL_FLIP", "RESIZE", "TO_TENSOR"], diff --git a/frontend/src/features/Train/features/Tabular/components/TabularParametersStep.tsx b/frontend/src/features/Train/features/Tabular/components/TabularParametersStep.tsx index 2bb3e073b..a98d31d5c 100644 --- a/frontend/src/features/Train/features/Tabular/components/TabularParametersStep.tsx +++ b/frontend/src/features/Train/features/Tabular/components/TabularParametersStep.tsx @@ -21,6 +21,9 @@ import { TextField, Typography, } from "@mui/material"; +import { styled } from "@mui/material/styles"; +import Tooltip, { tooltipClasses } from "@mui/material/Tooltip"; +import InfoIcon from "@mui/icons-material/Info"; import { Control, Controller, @@ -53,6 +56,31 @@ import { import ClientOnlyPortal from "@/common/components/ClientOnlyPortal"; import { updateTabularTrainspaceData } from "../redux/tabularActions"; +const HtmlTooltip = styled( + ({ + className, + title, + children, + ...props + }: { + className?: string; + children: React.ReactElement; + title: React.ReactNode; + }) => ( + + {children} + + ) +)(({ theme }) => ({ + [`& .${tooltipClasses.tooltip}`]: { + backgroundColor: "rgba(255, 255, 255, 0.95)", + color: "rgba(0, 0, 0, 0.87)", + maxWidth: 220, + fontSize: theme.typography.pxToRem(12), + border: "none", + }, +})); + const TabularParametersStep = ({ renderStepperButtons, setIsModified, @@ -77,7 +105,7 @@ const TabularParametersStep = ({ handleSubmit, formState: { errors, isDirty }, control, - watch + watch, } = useForm({ defaultValues: { targetCol: @@ -141,7 +169,7 @@ const TabularParametersStep = ({ error={errors.targetCol ? true : false} /> )} - options={data.filter(col => !features.includes(col))} + options={data.filter((col) => !features.includes(col))} /> )} /> @@ -167,8 +195,7 @@ const TabularParametersStep = ({ error={errors.features ? true : false} /> )} - - options={data.filter(col => col!== targetCol)} + options={data.filter((col) => col !== targetCol)} /> )} /> @@ -535,6 +562,19 @@ const LayerComponent = ({ alignItems={"center"} spacing={3} > + + + {STEP_SETTINGS.PARAMETERS.layers[data.value].label} + + {STEP_SETTINGS.PARAMETERS.layers[data.value].description} + + } + > + Info + + {STEP_SETTINGS.PARAMETERS.layers[data.value].label} diff --git a/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts b/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts index db5c4622e..21d944e2e 100644 --- a/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts +++ b/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts @@ -61,6 +61,7 @@ export const STEP_SETTINGS = { { label: "Adam Optimization", value: "Adam" }, ], layerValues: ["LINEAR", "RELU", "TANH", "SOFTMAX", "SIGMOID", "LOGSOFTMAX"], + layers: { LINEAR: { label: "Linear", @@ -81,16 +82,19 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "A linear layer performs a mathematical operation called linear transformation on a set of input values. It applies a combination of scaling and shifting to the input values, resulting in a new set of transformed values as output." }, RELU: { label: "ReLU", objectName: "nn.ReLU", parameters: [], + description: "ReLU, short for Rectified Linear Unit, is an activation function that acts like a filter that selectively allows positive numbers to pass through unchanged, while converting negative numbers to zero." }, TANH: { label: "Tanh", objectName: "nn.Tanh", parameters: [], + description: "The tanh function maps input numbers to a range between -1 and 1, emphasizing values close to zero while diminishing the impact of extremely large or small numbers, making it useful for capturing complex patterns in data." }, SOFTMAX: { label: "Softmax", @@ -104,11 +108,13 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "The softmax function takes a set of numbers as input and converts them into a probability distribution, assigning higher probabilities to larger numbers and lower probabilities to smaller numbers, making it useful for multi-class classification tasks." }, SIGMOID: { label: "Sigmoid", objectName: "nn.Sigmoid", parameters: [], + description: "The sigmoid function takes any input number and squeezes it to a range between 0 and 1, effectively converting it into a probability-like value, often used for binary classification tasks and as an activation function in neural networks." }, LOGSOFTMAX: { label: "LogSoftmax", @@ -122,6 +128,7 @@ export const STEP_SETTINGS = { type: "number", }, ], + description: "The logsoftmax function converts a set of numbers into a probability distribution using the softmax function, and then applies a logarithm to the resulting probabilities. It is commonly used for multi-class classification tasks as an activation function and to calculate the logarithmic loss during neural network training." } }, },