diff --git a/snakemake_storage_plugin_xrootd/__init__.py b/snakemake_storage_plugin_xrootd/__init__.py index a122a94..4f49e94 100644 --- a/snakemake_storage_plugin_xrootd/__init__.py +++ b/snakemake_storage_plugin_xrootd/__init__.py @@ -1,13 +1,13 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import os -import re -from typing import Any, Iterable, Optional, List, Type +from typing import Any, Iterable, Optional, List, Type, Callable from reretry import retry from XRootD import client from XRootD.client.flags import MkDirFlags, StatInfoFlags from XRootD.client.responses import XRootDStatus +from XRootD.client import URL from snakemake_interface_common.logging import get_logger from snakemake_interface_storage_plugins.settings import StorageProviderSettingsBase @@ -48,11 +48,55 @@ def _raise_fatal_error(exception: Type[Exception]): @dataclass class StorageProviderSettings(StorageProviderSettingsBase): - pass + host: Optional[str] = field( + default=None, + metadata={ + "help": "The XrootD host to connect to", + "env_var": False, + "required": False, + }, + ) + port: Optional[int] = field( + default=None, + metadata={ + "help": "The XrootD port to connect to", + "env_var": False, + "required": False, + }, + ) + username: Optional[str] = field( + default=None, + metadata={ + "help": "The username to use for authentication", + "env_var": True, + "required": False, + }, + ) + password: Optional[str] = field( + default=None, + metadata={ + "help": "The password to use for authentication", + "env_var": True, + "required": False, + }, + ) + url_decorator: Optional[Callable] = field( + default=lambda x: x, + metadata={ + "help": "A function to decorate the URL e.g. wrapping an auth token.", + "env_var": False, + "required": False, + }, + ) class StorageProvider(StorageProviderBase): def __post_init__(self): + self.username = self.settings.username + self.password = self.settings.password + self.host = self.settings.host + self.port = self.settings.port + self.url_decorator = self.settings.url_decorator # List of error codes that there is no point in retrying self.no_retry_codes = [ 3000, @@ -86,9 +130,15 @@ def example_queries(cls) -> List[ExampleQuery]: return [ ExampleQuery( query="root://eosuser.cern.ch//eos/user/s/someuser/somefile.txt", - description="A file on an XrootD instance.", + description="A file on a XrootD instance not specifying any arguments.", type=QueryType.ANY, - ) + ), + ExampleQuery( + query="root://eos/user/s/someuser/somefile.txt", + description="A file on an XrootD instance where the host has been" + "specified in the storage object.", + type=QueryType.ANY, + ), ] # ? @@ -111,16 +161,60 @@ def use_rate_limiter(self) -> bool: return True @staticmethod - def _parse_url(query: str) -> List[str] | None: - match = re.search( - r"(?P(?:[A-Za-z]+://)[A-Za-z0-9:@\_\-\.\{\}]+\:?/)(?P.+)", - query, - ) - if match is None: - return None + def _no_pass_url(url: str) -> str: + new_url = URL(url) + if new_url.password == "": + return str(new_url) + else: + return str(new_url).replace(new_url.password, "****") - domain = match.group("domain") - dirname, filename = os.path.split(match.group("path")) + @staticmethod + def _no_params_url(url: str) -> str: + new_url = URL(url) + return url.replace(new_url.path_with_params, f"{new_url.path}?****") + + @staticmethod + def _safe_to_print_url(url: str) -> str: + return StorageProvider._no_params_url(StorageProvider._no_pass_url(url)) + + def _parse_url(self, query: str) -> List[str] | None: + url = URL(query) + user = self.username or url.username + password = self.password or url.password + host = self.host or url.hostname + port = self.port or url.port + if user != "": + if password != "": + user_pass = f"{user}:{password}@" + else: + user_pass = f"{user}@" + else: + if password != "": + raise IOError( + "XRootD Error: Cannot specify a password without specifying a user" + ) + user_pass = "" + + # The XRootD parsing does not understand the host not being there + if self.host is not None and self.host != url.hostname: + full_path = f"{url.hostname}/{url.path_with_params}" + else: + full_path = url.path_with_params + new_url = f"{url.protocol}://{user_pass}{host}:{port}//{full_path}" + dec_url = self.url_decorator(new_url) + full_url = URL(dec_url) + if not full_url.is_valid(): + if URL(new_url).is_valid(): + raise IOError( + f"XRootD Error: URL {self._safe_to_print_url(dec_url)} was made" + "invalid when applying the url_decorator" + ) + else: + raise IOError( + f"XRootD Error: URL {self._safe_to_print_url(new_url)} is invalid" + ) + + dirname, filename = os.path.split(full_url.path) # We need a trailing / to keep XRootD happy dirname += "/" @@ -128,7 +222,7 @@ def _parse_url(query: str) -> List[str] | None: if not dirname.startswith("/"): dirname = f"/{dirname}" - return domain, dirname, filename + return full_url, dirname, filename @classmethod def is_valid_query(cls, query: str) -> StorageQueryValidationResult: @@ -136,8 +230,8 @@ def is_valid_query(cls, query: str) -> StorageQueryValidationResult: # Ensure that also queries containing wildcards (e.g. {sample}) are accepted # and considered valid. The wildcards will be resolved before the storage # object is actually used. - parsed_query = cls._parse_url(query) - if parsed_query is None: + url = URL(query) + if not url.is_valid(): return StorageQueryValidationResult( valid=False, reason="Malformed XRootD url", @@ -160,9 +254,9 @@ class StorageObject(StorageObjectRead, StorageObjectWrite): def __post_init__(self): # Does is_valid_query happen before this or we need to verify here too? - self.domain, self.dirname, self.filename = self.provider._parse_url(self.query) - self.path = os.path.join(self.dirname, self.filename) - self.file_system = client.FileSystem(self.domain) + self.url, self.dirname, self.filename = self.provider._parse_url(self.query) + self.path = self.url.path + self.file_system = client.FileSystem(self.url.hostid) self.keep_local = self.provider.settings.keep_local self.retrieve = self.provider.settings.retrieve @@ -183,12 +277,12 @@ async def inventory(self, cache: IOCacheStorageInterface): def get_inventory_parent(self) -> Optional[str]: """Return the parent directory of this object.""" # this is optional and can be left as is - return self.domain + self.dirname + return str(self.url).replace(self.filename, "") def local_suffix(self) -> str: """Return a unique suffix for the local path, determined from self.query.""" # path always has a '/' at the end which we do not want here - return str(self.path)[1:] + return str(self.path)[2:] # Check but should be nothing? def cleanup(self): @@ -204,14 +298,16 @@ def exists(self) -> bool: def _exists(self, query) -> bool: # we split up the query again so that this can be re-used to check the # existence of other files e.g. the parent directory. - domain, dirname, filename = self.provider._parse_url(query) - status, stat_info = self.file_system.stat(os.path.join(dirname, filename)) + url, dirname, filename = self.provider._parse_url(query) + status, stat_info = self.file_system.stat(url.path) # a bit special, 3011 == file not found if not status.ok: if status.errno == 3011: return False self.provider._check_status( - status, f"Error checking existence of {query} on XRootD" + status, + f"Error checking existence of {self.provider._safe_to_print_url(query)}" + "on XRootD", ) return True @@ -219,46 +315,42 @@ def _exists(self, query) -> bool: def mtime(self) -> float: # return the modification time status, stat = self.file_system.stat(self.path) - self.provider._check_status(status, f"Error checking info of {self.query}") + self.provider._check_status( + status, + f"Error checking info of {self.provider._safe_to_print_url(self.query)}", + ) return stat.modtime @xrootd_retry def size(self) -> int: # return the size in bytes status, stat = self.file_system.stat(self.path) - self.provider._check_status(status, f"Error checking info of {self.query}") + self.provider._check_status( + status, + f"Error checking info of {self.provider._safe_to_print_url(self.query)}", + ) return stat.size - @xrootd_retry + # @xrootd_retry def retrieve_object(self): # Ensure that the object is accessible locally under self.local_path() # check if dir process = client.CopyProcess() - # TODO is special handling for directories needed? - # if stat.flags & client.flags.StatInfoFlags.IS_DIR: - # self.local_path().mkdir(parents=True, exist_ok=True) - # _, listing = self.provider.filesystem_client.dirlist( - # self.path, flags=client.flags.DirListFlags.STAT) - # for item in listing: - # item_path = f"{self.path}/{item.name}" - # if item.statinfo.flags & client.flags.StatInfoFlags.IS_DIR: - - # else: - # process.add_job(self._get_url(item_path), - # str(self.local_path() / item_path), force=True) - # else: - # local path must be an absoulte path as well local_path = os.path.abspath(self.local_path()) - process.add_job(self.query, local_path, force=True) + process.add_job(str(self.url), local_path, force=True) process.prepare() status, returns = process.run() - self.provider._check_status(status, f"Error downloading from {self.query}") self.provider._check_status( - returns[0]["status"], f"Error downloading from {self.query}" + status, + f"Error downloading from {self.provider._safe_to_print_url(self.query)}", + ) + self.provider._check_status( + returns[0]["status"], + f"Error downloading from {self.provider._safe_to_print_url(self.query)}", ) # The following to methods are only required if the class inherits from @@ -267,9 +359,12 @@ def retrieve_object(self): @xrootd_retry def _makedirs(self): if not self._exists(self.get_inventory_parent()): + raise XRootDFatalException() status, _ = self.file_system.mkdir(self.dirname, MkDirFlags.MAKEPATH) self.provider._check_status( - status, f"Error creating directory {self.query}" + status, + "Error creating directory " + f"{self.provider._safe_to_print_url(self.query)}", ) @xrootd_retry @@ -279,12 +374,15 @@ def store_object(self): process = client.CopyProcess() self._makedirs() local_path = os.path.abspath(self.local_path()) - process.add_job(local_path, self.query, force=True) + process.add_job(local_path, str(self.url), force=True) process.prepare() status, returns = process.run() - self.provider._check_status(status, f"Error uploading to {self.query}") self.provider._check_status( - returns[0]["status"], f"Error uploading to {self.query}" + status, f"Error uploading to {self.provider._safe_to_print_url(self.query)}" + ) + self.provider._check_status( + returns[0]["status"], + f"Error uploading to {self.provider._safe_to_print_url(self.query)}", ) @xrootd_retry