Skip to content

Commit

Permalink
some types
Browse files Browse the repository at this point in the history
  • Loading branch information
rcarpa committed Jan 18, 2024
1 parent f21e8a1 commit d506518
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 19 deletions.
5 changes: 4 additions & 1 deletion lib/rucio/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, TypedDict, Union
from typing import Any, Callable, Optional, TypedDict, Union


class InternalType(object):
Expand Down Expand Up @@ -104,6 +104,9 @@ def __init__(self, scope, vo='def', fromExternal=True):
super(InternalScope, self).__init__(value=scope, vo=vo, fromExternal=fromExternal)


LoggerFunction = Callable[..., Any]


class RSEDomainLANDict(TypedDict):
read: Optional[int]
write: Optional[int]
Expand Down
2 changes: 1 addition & 1 deletion lib/rucio/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def merkle_sha256(file) -> str:
CHECKSUM_ALGO_DICT['merkle_sha256'] = merkle_sha256


def bencode(obj):
def bencode(obj) -> bytes:
"""
Copied from the reference implementation of v2 bittorrent:
http://bittorrent.org/beps/bep_0052_torrent_creator.py
Expand Down
14 changes: 8 additions & 6 deletions lib/rucio/transfertool/bittorrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from os import path
from typing import TYPE_CHECKING, Optional, Sequence, Type

from rucio.common import types
from rucio.common.config import config_get
from rucio.common.extra import import_extras
from rucio.common.utils import construct_torrent
Expand All @@ -27,6 +28,7 @@

if TYPE_CHECKING:
from rucio.core.rse import RseData
from rucio.core import transfer as transfer_core

DRIVER_NAME_RSE_ATTRIBUTE = 'bittorrent_driver'
DRIVER_CLASSES_BY_NAME: dict[str, Type[BittorrentDriver]] = {}
Expand All @@ -46,7 +48,7 @@ class BittorrentTransfertool(Transfertool):

required_rse_attrs = (DRIVER_NAME_RSE_ATTRIBUTE, )

def __init__(self, external_host, logger=logging.log):
def __init__(self, external_host: str, logger: types.LoggerFunction = logging.log):
super().__init__(external_host=external_host, logger=logger)

self._drivers_by_rse_id = {}
Expand All @@ -55,7 +57,7 @@ def __init__(self, external_host, logger=logging.log):
self.tracker = config_get('transfers', 'bittorrent-tracker-addr', raise_exception=False, default=None)

@classmethod
def _pick_management_api_driver_cls(cls, rse) -> Optional[Type[BittorrentDriver]]:
def _pick_management_api_driver_cls(cls, rse: "RseData") -> Optional[Type[BittorrentDriver]]:
driver_cls = DRIVER_CLASSES_BY_NAME.get(rse.attributes.get(DRIVER_NAME_RSE_ATTRIBUTE))
if driver_cls is None:
return None
Expand All @@ -77,15 +79,15 @@ def _driver_for_rse(self, rse: "RseData") -> Optional[BittorrentDriver]:
return driver

@staticmethod
def _get_torrent_meta(scope, name):
def _get_torrent_meta(scope: "types.InternalScope", name: str):
meta = get_metadata(scope=scope, name=name, plugin='all')
pieces_root = base64.b64decode(meta.get('bittorrent_pieces_root', ''))
pieces_layers = base64.b64decode(meta.get('bittorrent_pieces_layers', ''))
piece_length = meta.get('bittorrent_piece_length', 0)
return pieces_root, pieces_layers, piece_length

@classmethod
def submission_builder_for_path(cls, transfer_path, logger=logging.log):
def submission_builder_for_path(cls, transfer_path: "list[transfer_core.DirectTransferDefinitions]", logger=logging.log):
hop = transfer_path[0]
if hop.rws.byte_count == 0:
logger(logging.INFO, f"Bittorrent cannot transfer fully empty torrents. Skipping {hop}")
Expand All @@ -108,7 +110,7 @@ def submission_builder_for_path(cls, transfer_path, logger=logging.log):

return [hop], TransferToolBuilder(cls, external_host='Bittorrent Transfertool')

def group_into_submit_jobs(self, transfer_paths):
def group_into_submit_jobs(self, transfer_paths: "Sequence[list[transfer_core.DirectTransferDefinitions]]"):
return [{'transfers': transfer_path, 'job_params': {}} for transfer_path in transfer_paths]

@staticmethod
Expand All @@ -120,7 +122,7 @@ def _connect_without_tracker(torrent_id, peers_drivers: Sequence[BittorrentDrive
for driver in peers_drivers:
driver.add_peers(torrent_id=torrent_id, peers=peer_addr)

def submit(self, transfers, job_params, timeout=None):
def submit(self, transfers: "Sequence[transfer_core.DirectTransferDefinitions]", job_params, timeout=None):
[transfer] = transfers
rws = transfer.rws

Expand Down
9 changes: 6 additions & 3 deletions lib/rucio/transfertool/bittorrent_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from typing import TYPE_CHECKING, Sequence

if TYPE_CHECKING:
from typing import Optional

from rucio.core.rse import RseData
from rucio.transfertool.transfertool import TransferStatusReport


class BittorrentDriver(metaclass=ABCMeta):
Expand All @@ -27,7 +30,7 @@ class BittorrentDriver(metaclass=ABCMeta):

@classmethod
@abstractmethod
def make_driver(cls, rse: "RseData", logger=logging.log):
def make_driver(cls, rse: "RseData", logger=logging.log) -> "Optional[BittorrentDriver]":
pass

@abstractmethod
Expand All @@ -39,13 +42,13 @@ def management_addr(self) -> tuple[str, int]:
pass

@abstractmethod
def add_torrent(self, file_name: str, file_content: bytes, download_location: str, seed_mode: bool = False):
def add_torrent(self, file_name: str, file_content: bytes, download_location: str, seed_mode: bool = False) -> None:
pass

@abstractmethod
def add_peers(self, torrent_id: str, peers: Sequence[tuple[str, int]]):
pass

@abstractmethod
def get_status(self, request_id: str, torrent_id: str):
def get_status(self, request_id: str, torrent_id: str) -> "TransferStatusReport":
pass
18 changes: 10 additions & 8 deletions lib/rucio/transfertool/bittorrent_driver_qbittorrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import qbittorrentapi

from rucio.common import types
from rucio.common.config import get_rse_credentials
from rucio.common.utils import resolve_ip
from rucio.core.oidc import request_token
Expand All @@ -27,6 +28,7 @@
from .bittorrent_driver import BittorrentDriver

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from rucio.core.rse import RseData


Expand All @@ -37,7 +39,7 @@ class QBittorrentTransferStatusReport(TransferStatusReport):
'external_id',
]

def __init__(self, request_id, external_id, qbittorrent_response: Optional[qbittorrentapi.TorrentDictionary]):
def __init__(self, request_id: str, external_id: str, qbittorrent_response: Optional[qbittorrentapi.TorrentDictionary]):
super().__init__(request_id)

if qbittorrent_response and qbittorrent_response.state_enum.is_complete == 1:
Expand All @@ -50,10 +52,10 @@ def __init__(self, request_id, external_id, qbittorrent_response: Optional[qbitt
if new_state in [RequestState.FAILED, RequestState.DONE]:
self.external_id = external_id

def initialize(self, session, logger=logging.log):
def initialize(self, session: "Session", logger: types.LoggerFunction = logging.log) -> None:
pass

def get_monitor_msg_fields(self, session, logger=logging.log):
def get_monitor_msg_fields(self, session: "Session", logger: types.LoggerFunction = logging.log):
return {'protocol': 'qbittorrent'}


Expand All @@ -63,7 +65,7 @@ class QBittorrentDriver(BittorrentDriver):
required_rse_attrs = ('qbittorrent_management_address', )

@classmethod
def make_driver(cls, rse: "RseData", logger=logging.log):
def make_driver(cls, rse: "RseData", logger=logging.log) -> "Optional[BittorrentDriver]":

address = rse.attributes.get('qbittorrent_management_address')
if not address:
Expand Down Expand Up @@ -91,7 +93,7 @@ def make_driver(cls, rse: "RseData", logger=logging.log):
logger=logger,
)

def __init__(self, address, username, password, token=None, logger=logging.log):
def __init__(self, address: str, username: str, password: str, token: Optional[str] = None, logger: types.LoggerFunction = logging.log):
extra_headers = None
if token:
extra_headers = {'Authorization': 'Bearer ' + token}
Expand All @@ -114,7 +116,7 @@ def listen_addr(self) -> tuple[str, int]:
def management_addr(self) -> tuple[str, int]:
return self.client.host, self.client.port

def add_torrent(self, file_name: str, file_content: bytes, download_location: str, seed_mode: bool = False):
def add_torrent(self, file_name: str, file_content: bytes, download_location: str, seed_mode: bool = False) -> None:
self.client.torrents_add(
rename=file_name,
torrent_files=file_content,
Expand All @@ -123,9 +125,9 @@ def add_torrent(self, file_name: str, file_content: bytes, download_location: st
is_sequential_download=True,
)

def add_peers(self, torrent_id: str, peers: Sequence[tuple[str, int]]):
def add_peers(self, torrent_id: str, peers: Sequence[tuple[str, int]]) -> None:
self.client.torrents_add_peers(torrent_hashes=[torrent_id], peers=[f'{ip}:{port}' for ip, port in peers])

def get_status(self, request_id: str, torrent_id: str):
def get_status(self, request_id: str, torrent_id: str) -> TransferStatusReport:
info = self.client.torrents_info(torrent_hashes=[torrent_id])
return QBittorrentTransferStatusReport(request_id, external_id=torrent_id, qbittorrent_response=info[0] if info else None)

0 comments on commit d506518

Please sign in to comment.