Skip to content

Commit

Permalink
Merge pull request #397 from ryuwd/roneil-jobstateupdatehandler
Browse files Browse the repository at this point in the history
feat: patch job metadata endpoint
  • Loading branch information
fstagni authored Mar 7, 2025
2 parents ebb9299 + 3419a0d commit a2e3fb0
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 2 deletions.
5 changes: 4 additions & 1 deletion diracx-logic/src/diracx/logic/jobs/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ async def search(
if query_logging_info := ("LoggingInfo" in (body.parameters or [])):
if body.parameters:
body.parameters.remove("LoggingInfo")
body.parameters = ["JobID"] + (body.parameters or [])
if not body.parameters:
body.parameters = None
else:
body.parameters = ["JobID"] + (body.parameters or [])

# TODO: Apply all the job policy stuff properly using user_info
if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo:
Expand Down
49 changes: 48 additions & 1 deletion diracx-logic/src/diracx/logic/jobs/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
VectorSearchOperator,
VectorSearchSpec,
)
from diracx.db.sql.job.db import JobDB
from diracx.db.os.job_parameters import JobParametersDB
from diracx.db.sql.job.db import JobDB, _get_columns
from diracx.db.sql.job.schema import Jobs
from diracx.db.sql.job_logging.db import JobLoggingDB
from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB
from diracx.db.sql.task_queue.db import TaskQueueDB
Expand Down Expand Up @@ -474,3 +476,48 @@ async def remove_jobs_from_task_queue(
await recalculate_tq_shares_for_entity(
owner, owner_group, vo, config, task_queue_db
)


async def set_job_parameters_or_attributes(
updates: dict[int, dict[str, Any]],
job_db: JobDB,
job_parameters_db: JobParametersDB,
):
"""Set job parameters or attributes for a list of jobs."""
attribute_columns: list[str] = [
col.name for col in _get_columns(Jobs.__table__, None)
]
attribute_columns_lower: list[str] = [col.lower() for col in attribute_columns]

attr_updates: dict[int, dict[str, Any]] = {}
param_updates: dict[int, dict[str, Any]] = {}

for job_id, metadata in updates.items():
attr_updates[job_id] = {}
param_updates[job_id] = {}
for pname, pvalue in metadata.items():
# If the attribute exactly matches one of the allowed columns, treat it as an attribute.
if pname in attribute_columns:
attr_updates[job_id][pname] = pvalue
# Otherwise, if the lower-case version is valid, the user likely mis-cased the key.
elif pname.lower() in attribute_columns_lower:
correct_name = attribute_columns[
attribute_columns_lower.index(pname.lower())
]
raise ValueError(
f"Attribute column '{pname}' is mis-cased. Did you mean '{correct_name}'?"
)
# Otherwise, assume it should be routed to the parameters DB.
else:
param_updates[job_id][pname] = pvalue

# bulk set job attributes
await job_db.set_job_attributes(attr_updates)

# TODO: can we upsert to multiple documents?
for job_id, p_updates_ in param_updates.items():
if p_updates_:
await job_parameters_db.upsert(
int(job_id),
p_updates_,
)
21 changes: 21 additions & 0 deletions diracx-routers/src/diracx/routers/jobs/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
)
from diracx.logic.jobs.status import remove_jobs as remove_jobs_bl
from diracx.logic.jobs.status import reschedule_jobs as reschedule_jobs_bl
from diracx.logic.jobs.status import (
set_job_parameters_or_attributes as set_job_parameters_or_attributes_bl,
)
from diracx.logic.jobs.status import set_job_statuses as set_job_statuses_bl

from ..dependencies import (
Config,
JobDB,
JobLoggingDB,
JobParametersDB,
SandboxMetadataDB,
TaskQueueDB,
)
Expand Down Expand Up @@ -124,3 +128,20 @@ async def reschedule_jobs(
# self.__sendJobsToOptimizationMind(validJobList)

return resched_jobs


@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT)
async def patch_metadata(
updates: dict[int, dict[str, Any]],
job_db: JobDB,
job_parameters_db: JobParametersDB,
check_permissions: CheckWMSPolicyCallable,
):
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=updates)
try:
await set_job_parameters_or_attributes_bl(updates, job_db, job_parameters_db)
except ValueError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=str(e),
) from e
107 changes: 107 additions & 0 deletions diracx-routers/tests/jobs/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,110 @@ def test_remove_jobs_valid_job_ids(
# for job_id in valid_job_ids:
# r = normal_user_client.get(f"/api/jobs/{job_id}/status")
# assert r.status_code == HTTPStatus.NOT_FOUND, r.json()


def test_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
# Arrange
r = normal_user_client.post(
"/api/jobs/search",
json={
"search": [
{
"parameter": "JobID",
"operator": "eq",
"value": valid_job_id,
}
],
"parameters": ["LoggingInfo"],
},
)

assert r.status_code == 200, r.json()
for j in r.json():
assert j["JobID"] == valid_job_id
assert j["Status"] == JobStatus.RECEIVED.value
assert j["MinorStatus"] == "Job accepted"
assert j["ApplicationStatus"] == "Unknown"

# Act
hbt = str(datetime.now(timezone.utc))
r = normal_user_client.patch(
"/api/jobs/metadata",
json={
valid_job_id: {
"UserPriority": 2,
"HeartBeatTime": hbt,
# set a parameter
"JobType": "VerySpecialIndeed",
}
},
)

# Assert
assert (
r.status_code == 204
), "PATCH metadata should return 204 No Content on success"
r = normal_user_client.post(
"/api/jobs/search",
json={
"search": [
{
"parameter": "JobID",
"operator": "eq",
"value": valid_job_id,
}
],
"parameters": ["LoggingInfo"],
},
)
assert r.status_code == 200, r.json()

assert r.json()[0]["JobID"] == valid_job_id
assert r.json()[0]["JobType"] == "VerySpecialIndeed"
assert datetime.fromisoformat(
r.json()[0]["HeartBeatTime"]
) == datetime.fromisoformat(hbt)
assert r.json()[0]["UserPriority"] == 2


def test_bad_patch_metadata(normal_user_client: TestClient, valid_job_id: int):
# Arrange
r = normal_user_client.post(
"/api/jobs/search",
json={
"search": [
{
"parameter": "JobID",
"operator": "eq",
"value": valid_job_id,
}
],
"parameters": ["LoggingInfo"],
},
)

assert r.status_code == 200, r.json()
for j in r.json():
assert j["JobID"] == valid_job_id
assert j["Status"] == JobStatus.RECEIVED.value
assert j["MinorStatus"] == "Job accepted"
assert j["ApplicationStatus"] == "Unknown"

# Act
hbt = str(datetime.now(timezone.utc))
r = normal_user_client.patch(
"/api/jobs/metadata",
json={
valid_job_id: {
"UserPriority": 2,
"Heartbeattime": hbt,
# set a parameter
"JobType": "VerySpecialIndeed",
}
},
)

# Assert
assert (
r.status_code == 400
), "PATCH metadata should 400 Bad Request if an attribute column's case is incorrect"

0 comments on commit a2e3fb0

Please sign in to comment.