diff --git a/src/Shared/Shared/ValidationConstants.cs b/src/Shared/Shared/ValidationConstants.cs index 460d1e0a4..a02d54344 100644 --- a/src/Shared/Shared/ValidationConstants.cs +++ b/src/Shared/Shared/ValidationConstants.cs @@ -53,6 +53,21 @@ public static class ValidationConstants /// public static readonly string Notifications = "notifications"; + /// + /// Key for the CPU. + /// + public static readonly string Cpu = "cpu"; + + /// + /// Key for the memory. + /// + public static readonly string Memory = "memory_gb"; + + /// + /// Key for the GPU. + /// + public static readonly string Gpu = "gpu"; + public enum ModeValues { QA, diff --git a/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs b/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs index 5e78c2f15..ead9718b8 100644 --- a/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs +++ b/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs @@ -88,6 +88,21 @@ internal static class Keys /// public static readonly string TaskPriorityClassName = "priority"; + /// + /// Key for CPU + /// + public static readonly string Cpu = "cpu"; + + /// + /// Key for memory allocation + /// + public static readonly string Memory = "memory_gb"; + + /// + /// Key for GPU + /// + public static readonly string Gpu = "number_gpu"; + /// /// Required arguments to run the Argo workflow. /// diff --git a/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs b/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs index acec57652..1e3d7418f 100644 --- a/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs +++ b/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs @@ -152,9 +152,9 @@ private void ValidateTasks(Workflow workflow, string firstTaskId) { // duplicate destinations var duplicates = destinations - .GroupBy(i => i) - .Where(g => g.Count() > 1) - .Select(g => g.Key); + .GroupBy(i => i) + .Where(g => g.Count() > 1) + .Select(g => g.Key); foreach (var dupe in duplicates) { @@ -348,6 +348,27 @@ private void ValidateArgoTask(TaskObject currentTask) break; } } + + new List { Cpu, Memory }.ForEach(key => + { + if ( + currentTask.Args.TryGetValue(key, out var val) && + !string.IsNullOrEmpty(val) && + double.TryParse(val, out double parsedVal) && + (parsedVal < 1 || Math.Truncate(parsedVal) != parsedVal)) + { + Errors.Add($"Task: '{currentTask.Id}' value '{val}' provided for argument '{key}' is not valid. The value needs to be a whole number greater than 0."); + } + }); + + if ( + currentTask.Args.TryGetValue(Gpu, out var gpu) && + !string.IsNullOrEmpty(gpu) && + double.TryParse(gpu, out double parsedGpu) && + (parsedGpu != 0 || parsedGpu != 1)) + { + Errors.Add($"Task: '{currentTask.Id}' value '{gpu}' provided for argument '{Gpu}' is not valid. The value needs to be 0 or 1."); + } } private void ValidateClinicalReviewTask(TaskObject[] tasks, TaskObject currentTask) diff --git a/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs b/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs index cb9138d53..4e788c15a 100644 --- a/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs +++ b/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs @@ -15,9 +15,7 @@ */ using System; -using System.Security.Cryptography.Xml; using System.Threading.Tasks; -using Amazon.Runtime.Internal.Transform; using Microsoft.Extensions.Logging; using Monai.Deploy.WorkflowManager.Common.Interfaces; using Monai.Deploy.WorkflowManager.Contracts.Models; @@ -206,7 +204,10 @@ public async Task ValidateWorkflow_ValidatesAWorkflow_ReturnsErrorsAndHasCorrect Type = "argo", Description = "Test Argo Task", Args = { - { "example", "value" } + { "example", "value" }, + { "cpu", "0.1" }, + { "memory_gb", "0.1" }, + { "gpu", "2" } }, TaskDestinations = new TaskDestination[] { @@ -347,7 +348,7 @@ public async Task ValidateWorkflow_ValidatesAWorkflow_ReturnsErrorsAndHasCorrect Assert.True(errors.Count > 0); - Assert.Equal(40, errors.Count); + Assert.Equal(43, errors.Count); var convergingTasksDestinations = "Converging Tasks Destinations in tasks: (test-clinical-review-2, example-task) on task: example-task"; Assert.Contains(convergingTasksDestinations, errors); @@ -385,6 +386,15 @@ public async Task ValidateWorkflow_ValidatesAWorkflow_ReturnsErrorsAndHasCorrect var missingArgoArgs = "Task: 'test-argo-task' workflow_template_name must be specified, this corresponds to an Argo template name."; Assert.Contains(missingArgoArgs, errors); + var invalidArgoArg1 = "Task: 'test-argo-task' value '0.1' provided for argument 'cpu' is not valid. The value needs to be a whole number greater than 0."; + Assert.Contains(invalidArgoArg1, errors); + + var invalidArgoArg2 = "Task: 'test-argo-task' value '0.1' provided for argument 'memory_gb' is not valid. The value needs to be a whole number greater than 0."; + Assert.Contains(invalidArgoArg2, errors); + + var invalidArgoArg3 = "Task: 'test-argo-task' value '2' provided for argument 'gpu' is not valid. The value needs to be 0 or 1."; + Assert.Contains(invalidArgoArg3, errors); + var incorrectClinicalReviewValueFormat = $"Invalid Value property on input artifact 'Invalid Value Format' in task: 'test-clinical-review'. Incorrect format."; Assert.Contains(incorrectClinicalReviewValueFormat, errors);