Skip to content

Commit

Permalink
Move auth logic to Bfabric class (#79)
Browse files Browse the repository at this point in the history
- Make managing the `BfabricAuth` object the responsibility of the `Bfabric` class. This allows us to implement the contextmanager, which will be useful for the REST-proxy server.
  - Notably this shouldn't change anything about code that already uses the new class.
- Rename `BfabricConfig.with_overrides` to `BfabricConfig.copy_with` for clarity.
- Some unit tests for the Bfabric class, I will add more as I go.
  • Loading branch information
leoschwarz authored May 7, 2024
1 parent 76daee1 commit 2afcd0c
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 71 deletions.
37 changes: 25 additions & 12 deletions bfabric/bfabric2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
The python3 library first appeared in 2014.
"""
import os
import sys
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from enum import Enum
Expand All @@ -45,13 +47,12 @@ class BfabricAPIEngineType(Enum):
ZEEP = 2


def get_system_auth(login: str = None, password: str = None, base_url: str = None, externaljobid=None,
def get_system_auth(login: str = None, password: str = None, base_url: str = None,
config_path: str = None, config_env: str = None, optional_auth: bool = False, verbose: bool = False):
"""
:param login: Login string for overriding config file
:param password: Password for overriding config file
:param base_url: Base server url for overriding config file
:param externaljobid: ?
:param config_path: Path to the config file, in case it is different from default
:param config_env: Which config environment to use. Can also specify via environment variable or use
default in the config file (at your own risk)
Expand All @@ -78,7 +79,7 @@ def get_system_auth(login: str = None, password: str = None, base_url: str = Non
# Load config from file, override some of the fields with the provided ones
else:
config, auth = read_config(config_path, config_env=config_env, optional_auth=optional_auth)
config = config.with_overrides(base_url=base_url)
config = config.copy_with(base_url=base_url)
if (login is not None) and (password is not None):
auth = BfabricAuth(login=login, password=password)
elif (login is None) and (password is None):
Expand Down Expand Up @@ -117,10 +118,10 @@ def __init__(
self._auth = auth

if engine == BfabricAPIEngineType.SUDS:
self.engine = EngineSUDS(auth.login, auth.password, config.base_url)
self.engine = EngineSUDS(base_url=config.base_url)
self.result_type = BfabricResultType.LISTSUDS
elif engine == BfabricAPIEngineType.ZEEP:
self.engine = EngineZeep(auth.login, auth.password, config.base_url)
self.engine = EngineZeep(base_url=config.base_url)
self.result_type = BfabricResultType.LISTZEEP
else:
raise ValueError(f"Unexpected engine: {engine}")
Expand All @@ -142,11 +143,23 @@ def auth(self) -> BfabricAuth:
raise ValueError("Authentication not available")
return self._auth

@contextmanager
def with_auth(self, auth: BfabricAuth):
"""Context manager that temporarily (within the scope of the context) sets the authentication for
the Bfabric object to the provided value. This is useful when authenticating multiple users, to avoid accidental
use of the wrong credentials.
"""
old_auth = self._auth
self._auth = auth
try:
yield
finally:
self._auth = old_auth

def read(self, endpoint: str, obj: dict, max_results: Optional[int] = 100, readid: bool = False, check: bool = True,
**kwargs) -> ResultContainer:
"""
Make a read query to the engine. Determine the number of pages. Make calls for every page, concatenate
results.
"""Reads objects from the specified endpoint that match all specified attributes in `obj`.
By setting `max_results` it is possible to change the number of results that are returned.
:param endpoint: endpoint
:param obj: query dictionary
:param max_results: cap on the number of results to query. The code will keep reading pages until all pages
Expand Down Expand Up @@ -198,14 +211,14 @@ def read(self, endpoint: str, obj: dict, max_results: Optional[int] = 100, readi
return result

def save(self, endpoint: str, obj: dict, check: bool = True, **kwargs) -> ResultContainer:
results = self.engine.save(endpoint, obj, **kwargs)
results = self.engine.save(endpoint, obj, auth=self.auth, **kwargs)
result = ResultContainer(results[endpoint], self.result_type, errors=get_response_errors(results, endpoint))
if check:
result.assert_success()
return result

def delete(self, endpoint: str, id: Union[List, int], check: bool = True) -> ResultContainer:
results = self.engine.delete(endpoint, id)
results = self.engine.delete(endpoint, id, auth=self.auth)
result = ResultContainer(results[endpoint], self.result_type, errors=get_response_errors(results, endpoint))
if check:
result.assert_success()
Expand All @@ -214,9 +227,9 @@ def delete(self, endpoint: str, id: Union[List, int], check: bool = True) -> Res
def _read_method(self, readid: bool, endpoint: str, obj: dict, page: int = 1, **kwargs):
if readid:
# https://fgcz-bfabric.uzh.ch/wiki/tiki-index.php?page=endpoint.workunit#Web_Method_readid_
return self.engine.readid(endpoint, obj, page=page, **kwargs)
return self.engine.readid(endpoint, obj, auth=self.auth, page=page, **kwargs)
else:
return self.engine.read(endpoint, obj, page=page, **kwargs)
return self.engine.read(endpoint, obj, auth=self.auth, page=page, **kwargs)

############################
# Multi-query functionality
Expand Down
4 changes: 2 additions & 2 deletions bfabric/bfabric_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __str__(self):
class BfabricConfig:
"""Holds the configuration for the B-Fabric client for connecting to particular instance of B-Fabric.
Attributes:
Parameters:
base_url (optional): The API base url
application_ids (optional): Map of application names to ids.
job_notification_emails (optional): Space-separated list of email addresses to notify when a job finishes.
Expand Down Expand Up @@ -56,7 +56,7 @@ def job_notification_emails(self) -> str:
"""Space-separated list of email addresses to notify when a job finishes."""
return self._job_notification_emails

def with_overrides(
def copy_with(
self,
base_url: Optional[str] = None,
application_ids: Optional[Dict[str, int]] = None,
Expand Down
89 changes: 46 additions & 43 deletions bfabric/src/engine_suds.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,74 @@
from typing import Union, List
from __future__ import annotations

import copy
from typing import Any

from suds.client import Client
from suds import MethodNotFound
from suds.client import Client
from suds.serviceproxy import ServiceProxy

from bfabric.bfabric_config import BfabricAuth
from bfabric.src.errors import BfabricRequestError


class EngineSUDS:
"""B-Fabric API SUDS Engine"""

def __init__(self, login: str, password: str, base_url: str):
def __init__(self, base_url: str) -> None:
self.cl = {}
self.login = login
self.password = password
self.base_url = base_url

def _get_client(self, endpoint: str):
try:
if endpoint not in self.cl:
wsdl = "".join((self.base_url, '/', endpoint, "?wsdl"))
self.cl[endpoint] = Client(wsdl, cache=None)
return self.cl[endpoint]
except Exception as e:
print(e)
raise

def read(self, endpoint: str, obj: dict, page: int = 1, idonly: bool = False,
includedeletableupdateable: bool = False):
"""
A generic method which can connect to any endpoint, e.g., workunit, project, order,
externaljob, etc, and returns the object with the requested id.
obj is a python dictionary which contains all the attributes of the endpoint
for the "query".
def read(
self,
endpoint: str,
obj: dict[str, Any],
auth: BfabricAuth,
page: int = 1,
idonly: bool = False,
includedeletableupdateable: bool = False,
):
"""Reads the requested `obj` from `endpoint`.
:param endpoint: the endpoint to read, e.g. `workunit`, `project`, `order`, `externaljob`, etc.
:param obj: a python dictionary which contains all the attribute values that have to match
:param auth: the authentication handle of the user performing the request
:param page: the page number to read
:param idonly: whether to return only the ids of the objects
:param includedeletableupdateable: TODO
"""
query = copy.deepcopy(obj)
query['includedeletableupdateable'] = includedeletableupdateable

full_query = dict(login=self.login, page=page, password=self.password, query=query,
idonly=idonly)
query["includedeletableupdateable"] = includedeletableupdateable

client = self._get_client(endpoint)
return client.service.read(full_query)
full_query = dict(login=auth.login, page=page, password=auth.password, query=query, idonly=idonly)
service = self._get_suds_service(endpoint)
return service.read(full_query)

# TODO: How is client.service.readid different from client.service.read. Do we need this method?
def readid(self, endpoint: str, obj: dict, page: int = 1):
query = dict(login=self.login, page=page, password=self.password, query=obj)

client = self._get_client(endpoint)
return client.service.readid(query)

def save(self, endpoint: str, obj: dict):
query = {'login': self.login, 'password': self.password, endpoint: obj}

client = self._get_client(endpoint)
def readid(self, endpoint: str, obj: dict, auth: BfabricAuth, page: int = 1):
query = dict(login=auth.login, page=page, password=auth.password, query=obj)
service = self._get_suds_service(endpoint)
return service.readid(query)

def save(self, endpoint: str, obj: dict, auth: BfabricAuth):
query = {"login": auth.login, "password": auth.password, endpoint: obj}
service = self._get_suds_service(endpoint)
try:
res = client.service.save(query)
res = service.save(query)
except MethodNotFound as e:
raise BfabricRequestError(f"SUDS failed to find save method for the {endpoint} endpoint.") from e
return res

def delete(self, endpoint: str, id: Union[int, List]):
def delete(self, endpoint: str, id: int | list[int], auth: BfabricAuth):
if isinstance(id, list) and len(id) == 0:
print("Warning, attempted to delete an empty list, ignoring")
return []

query = {'login': self.login, 'password': self.password, 'id': id}
query = {"login": auth.login, "password": auth.password, "id": id}
service = self._get_suds_service(endpoint)
return service.delete(query)

client = self._get_client(endpoint)
return client.service.delete(query)
def _get_suds_service(self, endpoint: str) -> ServiceProxy:
"""Returns a SUDS service for the given endpoint. Reuses existing instances when possible."""
if endpoint not in self.cl:
wsdl = "".join((self.base_url, "/", endpoint, "?wsdl"))
self.cl[endpoint] = Client(wsdl, cache=None)
return self.cl[endpoint].service
19 changes: 9 additions & 10 deletions bfabric/src/engine_zeep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import zeep
import copy

from bfabric.bfabric_config import BfabricAuth
from bfabric.src.errors import BfabricRequestError


Expand All @@ -28,10 +29,8 @@ def _zeep_query_append_skipped(query: dict, skipped_keys: list, inplace: bool =
class EngineZeep:
"""B-Fabric API Zeep Engine"""

def __init__(self, login: str, password: str, base_url: str):
def __init__(self, base_url: str):
self.cl = {}
self.login = login
self.password = password
self.base_url = base_url

def _get_client(self, endpoint: str):
Expand All @@ -44,7 +43,7 @@ def _get_client(self, endpoint: str):
print(e)
raise

def read(self, endpoint: str, obj: dict, page: int = 1, idonly: bool = False,
def read(self, endpoint: str, obj: dict, auth: BfabricAuth, page: int = 1, idonly: bool = False,
includedeletableupdateable: bool = False):
query = copy.deepcopy(obj)
query['includedeletableupdateable'] = includedeletableupdateable
Expand All @@ -55,17 +54,17 @@ def read(self, endpoint: str, obj: dict, page: int = 1, idonly: bool = False,
'includechildren', 'includeparents', 'includereplacements']
_zeep_query_append_skipped(query, excl_keys, inplace=True, overwrite=False)

full_query = dict(login=self.login, page=page, password=self.password, query=query, idonly=idonly)
full_query = dict(login=auth.login, page=page, password=auth.password, query=query, idonly=idonly)

client = self._get_client(endpoint)
with client.settings(strict=False, xml_huge_tree=True, xsd_ignore_sequence_order=True):
return client.service.read(full_query)

def readid(self, endpoint: str, obj: dict, page: int = 1, includedeletableupdateable: bool = True):
def readid(self, endpoint: str, obj: dict, auth: BfabricAuth, page: int = 1, includedeletableupdateable: bool = True):
raise NotImplementedError("Attempted to use a method `readid` of Zeep, which does not exist")

def save(self, endpoint: str, obj: dict, skipped_keys: list = None):
query = {'login': self.login, 'password': self.password, endpoint: obj}
def save(self, endpoint: str, obj: dict, auth: BfabricAuth, skipped_keys: list = None):
query = {'login': auth.login, 'password': auth.password, endpoint: obj}

# If necessary, add skipped keys to the query
if skipped_keys is not None:
Expand All @@ -82,12 +81,12 @@ def save(self, endpoint: str, obj: dict, skipped_keys: list = None):
raise e
return res

def delete(self, endpoint: str, id: Union[int, List]):
def delete(self, endpoint: str, id: Union[int, List], auth: BfabricAuth):
if isinstance(id, list) and len(id) == 0:
print("Warning, attempted to delete an empty list, ignoring")
return []

query = {'login': self.login, 'password': self.password, 'id': id}
query = {'login': auth.login, 'password': auth.password, 'id': id}

client = self._get_client(endpoint)
return client.service.delete(query)
59 changes: 59 additions & 0 deletions bfabric/tests/unit/test_bfabric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest
from functools import cached_property
from unittest.mock import MagicMock

from bfabric import BfabricConfig
from bfabric.bfabric2 import BfabricAPIEngineType, Bfabric
from bfabric.src.engine_suds import EngineSUDS


class TestBfabric(unittest.TestCase):
def setUp(self):
self.mock_config = MagicMock(name="mock_config", spec=BfabricConfig)
self.mock_auth = None
self.mock_engine_type = BfabricAPIEngineType.SUDS
self.mock_engine = MagicMock(name="mock_engine", spec=EngineSUDS)

@cached_property
def mock_bfabric(self) -> Bfabric:
return Bfabric(config=self.mock_config, auth=self.mock_auth, engine=self.mock_engine_type)

def test_query_counter(self):
self.assertEqual(0, self.mock_bfabric.query_counter)

def test_config(self):
self.assertEqual(self.mock_config, self.mock_bfabric.config)

def test_auth_when_missing(self):
with self.assertRaises(ValueError) as error:
_ = self.mock_bfabric.auth
self.assertIn("Authentication not available", str(error.exception))

def test_auth_when_provided(self):
self.mock_auth = MagicMock(name="mock_auth")
self.assertEqual(self.mock_auth, self.mock_bfabric.auth)

def test_with_auth(self):
mock_old_auth = MagicMock(name="mock_old_auth")
mock_new_auth = MagicMock(name="mock_new_auth")
self.mock_auth = mock_old_auth
with self.mock_bfabric.with_auth(mock_new_auth):
self.assertEqual(mock_new_auth, self.mock_bfabric.auth)
self.assertEqual(mock_old_auth, self.mock_bfabric.auth)

def test_with_auth_when_exception(self):
mock_old_auth = MagicMock(name="mock_old_auth")
mock_new_auth = MagicMock(name="mock_new_auth")
self.mock_auth = mock_old_auth
try:
with self.mock_bfabric.with_auth(mock_new_auth):
raise ValueError("Test exception")
except ValueError:
pass
self.assertEqual(mock_old_auth, self.mock_bfabric.auth)

# TODO further unit tests


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions bfabric/tests/unit/test_bfabric_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_default_params_when_specified(self):
self.assertEqual({}, config.application_ids)
self.assertEqual("", config.job_notification_emails)

def test_with_overrides(self):
new_config = self.config.with_overrides(
def test_copy_with_overrides(self):
new_config = self.config.copy_with(
base_url="new_url",
application_ids={"new": 2},
)
Expand All @@ -47,8 +47,8 @@ def test_with_overrides(self):
self.assertEqual("url", self.config.base_url)
self.assertEqual({"app": 1}, self.config.application_ids)

def test_with_replaced_when_none(self):
new_config = self.config.with_overrides(base_url=None, application_ids=None)
def test_copy_with_replaced_when_none(self):
new_config = self.config.copy_with(base_url=None, application_ids=None)
self.assertEqual("url", new_config.base_url)
self.assertEqual({"app": 1}, new_config.application_ids)
self.assertEqual("url", self.config.base_url)
Expand Down

0 comments on commit 2afcd0c

Please sign in to comment.