Skip to content

Commit

Permalink
Merge pull request DIRACGrid#7222 from chaen/diracx_sandbox
Browse files Browse the repository at this point in the history
[8.1] Get tokens and sandboxes from DiracX
  • Loading branch information
chrisburr authored Oct 7, 2023
2 parents 60d00f8 + c43382b commit da42574
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 56 deletions.
9 changes: 9 additions & 0 deletions dirac.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ DIRAC

}

}
# This part contains anything related to DiracX
DiracX
{
# The URL of the DIRAC Server
URL = https://diracx.invalid:8000
# A key used to have priviledged interactions with diracx. see
LegacyExchangeApiKey = diracx:legacy:InsecureChangeMe

}
### Registry section:
# Sections to register VOs, groups, users and hosts
Expand Down
22 changes: 14 additions & 8 deletions integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,22 +348,28 @@ def install_server():
"""Install DIRAC in the server container."""
_check_containers_running()

typer.secho("Running server installation", fg=c.GREEN)
base_cmd = _build_docker_cmd("server", tty=False)
subprocess.run(
base_cmd + ["bash", "/home/dirac/LocalRepo/TestCode/DIRAC/tests/CI/install_server.sh"],
check=True,
)

# This runs a continuous loop that exports the config in yaml
# for the diracx container to use
# It needs to be started and running before the DIRAC server installation
# because after installing the databases, the install server script
# calls dirac-login.
# At this point we need the new CS to have been updated
# already else the token exchange fails.

typer.secho("Starting configuration export loop for diracx", fg=c.GREEN)
base_cmd = _build_docker_cmd("server", tty=False, daemon=True)
base_cmd = _build_docker_cmd("server", tty=False, daemon=True, use_root=True)
subprocess.run(
base_cmd + ["bash", "/home/dirac/LocalRepo/ALTERNATIVE_MODULES/DIRAC/tests/CI/exportCSLoop.sh"],
check=True,
)

typer.secho("Running server installation", fg=c.GREEN)
base_cmd = _build_docker_cmd("server", tty=False)
subprocess.run(
base_cmd + ["bash", "/home/dirac/LocalRepo/TestCode/DIRAC/tests/CI/install_server.sh"],
check=True,
)

typer.secho("Copying credentials and certificates", fg=c.GREEN)
base_cmd = _build_docker_cmd("client", tty=False)
subprocess.run(
Expand Down
10 changes: 10 additions & 0 deletions src/DIRAC/Core/scripts/dirac_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self):
self.outputFile = ""
self.skipVOMSDownload = False
self.extensions = ""
self.legacyExchangeApiKey = ""

def setGateway(self, optionValue):
self.gatewayServer = optionValue
Expand Down Expand Up @@ -174,6 +175,12 @@ def setIssuer(self, optionValue):
DIRAC.gConfig.setOptionValue("/DIRAC/Security/Authorization/issuer", self.issuer)
return DIRAC.S_OK()

def setLegacyExchangeApiKey(self, optionValue):
self.legacyExchangeApiKey = optionValue
Script.localCfg.addDefaultEntry("/DiracX/LegacyExchangeApiKey", self.legacyExchangeApiKey)
DIRAC.gConfig.setOptionValue(cfgInstallPath("LegacyExchangeApiKey"), self.legacyExchangeApiKey)
return DIRAC.S_OK()


def _runConfigurationWizard(setups, defaultSetup):
"""The implementation of the configuration wizard"""
Expand Down Expand Up @@ -361,6 +368,9 @@ def runDiracConfigure(params):
Script.registerSwitch("n:", "SiteName=", "Set <sitename> as DIRAC Site Name", params.setSiteName)
Script.registerSwitch("N:", "CEName=", "Set <cename> as Computing Element name", params.setCEName)
Script.registerSwitch("V:", "VO=", "Set the VO name", params.setVO)
Script.registerSwitch(
"K:", "LegacyExchangeApiKey=", "Set the Api Key to talk to DiracX", params.setLegacyExchangeApiKey
)

Script.registerSwitch("W:", "gateway=", "Configure <gateway> as DIRAC Gateway for the site", params.setGateway)

Expand Down
64 changes: 32 additions & 32 deletions src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
:caption: ProxyManager options
"""

from DIRAC import gLogger, S_OK, S_ERROR
import os
import requests
from DIRAC import gLogger, S_OK, S_ERROR, gConfig
from DIRAC.Core.DISET.RequestHandler import RequestHandler, getServiceOption
from DIRAC.Core.Security import Properties
from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader
Expand Down Expand Up @@ -411,39 +413,37 @@ def export_getVOMSProxyWithToken(self, userDN, userGroup, requestPem, requiredLi

def export_exchangeProxyForToken(self):
"""Exchange a proxy for an equivalent token to be used with diracx"""

apiKey = gConfig.getValue("/DiracX/LegacyExchangeApiKey")
if not apiKey:
return S_ERROR("Missing mandatory /DiracX/LegacyExchangeApiKey configuration")

diracxUrl = gConfig.getValue("/DiracX/URL")
if not diracxUrl:
return S_ERROR("Missing mandatory /DiracX/URL configuration")

credDict = self.getRemoteCredentials()
vo = Registry.getVOForGroup(credDict["group"])
dirac_properties = list(set(credDict.get("groupProperties", [])) | set(credDict.get("properties", [])))
group = credDict["group"]
scopes = [f"vo:{vo}", f"group:{group}"] + [f"property:{prop}" for prop in dirac_properties]

try:
from diracx.routers.auth import ( # pylint: disable=import-error
AuthSettings,
create_token,
TokenResponse,
) # pylint: disable=import-error

authSettings = AuthSettings()

from uuid import uuid4

credDict = self.getRemoteCredentials()
vo = Registry.getVOForGroup(credDict["group"])
payload = {
"sub": f"{vo}:{credDict['username']}",
"vo": vo,
"aud": authSettings.token_audience,
"iss": authSettings.token_issuer,
"dirac_properties": list(
set(credDict.get("groupProperties", [])) | set(credDict.get("properties", []))
),
"jti": str(uuid4()),
"preferred_username": credDict["username"],
"dirac_group": credDict["group"],
}
return S_OK(
TokenResponse(
access_token=create_token(payload, authSettings),
expires_in=authSettings.access_token_expire_minutes * 60,
).dict()
r = requests.get(
f"{diracxUrl}/auth/legacy-exchange",
params={
"preferred_username": credDict["username"],
"scope": " ".join(scopes),
},
headers={"Authorization": f"Bearer {apiKey}"},
)
except Exception as e:
return S_ERROR(f"Could not get token: {e!r}")
except requests.exceptions.RequestException as exc:
return S_ERROR(f"Failed to contact DiracX: {exc}")
else:
if not r.ok:
return S_ERROR(f"Failed to contact DiracX: {r.status_code} {r.text}")

return S_OK(r.json())


class ProxyManagerHandler(ProxyManagerHandlerMixin, RequestHandler):
Expand Down
77 changes: 77 additions & 0 deletions src/DIRAC/FrameworkSystem/Utilities/diracx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# pylint: disable=import-error

import requests

from cachetools import TTLCache, cached
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
from DIRAC import gConfig
from DIRAC.ConfigurationSystem.Client.Helpers import Registry


from diracx.core.preferences import DiracxPreferences

from diracx.core.utils import write_credentials

from diracx.core.models import TokenResponse
from diracx.client import DiracClient

# How long tokens are kept
DEFAULT_TOKEN_CACHE_TTL = 5 * 60

# Add a cache not to query the token all the time
_token_cache = TTLCache(maxsize=100, ttl=DEFAULT_TOKEN_CACHE_TTL)


@cached(_token_cache, key=lambda x, y: repr(x))
def _get_token(credDict, diracxUrl, /) -> Path:
"""
Write token to a temporary file and return the path to that file
"""

apiKey = gConfig.getValue("/DiracX/LegacyExchangeApiKey")
if not apiKey:
raise ValueError("Missing mandatory /DiracX/LegacyExchangeApiKey configuration")

vo = Registry.getVOForGroup(credDict["group"])
dirac_properties = list(set(credDict.get("groupProperties", [])) | set(credDict.get("properties", [])))
group = credDict["group"]

scopes = [f"vo:{vo}", f"group:{group}"] + [f"property:{prop}" for prop in dirac_properties]

r = requests.get(
f"{diracxUrl}/auth/legacy-exchange",
params={
"preferred_username": credDict["username"],
"scope": " ".join(scopes),
},
headers={"Authorization": f"Bearer {apiKey}"},
timeout=10,
)

r.raise_for_status()

token_location = Path(NamedTemporaryFile().name)

write_credentials(TokenResponse(**r.json()), location=token_location)

return token_location


def TheImpersonator(credDict: dict[str, Any]) -> DiracClient:
"""
Client to be used by DIRAC server needing to impersonate
a user for diracx.
It queries a token, places it in a file, and returns the `DiracClient`
class
Use as a context manager
"""

diracxUrl = gConfig.getValue("/DiracX/URL")
token_location = _get_token(credDict, diracxUrl)
pref = DiracxPreferences(url=diracxUrl, credentials_path=token_location)

return DiracClient(diracx_preferences=pref)
11 changes: 10 additions & 1 deletion src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,27 @@ def doTheMagic(self):
if os.getenv("DIRAC_ENABLE_DIRACX_LOGIN", "No").lower() in ("yes", "true"):
from diracx.core.utils import write_credentials # pylint: disable=import-error
from diracx.core.models import TokenResponse # pylint: disable=import-error
from diracx.core.preferences import DiracxPreferences # pylint: disable=import-error

res = Client(url="Framework/ProxyManager").exchangeProxyForToken()
if not res["OK"]:
return res
from DIRAC import gConfig

diracxUrl = gConfig.getValue("/DiracX/URL")
if not diracxUrl:
return S_ERROR("Missing mandatory /DiracX/URL configuration")

token_content = res["Value"]
preferences = DiracxPreferences(url=diracxUrl)
write_credentials(
TokenResponse(
access_token=token_content["access_token"],
expires_in=token_content["expires_in"],
token_type=token_content.get("token_type"),
refresh_token=token_content.get("refresh_token"),
)
),
location=preferences.credentials_path,
)

return S_OK()
Expand Down
4 changes: 1 addition & 3 deletions src/DIRAC/Resources/Storage/StorageBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,7 @@ def constructURLFromLFN(self, lfn, withWSUrl=False):
# 2. VO name must not appear as any subdirectory or file name
lfnSplitList = lfn.split("/")
voLFN = lfnSplitList[1]
# TODO comparison to Sandbox below is for backward compatibility, should
# be removed in the next release
if voLFN != self.se.vo and voLFN != "SandBox" and voLFN != "Sandbox":
if voLFN != self.se.vo and voLFN != "SandBox" and voLFN != "S3":
return S_ERROR(f"LFN ({lfn}) path must start with VO name ({self.se.vo})")

urlDict = dict(self.protocolParameters)
Expand Down
2 changes: 2 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/ConfigTemplate.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ Services
SandboxPrefix = Sandbox
BasePath = /opt/dirac/storage/sandboxes
DelayedExternalDeletion = True
# If true, uploads the sandbox via diracx on an S3 storage
UseDiracXBackend = False
Authorization
{
Default = authenticated
Expand Down
68 changes: 68 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/Service/SandboxStoreHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
import hashlib
import os
import requests
import tempfile
import threading
import time
Expand Down Expand Up @@ -49,6 +50,7 @@ def initializeHandler(cls, serviceInfoDict):
def initializeRequest(self):
self.__backend = self.getCSOption("Backend", "local")
self.__localSEName = self.getCSOption("LocalSE", "SandboxSE")
self._useDiracXBackend = self.getCSOption("UseDiracXBackend", False)
self._maxUploadBytes = self.getCSOption("MaxSandboxSizeMiB", 10) * 1048576
if self.__backend.lower() == "local" or self.__backend == self.__localSEName:
self.__useLocalStorage = True
Expand Down Expand Up @@ -106,6 +108,51 @@ def _getFromClient(self, fileId, token, fileSize, fileHelper=None, data=""):
gLogger.info("Upload requested", f"for {aHash} [{extension}]")

credDict = self.getRemoteCredentials()

if self._useDiracXBackend:
from DIRAC.FrameworkSystem.Utilities.diracx import TheImpersonator
from diracx.client.models import SandboxInfo # pylint: disable=import-error

gLogger.info("Forwarding to DiracX")
with tempfile.TemporaryFile(mode="w+b") as tar_fh:
result = fileHelper.networkToDataSink(tar_fh, maxFileSize=self._maxUploadBytes)
if not result["OK"]:
return result
tar_fh.seek(0)

hasher = hashlib.sha256()
while data := tar_fh.read(512 * 1024):
hasher.update(data)
checksum = hasher.hexdigest()
tar_fh.seek(0)
gLogger.debug("Sandbox checksum is", checksum)

sandbox_info = SandboxInfo(
checksum_algorithm="sha256",
checksum=checksum,
size=os.stat(tar_fh.fileno()).st_size,
format=extension,
)

with TheImpersonator(credDict) as client:
res = client.jobs.initiate_sandbox_upload(sandbox_info)

if res.url:
gLogger.debug("Uploading sandbox for", res.pfn)
files = {"file": ("file", tar_fh)}

response = requests.post(res.url, data=res.fields, files=files, timeout=300)

gLogger.debug("Sandbox uploaded", f"for {res.pfn} with status code {response.status_code}")
# TODO: Handle this error better
try:
response.raise_for_status()
except Exception as e:
return S_ERROR("Error uploading sandbox", repr(e))
else:
gLogger.debug("Sandbox already exists in storage backend", res.pfn)
return S_OK(res.pfn)

sbPath = self.__getSandboxPath(f"{aHash}.{extension}")
# Generate the location
result = self.__generateLocation(sbPath)
Expand Down Expand Up @@ -431,6 +478,27 @@ def _sendToClient(self, fileID, token, fileHelper=None, raw=False):
credDict = self.getRemoteCredentials()
serviceURL = self.serviceInfoDict["URL"]
filePath = fileID.replace(serviceURL, "")

# If the PFN starts with S3, we know it has been uploaded to the
# S3 sandbox store, so download it from there before sending it
if filePath.startswith("/S3"):
from DIRAC.FrameworkSystem.Utilities.diracx import TheImpersonator

with TheImpersonator(credDict) as client:
res = client.jobs.get_sandbox_file(pfn=filePath)
r = requests.get(res.url)
r.raise_for_status()
sbData = r.content
if fileHelper:
from io import BytesIO

result = fileHelper.DataSourceToNetwork(BytesIO(sbData))
# fileHelper.oFile.close()
return result
if raw:
return sbData
return S_OK(sbData)

result = self.sandboxDB.getSandboxId(self.__localSEName, filePath, credDict["username"], credDict["group"])
if not result["OK"]:
return result
Expand Down
Loading

0 comments on commit da42574

Please sign in to comment.