Skip to content

Commit

Permalink
refactored job submission to reduce the number of statements executed…
Browse files Browse the repository at this point in the history
… where possible
  • Loading branch information
ryuwd committed Dec 13, 2024
1 parent b3249a3 commit 2239bd6
Showing 1 changed file with 151 additions and 98 deletions.
249 changes: 151 additions & 98 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel
from sqlalchemy import bindparam, delete, func, insert, select, update
from sqlalchemy.exc import IntegrityError, NoResultFound

if TYPE_CHECKING:
from sqlalchemy.sql.elements import BindParameter

from diracx.core.exceptions import InvalidQueryError, JobNotFound
from diracx.core.models import (
JobStatus,
LimitedJobStatusReturn,
SearchSpec,
SortSpec,
Expand All @@ -27,6 +26,15 @@
)


class JobSubmissionSpec(BaseModel):
jdl: str
owner: str
owner_group: str
initial_status: str
initial_minor_status: str
vo: str


def _get_columns(table, parameters):
columns = [x for x in table.columns]
if parameters:
Expand Down Expand Up @@ -218,108 +226,155 @@ async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, s
if jdl
}

async def insert(
async def insert_bulk(
self,
jdl,
owner,
owner_group,
initial_status,
initial_minor_status,
vo,
jobs: list[JobSubmissionSpec],
):
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import (
checkAndAddOwner,
compressJDL,
createJDLWithInitialStatus,
fixJDL,
)

job_attrs = {
"LastUpdateTime": datetime.now(tz=timezone.utc),
"SubmissionTime": datetime.now(tz=timezone.utc),
"Owner": owner,
"OwnerGroup": owner_group,
"VO": vo,
}

jobManifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_group))

jdl = fixJDL(jdl)

job_id = await self._insertNewJDL(jdl)

jobManifest.setOption("JobID", job_id)
jobs_to_insert = []
jdls_to_update = []
inputdata_to_insert = []
original_jdls = []

job_attrs["JobID"] = job_id

# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest.dumpAsJDL()

# Replace the JobID placeholder if any
if jobJDL.find("%j") != -1:
jobJDL = jobJDL.replace("%j", str(job_id))

class_ad_job = ClassAd(jobJDL)
class_ad_req = ClassAd("[]")
if not class_ad_job.isOK():
job_attrs["Status"] = JobStatus.FAILED

job_attrs["MinorStatus"] = "Error in JDL syntax"

await self._insertJob(job_attrs)
# generate the jobIDs first
for job in jobs:
original_jdl = job.jdl
jobManifest = returnValueOrRaise(
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
)

return {
# Fix possible lack of brackets
if original_jdl.strip()[0] != "[":
original_jdl = f"[{original_jdl}]"

original_jdls.append((original_jdl, jobManifest))

results = await self.conn.execute(
insert(JobJDLs),
[
{
"JDL": "",
"JobRequirements": "",
"OriginalJDL": compressJDL(original_jdl),
}
for original_jdl, _ in original_jdls
],
)
job_ids = [
result.lastrowid for result in results
] # FIXME is SCOPE_IDENTITY() used?

for job_id, job, (original_jdl, jobManifest) in zip(
job_ids, jobs, original_jdls
):
job_attrs = {
"LastUpdateTime": datetime.now(tz=timezone.utc),
"SubmissionTime": datetime.now(tz=timezone.utc),
"Owner": job.owner,
"OwnerGroup": job.owner_group,
"VO": job.vo,
"JobID": job_id,
"Status": JobStatus.FAILED,
"MinorStatus": "Error in JDL syntax",
}

class_ad_job.insertAttributeInt("JobID", job_id)
jobManifest.setOption("JobID", job_id)

# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest.dumpAsJDL()

# Replace the JobID placeholder if any
if jobJDL.find("%j") != -1:
jobJDL = jobJDL.replace("%j", str(job_id))

class_ad_job = ClassAd(jobJDL)
class_ad_req = ClassAd("[]")
if not class_ad_job.isOK():
# Rollback the entire transaction
raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}")
# TODO: check if that is actually true
if class_ad_job.lookupAttribute("Parameters"):
raise NotImplementedError("Parameters in the JDL are not supported")

# TODO is this even needed?
class_ad_job.insertAttributeInt("JobID", job_id)

await self.checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
job.owner,
job.owner_group,
job_attrs,
job.vo,
)

await self.checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
owner,
owner_group,
job_attrs,
vo,
)
jobJDL = createJDLWithInitialStatus(
class_ad_job,
class_ad_req,
self.jdl2DBParameters,
job_attrs,
job.initial_status,
job.initial_minor_status,
modern=True,
)

jobJDL = createJDLWithInitialStatus(
class_ad_job,
class_ad_req,
self.jdl2DBParameters,
job_attrs,
initial_status,
initial_minor_status,
modern=True,
)
jobs_to_insert.append(job_attrs)
jdls_to_update.append(
{
"JobID": job_id,
"JDL": compressJDL(jobJDL),
}
)

await self.setJobJDL(job_id, jobJDL)
if class_ad_job.lookupAttribute("InputData"):
inputData = class_ad_job.getListFromExpression("InputData")
inputdata_to_insert += [
{"JobID": job_id, "LFN": lfn} for lfn in inputData if lfn
]
await self.conn.execute(
update(JobJDLs),
jdls_to_update,
)

# Adding the job in the Jobs table
await self._insertJob(job_attrs)
await self.conn.execute(
insert(Jobs),
jobs_to_insert,
)

# TODO: check if that is actually true
if class_ad_job.lookupAttribute("Parameters"):
raise NotImplementedError("Parameters in the JDL are not supported")
await self.conn.execute(
insert(InputData),
inputdata_to_insert,
)

# Looking for the Input Data
inputData = []
if class_ad_job.lookupAttribute("InputData"):
inputData = class_ad_job.getListFromExpression("InputData")
lfns = [lfn for lfn in inputData if lfn]
if lfns:
await self._insertInputData(job_id, lfns)
return job_ids

return {
"JobID": job_id,
"Status": initial_status,
"MinorStatus": initial_minor_status,
"TimeStamp": datetime.now(tz=timezone.utc),
}
async def insert(
self,
jdl,
owner,
owner_group,
initial_status,
initial_minor_status,
vo,
):
return self.insert_bulk(
[
JobSubmissionSpec(
jdl=jdl,
owner=owner,
owner_group=owner_group,
initial_status=initial_status,
initial_minor_status=initial_minor_status,
vo=vo,
)
]
)

async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
try:
Expand Down Expand Up @@ -347,21 +402,19 @@ async def set_job_command(self, job_id: int, command: str, arguments: str = ""):

async def set_job_command_bulk(self, commands):
"""Store a command to be passed to the job together with the next heart beat."""
try:
self.conn.execute(
insert(JobCommands),
[
{
"JobID": job_id,
"Command": command,
"Arguments": arguments,
"ReceptionTime": datetime.now(tz=timezone.utc),
}
for job_id, command, arguments in commands
],
)
except IntegrityError as e:
raise JobNotFound(job_id) from e # FIXME
self.conn.execute(
insert(JobCommands),
[
{
"JobID": job_id,
"Command": command,
"Arguments": arguments,
"ReceptionTime": datetime.now(tz=timezone.utc),
}
for job_id, command, arguments in commands
],
)
# FIXME handle IntegrityError

async def delete_jobs(self, job_ids: list[int]):
"""Delete jobs from the database."""
Expand Down

0 comments on commit 2239bd6

Please sign in to comment.