Skip to content

Commit

Permalink
Put back some additional settings as optional
Browse files Browse the repository at this point in the history
  • Loading branch information
MattMonk committed Aug 14, 2024
1 parent 2f3eb63 commit 21156a4
Showing 1 changed file with 148 additions and 50 deletions.
198 changes: 148 additions & 50 deletions snakemake_storage_plugin_xrootd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
import os
import re
from typing import Any, Iterable, Optional, List, Type
from typing import Any, Iterable, Optional, List, Type, Callable

from reretry import retry

from XRootD import client
from XRootD.client.flags import MkDirFlags, StatInfoFlags
from XRootD.client.responses import XRootDStatus
from XRootD.client import URL

from snakemake_interface_common.logging import get_logger
from snakemake_interface_storage_plugins.settings import StorageProviderSettingsBase
Expand Down Expand Up @@ -48,11 +48,55 @@ def _raise_fatal_error(exception: Type[Exception]):

@dataclass
class StorageProviderSettings(StorageProviderSettingsBase):
pass
host: Optional[str] = field(
default=None,
metadata={
"help": "The XrootD host to connect to",
"env_var": False,
"required": False,
},
)
port: Optional[int] = field(
default=None,
metadata={
"help": "The XrootD port to connect to",
"env_var": False,
"required": False,
},
)
username: Optional[str] = field(
default=None,
metadata={
"help": "The username to use for authentication",
"env_var": True,
"required": False,
},
)
password: Optional[str] = field(
default=None,
metadata={
"help": "The password to use for authentication",
"env_var": True,
"required": False,
},
)
url_decorator: Optional[Callable] = field(
default=lambda x: x,
metadata={
"help": "A function to decorate the URL e.g. wrapping an auth token.",
"env_var": False,
"required": False,
},
)


class StorageProvider(StorageProviderBase):
def __post_init__(self):
self.username = self.settings.username
self.password = self.settings.password
self.host = self.settings.host
self.port = self.settings.port
self.url_decorator = self.settings.url_decorator
# List of error codes that there is no point in retrying
self.no_retry_codes = [
3000,
Expand Down Expand Up @@ -86,9 +130,15 @@ def example_queries(cls) -> List[ExampleQuery]:
return [
ExampleQuery(
query="root://eosuser.cern.ch//eos/user/s/someuser/somefile.txt",
description="A file on an XrootD instance.",
description="A file on a XrootD instance not specifying any arguments.",
type=QueryType.ANY,
)
),
ExampleQuery(
query="root://eos/user/s/someuser/somefile.txt",
description="A file on an XrootD instance where the host has been"
"specified in the storage object.",
type=QueryType.ANY,
),
]

# ?
Expand All @@ -111,33 +161,77 @@ def use_rate_limiter(self) -> bool:
return True

@staticmethod
def _parse_url(query: str) -> List[str] | None:
match = re.search(
r"(?P<domain>(?:[A-Za-z]+://)[A-Za-z0-9:@\_\-\.\{\}]+\:?/)(?P<path>.+)",
query,
)
if match is None:
return None
def _no_pass_url(url: str) -> str:
new_url = URL(url)
if new_url.password == "":
return str(new_url)
else:
return str(new_url).replace(new_url.password, "****")

domain = match.group("domain")
dirname, filename = os.path.split(match.group("path"))
@staticmethod
def _no_params_url(url: str) -> str:
new_url = URL(url)
return url.replace(new_url.path_with_params, f"{new_url.path}?****")

@staticmethod
def _safe_to_print_url(url: str) -> str:
return StorageProvider._no_params_url(StorageProvider._no_pass_url(url))

def _parse_url(self, query: str) -> List[str] | None:
url = URL(query)
user = self.username or url.username
password = self.password or url.password
host = self.host or url.hostname
port = self.port or url.port
if user != "":
if password != "":
user_pass = f"{user}:{password}@"
else:
user_pass = f"{user}@"
else:
if password != "":
raise IOError(
"XRootD Error: Cannot specify a password without specifying a user"
)
user_pass = ""

# The XRootD parsing does not understand the host not being there
if self.host is not None and self.host != url.hostname:
full_path = f"{url.hostname}/{url.path_with_params}"
else:
full_path = url.path_with_params
new_url = f"{url.protocol}://{user_pass}{host}:{port}//{full_path}"
dec_url = self.url_decorator(new_url)
full_url = URL(dec_url)
if not full_url.is_valid():
if URL(new_url).is_valid():
raise IOError(
f"XRootD Error: URL {self._safe_to_print_url(dec_url)} was made"
"invalid when applying the url_decorator"
)
else:
raise IOError(
f"XRootD Error: URL {self._safe_to_print_url(new_url)} is invalid"
)

dirname, filename = os.path.split(full_url.path)
# We need a trailing / to keep XRootD happy
dirname += "/"

# XRootD also needs absoulte paths
if not dirname.startswith("/"):
dirname = f"/{dirname}"

return domain, dirname, filename
return full_url, dirname, filename

@classmethod
def is_valid_query(cls, query: str) -> StorageQueryValidationResult:
"""Return whether the given query is valid for this storage provider."""
# Ensure that also queries containing wildcards (e.g. {sample}) are accepted
# and considered valid. The wildcards will be resolved before the storage
# object is actually used.
parsed_query = cls._parse_url(query)
if parsed_query is None:
url = URL(query)
if not url.is_valid():
return StorageQueryValidationResult(
valid=False,
reason="Malformed XRootD url",
Expand All @@ -160,9 +254,9 @@ class StorageObject(StorageObjectRead, StorageObjectWrite):

def __post_init__(self):
# Does is_valid_query happen before this or we need to verify here too?
self.domain, self.dirname, self.filename = self.provider._parse_url(self.query)
self.path = os.path.join(self.dirname, self.filename)
self.file_system = client.FileSystem(self.domain)
self.url, self.dirname, self.filename = self.provider._parse_url(self.query)
self.path = self.url.path
self.file_system = client.FileSystem(self.url.hostid)
self.keep_local = self.provider.settings.keep_local
self.retrieve = self.provider.settings.retrieve

Expand All @@ -183,12 +277,12 @@ async def inventory(self, cache: IOCacheStorageInterface):
def get_inventory_parent(self) -> Optional[str]:
"""Return the parent directory of this object."""
# this is optional and can be left as is
return self.domain + self.dirname
return str(self.url).replace(self.filename, "")

def local_suffix(self) -> str:
"""Return a unique suffix for the local path, determined from self.query."""
# path always has a '/' at the end which we do not want here
return str(self.path)[1:]
return str(self.path)[2:]

# Check but should be nothing?
def cleanup(self):
Expand All @@ -204,61 +298,59 @@ def exists(self) -> bool:
def _exists(self, query) -> bool:
# we split up the query again so that this can be re-used to check the
# existence of other files e.g. the parent directory.
domain, dirname, filename = self.provider._parse_url(query)
status, stat_info = self.file_system.stat(os.path.join(dirname, filename))
url, dirname, filename = self.provider._parse_url(query)
status, stat_info = self.file_system.stat(url.path)
# a bit special, 3011 == file not found
if not status.ok:
if status.errno == 3011:
return False
self.provider._check_status(
status, f"Error checking existence of {query} on XRootD"
status,
f"Error checking existence of {self.provider._safe_to_print_url(query)}"
"on XRootD",
)
return True

@xrootd_retry
def mtime(self) -> float:
# return the modification time
status, stat = self.file_system.stat(self.path)
self.provider._check_status(status, f"Error checking info of {self.query}")
self.provider._check_status(
status,
f"Error checking info of {self.provider._safe_to_print_url(self.query)}",
)
return stat.modtime

@xrootd_retry
def size(self) -> int:
# return the size in bytes
status, stat = self.file_system.stat(self.path)
self.provider._check_status(status, f"Error checking info of {self.query}")
self.provider._check_status(
status,
f"Error checking info of {self.provider._safe_to_print_url(self.query)}",
)
return stat.size

@xrootd_retry
# @xrootd_retry
def retrieve_object(self):
# Ensure that the object is accessible locally under self.local_path()
# check if dir

process = client.CopyProcess()

# TODO is special handling for directories needed?
# if stat.flags & client.flags.StatInfoFlags.IS_DIR:
# self.local_path().mkdir(parents=True, exist_ok=True)
# _, listing = self.provider.filesystem_client.dirlist(
# self.path, flags=client.flags.DirListFlags.STAT)
# for item in listing:
# item_path = f"{self.path}/{item.name}"
# if item.statinfo.flags & client.flags.StatInfoFlags.IS_DIR:

# else:
# process.add_job(self._get_url(item_path),
# str(self.local_path() / item_path), force=True)
# else:

# local path must be an absoulte path as well
local_path = os.path.abspath(self.local_path())
process.add_job(self.query, local_path, force=True)
process.add_job(str(self.url), local_path, force=True)

process.prepare()
status, returns = process.run()
self.provider._check_status(status, f"Error downloading from {self.query}")
self.provider._check_status(
returns[0]["status"], f"Error downloading from {self.query}"
status,
f"Error downloading from {self.provider._safe_to_print_url(self.query)}",
)
self.provider._check_status(
returns[0]["status"],
f"Error downloading from {self.provider._safe_to_print_url(self.query)}",
)

# The following to methods are only required if the class inherits from
Expand All @@ -267,9 +359,12 @@ def retrieve_object(self):
@xrootd_retry
def _makedirs(self):
if not self._exists(self.get_inventory_parent()):
raise XRootDFatalException()
status, _ = self.file_system.mkdir(self.dirname, MkDirFlags.MAKEPATH)
self.provider._check_status(
status, f"Error creating directory {self.query}"
status,
"Error creating directory "
f"{self.provider._safe_to_print_url(self.query)}",
)

@xrootd_retry
Expand All @@ -279,12 +374,15 @@ def store_object(self):
process = client.CopyProcess()
self._makedirs()
local_path = os.path.abspath(self.local_path())
process.add_job(local_path, self.query, force=True)
process.add_job(local_path, str(self.url), force=True)
process.prepare()
status, returns = process.run()
self.provider._check_status(status, f"Error uploading to {self.query}")
self.provider._check_status(
returns[0]["status"], f"Error uploading to {self.query}"
status, f"Error uploading to {self.provider._safe_to_print_url(self.query)}"
)
self.provider._check_status(
returns[0]["status"],
f"Error uploading to {self.provider._safe_to_print_url(self.query)}",
)

@xrootd_retry
Expand Down

0 comments on commit 21156a4

Please sign in to comment.