Skip to content

Commit

Permalink
feat: Support pydantic 2
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Jun 5, 2024
1 parent 60e2d82 commit ce73016
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 31 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- psutil >=4.2.0
- pyasn1 >0.4.1
- pyasn1-modules
- pydantic <2
- pydantic >=2
- python-json-logger >=0.1.8
- pytz >=2015.7
- pyyaml
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ install_requires =
psutil
pyasn1
pyasn1-modules
pydantic
pyparsing
python-dateutil
pytz
Expand Down
31 changes: 16 additions & 15 deletions src/DIRAC/Core/Utilities/test/Test_JDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@
from DIRAC.WorkloadManagementSystem.Utilities.JobModel import JobDescriptionModel


def test_jdlToBaseJobDescriptionModel_valid():
@pytest.fixture()
def jdl_monkey_business(monkeypatch):
monkeypatch.setattr("DIRAC.Core.Base.API.getSites", lambda: S_OK(["LCG.IN2P3.fr"]))
monkeypatch.setattr("DIRAC.WorkloadManagementSystem.Utilities.JobModel.getSites", lambda: S_OK(["LCG.IN2P3.fr"]))
monkeypatch.setattr("DIRAC.Interfaces.API.Job.getDIRACPlatforms", lambda: S_OK("x86_64-slc6-gcc49-opt"))
monkeypatch.setattr(
"DIRAC.WorkloadManagementSystem.Utilities.JobModel.getDIRACPlatforms", lambda: S_OK("x86_64-slc6-gcc49-opt")
)
yield


def test_jdlToBaseJobDescriptionModel_valid(jdl_monkey_business):
"""This test makes sure that a job object can be parsed by the jdlToBaseJobDescriptionModel method"""
# Arrange
with patch("DIRAC.Core.Base.API.getSites", return_value=S_OK(["LCG.IN2P3.fr"])):
job = Job()
job = Job()
job.setConfigArgs("configArgs")
job.setCPUTime(3600)
job.setExecutable("/bin/echo", arguments="arguments", logFile="logFile")
Expand All @@ -36,8 +46,7 @@ def test_jdlToBaseJobDescriptionModel_valid():
job.setParameterSequence("FloatSequence", [1.0, 2.0, 3.0])

job.setOutputData(["outputfile.root"], outputSE="IN2P3-disk", outputPath="/myjobs/1234")
with patch("DIRAC.Interfaces.API.Job.getDIRACPlatforms", return_value=S_OK("x86_64-slc6-gcc49-opt")):
job.setPlatform("x86_64-slc6-gcc49-opt")
job.setPlatform("x86_64-slc6-gcc49-opt")
job.setPriority(10)

job.setDestination("LCG.IN2P3.fr")
Expand Down Expand Up @@ -71,15 +80,7 @@ def test_jdlToBaseJobDescriptionModel_valid():
assert res["OK"], res["Message"]

data = res["Value"].dict()
with patch(
"DIRAC.WorkloadManagementSystem.Utilities.JobModel.getDIRACPlatforms",
return_value=S_OK(["x86_64-slc6-gcc49-opt"]),
):
with patch(
"DIRAC.WorkloadManagementSystem.Utilities.JobModel.getSites",
return_value=S_OK(["LCG.IN2P3.fr"]),
):
assert JobDescriptionModel(owner="owner", ownerGroup="ownerGroup", vo="lhcb", **data)
assert JobDescriptionModel(owner="owner", ownerGroup="ownerGroup", vo="lhcb", **data)


@pytest.mark.parametrize(
Expand All @@ -90,7 +91,7 @@ def test_jdlToBaseJobDescriptionModel_valid():
"""Executable="executable";""", # Missing brackets
],
)
def test_jdlToBaseJobDescriptionModel_invalid(jdl):
def test_jdlToBaseJobDescriptionModel_invalid(jdl, jdl_monkey_business):
"""This test makes sure that a job object without an executable raises an error"""
# Arrange

Expand Down
4 changes: 2 additions & 2 deletions src/DIRAC/Workflow/Utilities/test/Test_Utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test__getStepDefinition(self):

stepDef = getStepDefinition("App_Step", ["Script", "FailoverRequest"])

self.assertTrue(str(appDefn) == str(stepDef))
assert str(appDefn) == str(stepDef)

self.job._addParameter(appDefn, "name", "type", "value", "desc")
self.job._addParameter(appDefn, "name1", "type1", "value1", "desc1")
Expand All @@ -59,7 +59,7 @@ def test__getStepDefinition(self):
parametersList=[["name", "type", "value", "desc"], ["name1", "type1", "value1", "desc1"]],
)

self.assertTrue(str(appDefn) == str(stepDef))
assert str(appDefn) == str(stepDef)

def test_getStepCPUTimes(self):
execT, cpuT = getStepCPUTimes({})
Expand Down
51 changes: 38 additions & 13 deletions src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,50 @@

# pylint: disable=no-self-argument, no-self-use, invalid-name, missing-function-docstring

from typing import Any
from collections.abc import Iterable
from typing import Any, Annotated

import pydantic
from packaging.version import Version
from pydantic import BaseModel, root_validator, validator

from DIRAC import gLogger
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getDIRACPlatforms, getSites


# HACK: Convert appropriate iterables into sets
def default_set_validator(value):
if not isinstance(value, Iterable):
return value
elif isinstance(value, (str, bytes, bytearray)):
return value
else:
return set(value)


if Version(pydantic.__version__) > Version("2.0.0a0"):
CoercibleSetStr = Annotated[
set[str] | None, pydantic.BeforeValidator(default_set_validator) # pylint: disable=no-member
]
else:
CoercibleSetStr = set[str]


class BaseJobDescriptionModel(BaseModel):
"""Base model for the job description (not parametric)"""

class Config:
validate_assignment = True

arguments: str = None
bannedSites: set[str] = None
bannedSites: CoercibleSetStr = None
cpuTime: int = Operations().getValue("JobDescription/DefaultCPUTime", 86400)
executable: str
executionEnvironment: dict = None
gridCE: str = None
inputSandbox: set[str] = None
inputData: set[str] = None
inputSandbox: CoercibleSetStr = None
inputData: CoercibleSetStr = None
inputDataPolicy: str = None
jobConfigArgs: str = None
jobGroup: str = None
Expand All @@ -29,16 +54,16 @@ class BaseJobDescriptionModel(BaseModel):
logLevel: str = "INFO"
maxNumberOfProcessors: int = None
minNumberOfProcessors: int = 1
outputData: set[str] = None
outputData: CoercibleSetStr = None
outputPath: str = None
outputSandbox: set[str] = None
outputSandbox: CoercibleSetStr = None
outputSE: str = None
platform: str = None
priority: int = Operations().getValue("JobDescription/DefaultPriority", 1)
sites: set[str] = None
sites: CoercibleSetStr = None
stderr: str = "std.err"
stdout: str = "std.out"
tags: set[str] = None
tags: CoercibleSetStr = None
extraFields: dict[str, Any] = None

@validator("cpuTime")
Expand Down Expand Up @@ -83,7 +108,7 @@ def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]):
raise ValueError("Input data files must start with LFN:/")
return v

@root_validator
@root_validator(skip_on_failure=True)
def checkNumberOfInputDataFiles(cls, values):
if "inputData" in values and values["inputData"]:
maxInputDataFiles = Operations().getValue("JobDescription/MaxInputData", 500)
Expand Down Expand Up @@ -126,14 +151,14 @@ def checkMaxNumberOfProcessorsBounds(cls, v):
)
return v

@root_validator
@root_validator(skip_on_failure=True)
def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(cls, values):
if "maxNumberOfProcessors" in values and values["maxNumberOfProcessors"]:
if values["maxNumberOfProcessors"] < values["minNumberOfProcessors"]:
raise ValueError("maxNumberOfProcessors must be greater than minNumberOfProcessors")
return values

@root_validator
@root_validator(skip_on_failure=True)
def addTagsDependingOnNumberOfProcessors(cls, values):
if "maxNumberOfProcessors" in values and values["minNumberOfProcessors"] == values["maxNumberOfProcessors"]:
if values["tags"] is None:
Expand All @@ -157,7 +182,7 @@ def checkSites(cls, v: set[str]):
raise ValueError(f"Invalid sites: {' '.join(invalidSites)}")
return v

@root_validator
@root_validator(skip_on_failure=True)
def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(cls, values):
if "sites" in values and values["sites"] and "bannedSites" in values and values["bannedSites"]:
values["sites"] -= values["bannedSites"]
Expand Down Expand Up @@ -192,7 +217,7 @@ class JobDescriptionModel(BaseJobDescriptionModel):
ownerGroup: str
vo: str

@root_validator
@root_validator(skip_on_failure=True)
def checkLFNMatchesREGEX(cls, values):
if "inputData" in values and values["inputData"]:
for lfn in values["inputData"]:
Expand Down

0 comments on commit ce73016

Please sign in to comment.