diff --git a/src/Shared/Shared/ValidationConstants.cs b/src/Shared/Shared/ValidationConstants.cs
index 460d1e0a4..a02d54344 100644
--- a/src/Shared/Shared/ValidationConstants.cs
+++ b/src/Shared/Shared/ValidationConstants.cs
@@ -53,6 +53,21 @@ public static class ValidationConstants
///
public static readonly string Notifications = "notifications";
+ ///
+ /// Key for the CPU.
+ ///
+ public static readonly string Cpu = "cpu";
+
+ ///
+ /// Key for the memory.
+ ///
+ public static readonly string Memory = "memory_gb";
+
+ ///
+ /// Key for the GPU.
+ ///
+ public static readonly string Gpu = "gpu";
+
public enum ModeValues
{
QA,
diff --git a/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs b/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs
index 5e78c2f15..ead9718b8 100644
--- a/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs
+++ b/src/TaskManager/Plug-ins/Argo/StaticValues/Keys.cs
@@ -88,6 +88,21 @@ internal static class Keys
///
public static readonly string TaskPriorityClassName = "priority";
+ ///
+ /// Key for CPU
+ ///
+ public static readonly string Cpu = "cpu";
+
+ ///
+ /// Key for memory allocation
+ ///
+ public static readonly string Memory = "memory_gb";
+
+ ///
+ /// Key for GPU
+ ///
+ public static readonly string Gpu = "number_gpu";
+
///
/// Required arguments to run the Argo workflow.
///
diff --git a/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs b/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs
index acec57652..1e3d7418f 100644
--- a/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs
+++ b/src/WorkflowManager/WorkflowManager/Validators/WorkflowValidator.cs
@@ -152,9 +152,9 @@ private void ValidateTasks(Workflow workflow, string firstTaskId)
{
// duplicate destinations
var duplicates = destinations
- .GroupBy(i => i)
- .Where(g => g.Count() > 1)
- .Select(g => g.Key);
+ .GroupBy(i => i)
+ .Where(g => g.Count() > 1)
+ .Select(g => g.Key);
foreach (var dupe in duplicates)
{
@@ -348,6 +348,27 @@ private void ValidateArgoTask(TaskObject currentTask)
break;
}
}
+
+ new List { Cpu, Memory }.ForEach(key =>
+ {
+ if (
+ currentTask.Args.TryGetValue(key, out var val) &&
+ !string.IsNullOrEmpty(val) &&
+ double.TryParse(val, out double parsedVal) &&
+ (parsedVal < 1 || Math.Truncate(parsedVal) != parsedVal))
+ {
+ Errors.Add($"Task: '{currentTask.Id}' value '{val}' provided for argument '{key}' is not valid. The value needs to be a whole number greater than 0.");
+ }
+ });
+
+ if (
+ currentTask.Args.TryGetValue(Gpu, out var gpu) &&
+ !string.IsNullOrEmpty(gpu) &&
+ double.TryParse(gpu, out double parsedGpu) &&
+ (parsedGpu != 0 || parsedGpu != 1))
+ {
+ Errors.Add($"Task: '{currentTask.Id}' value '{gpu}' provided for argument '{Gpu}' is not valid. The value needs to be 0 or 1.");
+ }
}
private void ValidateClinicalReviewTask(TaskObject[] tasks, TaskObject currentTask)
diff --git a/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs b/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs
index cb9138d53..4e788c15a 100644
--- a/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs
+++ b/tests/UnitTests/WorkflowManager.Tests/Validators/WorkflowValidatorTests.cs
@@ -15,9 +15,7 @@
*/
using System;
-using System.Security.Cryptography.Xml;
using System.Threading.Tasks;
-using Amazon.Runtime.Internal.Transform;
using Microsoft.Extensions.Logging;
using Monai.Deploy.WorkflowManager.Common.Interfaces;
using Monai.Deploy.WorkflowManager.Contracts.Models;
@@ -206,7 +204,10 @@ public async Task ValidateWorkflow_ValidatesAWorkflow_ReturnsErrorsAndHasCorrect
Type = "argo",
Description = "Test Argo Task",
Args = {
- { "example", "value" }
+ { "example", "value" },
+ { "cpu", "0.1" },
+ { "memory_gb", "0.1" },
+ { "gpu", "2" }
},
TaskDestinations = new TaskDestination[]
{
@@ -347,7 +348,7 @@ public async Task ValidateWorkflow_ValidatesAWorkflow_ReturnsErrorsAndHasCorrect
Assert.True(errors.Count > 0);
- Assert.Equal(40, errors.Count);
+ Assert.Equal(43, errors.Count);
var convergingTasksDestinations = "Converging Tasks Destinations in tasks: (test-clinical-review-2, example-task) on task: example-task";
Assert.Contains(convergingTasksDestinations, errors);
@@ -385,6 +386,15 @@ public async Task ValidateWorkflow_ValidatesAWorkflow_ReturnsErrorsAndHasCorrect
var missingArgoArgs = "Task: 'test-argo-task' workflow_template_name must be specified, this corresponds to an Argo template name.";
Assert.Contains(missingArgoArgs, errors);
+ var invalidArgoArg1 = "Task: 'test-argo-task' value '0.1' provided for argument 'cpu' is not valid. The value needs to be a whole number greater than 0.";
+ Assert.Contains(invalidArgoArg1, errors);
+
+ var invalidArgoArg2 = "Task: 'test-argo-task' value '0.1' provided for argument 'memory_gb' is not valid. The value needs to be a whole number greater than 0.";
+ Assert.Contains(invalidArgoArg2, errors);
+
+ var invalidArgoArg3 = "Task: 'test-argo-task' value '2' provided for argument 'gpu' is not valid. The value needs to be 0 or 1.";
+ Assert.Contains(invalidArgoArg3, errors);
+
var incorrectClinicalReviewValueFormat = $"Invalid Value property on input artifact 'Invalid Value Format' in task: 'test-clinical-review'. Incorrect format.";
Assert.Contains(incorrectClinicalReviewValueFormat, errors);