From c4f12d0feca65f97bcb6f39b35e8530414ca8196 Mon Sep 17 00:00:00 2001 From: parth-kulkarni1 Date: Fri, 26 Apr 2024 06:31:10 +0000 Subject: [PATCH] Initial structure of everything so far. I have been working on fleshing out the Auth interface as suggested by Peter, and made some progress. --- poetry.lock | 188 ++++- pyproject.toml | 3 + src/provenaclient/clients/README.md | 1 + src/provenaclient/clients/RegistryClient.py | 54 ++ src/provenaclient/modules/ProvenaClient.py | 10 + src/provenaclient/modules/README.md | 1 + src/provenaclient/provena_client.py | 16 +- src/provenaclient/utils/Auth.py | 261 ++++++ src/provenaclient/utils/AuthManager.py | 41 + src/provenaclient/utils/Config.py | 7 + src/provenaclient/utils/TokenManager.py | 893 ++++++++++++++++++++ src/provenaclient/utils/httpClient.py | 16 + src/provenaclient/utils/test.py | 22 + 13 files changed, 1499 insertions(+), 14 deletions(-) create mode 100644 src/provenaclient/clients/RegistryClient.py create mode 100644 src/provenaclient/modules/ProvenaClient.py create mode 100644 src/provenaclient/utils/Auth.py create mode 100644 src/provenaclient/utils/AuthManager.py create mode 100644 src/provenaclient/utils/Config.py create mode 100644 src/provenaclient/utils/TokenManager.py create mode 100644 src/provenaclient/utils/httpClient.py create mode 100644 src/provenaclient/utils/test.py diff --git a/poetry.lock b/poetry.lock index 1cd6690..6d5767f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -33,6 +33,28 @@ files = [ {file = "anyascii-0.3.2.tar.gz", hash = "sha256:9d5d32ef844fe225b8bc7cba7f950534fae4da27a9bf3a6bea2cb0ea46ce4730"}, ] +[[package]] +name = "anyio" +version = "4.3.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "appnope" version = "0.1.4" @@ -489,6 +511,24 @@ files = [ {file = "dotty_dict-1.3.1.tar.gz", hash = "sha256:4b016e03b8ae265539757a53eba24b9bfda506fb94fbce0bee843c6f05541a15"}, ] +[[package]] +name = "ecdsa" +version = "0.19.0" +description = "ECDSA cryptographic signature library (pure python)" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.6" +files = [ + {file = "ecdsa-0.19.0-py2.py3-none-any.whl", hash = "sha256:2cea9b88407fdac7bbeca0833b189e4c9c53f2ef1e1eaa29f6224dbc809b707a"}, + {file = "ecdsa-0.19.0.tar.gz", hash = "sha256:60eaad1199659900dd0af521ed462b793bbdf867432b3948e87416ae4caf6bf8"}, +] + +[package.dependencies] +six = ">=1.9.0" + +[package.extras] +gmpy = ["gmpy"] +gmpy2 = ["gmpy2"] + [[package]] name = "exceptiongroup" version = "1.2.0" @@ -634,6 +674,62 @@ files = [ docs = ["Sphinx", "furo"] test = ["objgraph", "psutil"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "idna" version = "3.6" @@ -1355,6 +1451,17 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyasn1" +version = "0.6.0" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, + {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -1476,6 +1583,25 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.2.1" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_settings-2.2.1-py3-none-any.whl", hash = "sha256:0235391d26db4d2190cb9b31051c4b46882d28a51533f97440867f012d4da091"}, + {file = "pydantic_settings-2.2.1.tar.gz", hash = "sha256:00b9f6a5e95553590434c0fa01ead0b216c3e10bc54ae02e37f359948643c5ed"}, +] + +[package.dependencies] +pydantic = ">=2.3.0" +python-dotenv = ">=0.21.0" + +[package.extras] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pygments" version = "2.17.2" @@ -1545,6 +1671,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "python-gitlab" version = "4.4.0" @@ -1564,6 +1704,27 @@ requests-toolbelt = ">=0.10.1" autocompletion = ["argcomplete (>=1.10.0,<3)"] yaml = ["PyYaml (>=6.0.1)"] +[[package]] +name = "python-jose" +version = "3.3.0" +description = "JOSE implementation in Python" +optional = false +python-versions = "*" +files = [ + {file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"}, + {file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"}, +] + +[package.dependencies] +ecdsa = "!=0.15" +pyasn1 = "*" +rsa = "*" + +[package.extras] +cryptography = ["cryptography (>=3.4.0)"] +pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] +pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] + [[package]] name = "python-semantic-release" version = "9.4.0" @@ -1959,6 +2120,20 @@ files = [ {file = "rpds_py-0.18.0.tar.gz", hash = "sha256:42821446ee7a76f5d9f71f9e33a4fb2ffd724bb3e7f93386150b61a43115788d"}, ] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "shellingham" version = "1.5.4" @@ -1992,6 +2167,17 @@ files = [ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -2423,4 +2609,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "2109bf269fba741dcd88ce4e509d31f3404bb6d1b515507f5c4a6e1fb32733e5" +content-hash = "352b9b21b0f51dee822299eaacfc7ea7b6c59c45525095107a0e67a9c2dd601c" diff --git a/pyproject.toml b/pyproject.toml index ebcd310..e14f3cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,9 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.9" +httpx = "^0.27.0" +pydantic-settings = "^2.2.1" +python-jose = "^3.3.0" [tool.poetry.dev-dependencies] mypy = "^1.9.0" diff --git a/src/provenaclient/clients/README.md b/src/provenaclient/clients/README.md index e69de29..4fbe7fc 100644 --- a/src/provenaclient/clients/README.md +++ b/src/provenaclient/clients/README.md @@ -0,0 +1 @@ +registry client, provena client. \ No newline at end of file diff --git a/src/provenaclient/clients/RegistryClient.py b/src/provenaclient/clients/RegistryClient.py new file mode 100644 index 0000000..d029b50 --- /dev/null +++ b/src/provenaclient/clients/RegistryClient.py @@ -0,0 +1,54 @@ +from ..utils.Auth import Auth +from ..utils.Config import Config +from ..utils.httpClient import HttpClient + +class RegistryClient: + + URL_MAP = { + (REGISTRY_ACTION, SUBTYPE) : "url" + } + + auth: Auth + config: Config + + def __init__(self, auth: Auth, config: Config): + self.auth = auth, + self.config = config + + async def fetch_item(request_model: RequestModel, entity_subtype: ItemSubType) -> ItemPerson: + + # To complete this method. + get_auth = ().. + + try: + url =self.URL_MAP[REGISTRY_API.FETCH_ITEM, entity_subtype] + + + try: + + response = await HttpClient.make_get_request(url = url, body=request_model.dict(), params = validated_id.id, headers = get_auth_headers) # Retrieve the response object. + + # Handling the HTTP/Application level errors here + if response.status_code != 200: + if response.status_code == 401: + # TODO Define custom exceptions for common scenarios e.g. Auth + raise Exception(f"Authorisation exception...") + if response.status_code >= 500: + raise UnexpectedException(...) + + except Exception as e: + raise ValueError("Failed to fetch item.") + + # parse as json + json_response = response.json() + + try: + parsed_model = ItemModel.parse_json(json_response) + except: + # raise exception here for validation + + # check for status if applicable + if not parsed_model.status: + raise Exception(#status details + ) + diff --git a/src/provenaclient/modules/ProvenaClient.py b/src/provenaclient/modules/ProvenaClient.py new file mode 100644 index 0000000..55c2a29 --- /dev/null +++ b/src/provenaclient/modules/ProvenaClient.py @@ -0,0 +1,10 @@ +from ..utils.Auth import Auth +from ..utils.Config import Config +from ..clients.RegistryClient import RegistryClient + +class ProvenaClient(): + auth: Auth, + config: Config + + _registry_client: RegistryClient + diff --git a/src/provenaclient/modules/README.md b/src/provenaclient/modules/README.md index e69de29..adc528d 100644 --- a/src/provenaclient/modules/README.md +++ b/src/provenaclient/modules/README.md @@ -0,0 +1 @@ +provenaclient, registry. \ No newline at end of file diff --git a/src/provenaclient/provena_client.py b/src/provenaclient/provena_client.py index 89ac0b6..50dc211 100644 --- a/src/provenaclient/provena_client.py +++ b/src/provenaclient/provena_client.py @@ -1,15 +1,5 @@ -from dataclasses import dataclass +from .utils.Auth import Auth +from .utils.Config import Config -@dataclass -class Settings: - domain: str - -class Auth: - def __init__(self, settings: Settings): - self.settings = settings - -class Client: - def __init__(self, auth: Auth, settings: Settings): - self.auth = auth - self.settings = settings +config = Config(registry_api_endpoint=) \ No newline at end of file diff --git a/src/provenaclient/utils/Auth.py b/src/provenaclient/utils/Auth.py new file mode 100644 index 0000000..04fa0d9 --- /dev/null +++ b/src/provenaclient/utils/Auth.py @@ -0,0 +1,261 @@ +# Take a look at migrating the mdis-client-tools into here and see whats possible. + +from typing import Any, Dict, Optional +from .AuthManager import AuthManager +import requests +import webbrowser +import time +from pydantic import BaseModel +import os +import json +from jose import jwt, JWTError # type: ignore +from jose.constants import ALGORITHMS # type: ignore +'''' + +This file will contain the various authentication flows (Device, Offline, etc) in various classes. + +Currently working on developing the device flow. + +In this implementation I have assumed that the user will not worry about where they want tokens placed, hence not providing them with the ability to change the location. + +Furthermore, they probably wont have stages as they will just be interacting with an instance of Provena that is on PROD. + +''' + + +class Tokens(BaseModel): + access_token: str + # refresh tokens are marked as optional because offline tokens should not be cached + refresh_token: Optional[str] + +class DeviceFlow(AuthManager): + def __init__(self, keycloak_endpoint: str, client_id: str): + self.keycloak_endpoint = keycloak_endpoint + self.client_id = client_id + self.scopes = [] + self.device_endpoint = f"{self.keycloak_endpoint}/protocol/openid-connect/auth/device" + self.token_endpoint = f"{self.keycloak_endpoint}/protocol/openid-connect/token" + + def init(self): + + # First thing to do here is obtain the keycloak public key. + self.retrieve_keycloak_public_key() + + # Second thing will be to check if the tokens.json file is already present or not. + + # If it's present validate it, if fails then refresh else not present then fetch new tokens. + + if os.path.exists('tokens.json'): + self.tokens = self.load_tokens() + + if self.tokens and self.validate_token(self.tokens): + print("Using cached tokens...") + + else: + self.refresh() + + else: + print("Cached tokens are invalid, starting device flow") + self.start_device_flow() + + def retrieve_keycloak_public_key(self) -> None: + """Given the keycloak endpoint, retrieves the advertised + public key. + Based on https://github.com/nurgasemetey/fastapi-keycloak-oidc/blob/main/main.py + """ + error_message = f"Error finding public key from keycloak endpoint {self.keycloak_endpoint}." + try: + r = requests.get(self.keycloak_endpoint, + timeout=3) + r.raise_for_status() + response_json = r.json() + self.public_key = f"-----BEGIN PUBLIC KEY-----\r\n{response_json['public_key']}\r\n-----END PUBLIC KEY-----" + except requests.exceptions.HTTPError as errh: + # self.optional_print(error_message) + #self.optional_print("Http Error:" + str(errh)) + raise errh + except requests.exceptions.ConnectionError as errc: + #self.optional_print(error_message) + #self.optional_print("Error Connecting:" + str(errc)) + raise errc + except requests.exceptions.Timeout as errt: + #self.optional_print(error_message) + #self.optional_print("Timeout Error:" + str(errt)) + raise errt + except requests.exceptions.RequestException as err: + #self.optional_print(error_message) + #self.optional_print("An unknown error occured: " + str(err)) + raise err + + + def load_tokens(self) -> Tokens: + + try: + with open('tokens.json', 'r') as file: + token_data = json.load(file) + return Tokens(**token_data) + except Exception as e: + print(f"Failed to load tokens: {e}") + return None + + def save_tokens(self, tokens: Tokens): + + print("Called") + + with open('tokens.json', 'w') as file: + json.dump(tokens.model_dump(), file) + print("Tokens saved to file successfully.") + + + def validate_token(self, tokens: Tokens) -> bool: + + """Uses the python-jose library to validate current creds. + + In this context, it is basically just checking signature + and expiry. The tokens are enforced at the API side + as well. + + Parameters + ---------- + tokens : Optional[Tokens], optional + The tokens object to validate, by default None + """ + + try: + + jwt.decode( + tokens.access_token, + self.public_key, + algorithms=[ALGORITHMS.RS256], + options={ + "verify_signature": True, + "verify_aud": False, + "exp": True + } + ) + + print("Token is valid.") + return True + + except JWTError as e: + print(f"Token Validation Error {e}") + + def start_device_flow(self) -> None: + + data = { + "client_id": self.client_id, + "scopes": ' '.join(self.scopes) + } + + response = requests.post(self.device_endpoint, data = data) + + if response.status_code == 200: + print("success") + response_data= response.json() + self.device_code = response_data.get('device_code') + self.interval = response_data.get('interval') + verification_url = response_data.get('verification_uri_complete') + user_code = response_data.get('user_code') + webbrowser.open(verification_url) + self.handle_auth_flow() + + else: + raise Exception("Failed to initiate device flow auth.") + + + def handle_auth_flow(self) -> None: + + print("in handle auth flow") + + device_grant_type = "urn:ietf:params:oauth:grant-type:device_code" + + data = { + "grant_type": device_grant_type, + "device_code": self.device_code, + "client_id": self.client_id, + "scope": " ".join(self.scopes) + } + + # Setup success criteria + succeeded = False + timed_out = False + misc_fail = False + + # start time + response_data: Optional[Dict[str, Any]] = None + + # get requests session for repeated queries + session = requests.session() + + # Poll for success + while not succeeded and not timed_out and not misc_fail: + response = session.post(self.token_endpoint, data=data) + response_data = response.json() + assert response_data + if response_data.get('error'): + error = response_data['error'] + if error != 'authorization_pending': + misc_fail = True + # Wait appropriate OAuth poll interval + time.sleep(self.interval) + else: + + # Successful as there was no error at the endpoint + # We will produce a token object here. + + self.tokens = Tokens( + access_token=response_data.get('access_token'), + refresh_token=response_data.get('refresh_token') + ) + + # Save the tokens into 'token.json' + + self.save_tokens(self.tokens) + + try: + assert response_data + self.optional_print(f"Failed due to {response_data['error']}") + return None + except Exception as e: + self.optional_print( + f"Failed with unknown error, failed to find error message. Error {e}") + return None + + + def refresh(self): + print("In refresh method.") + + def get_token(self): + pass + + def get_auth(self): + pass + + + def force_refresh(self): + pass + + +class OfflineFlow(AuthManager): + def __init__(self): + pass + + def init(self): + pass + + def refresh(self): + pass + + def get_token(self): + pass + + def force_refresh(self): + pass + + + + + + + + diff --git a/src/provenaclient/utils/AuthManager.py b/src/provenaclient/utils/AuthManager.py new file mode 100644 index 0000000..e844580 --- /dev/null +++ b/src/provenaclient/utils/AuthManager.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod + +class AuthManager(ABC): + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def refresh(self): + """ Refresh the current token""" + pass + + @abstractmethod + def force_refresh(self): + """ Force refresh the current token""" + pass + + @abstractmethod + def get_token(self): + """Get token information and other metadata.""" + pass + + @abstractmethod + def get_auth(self): + """Get the auth object.""" + pass + + @abstractmethod + def handle_auth_flow(self): + """Handle any user interactions required for auth flow.""" + pass + + @abstractmethod + def validate_token(self): + """Validate the token checking expiry and credentials.""" + pass + + + + diff --git a/src/provenaclient/utils/Config.py b/src/provenaclient/utils/Config.py new file mode 100644 index 0000000..40bf737 --- /dev/null +++ b/src/provenaclient/utils/Config.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + +class Config(BaseModel): + + # This will contain information about your Provena domain. + registry_api_endpoint: str + datastore_api_endpoint: str diff --git a/src/provenaclient/utils/TokenManager.py b/src/provenaclient/utils/TokenManager.py new file mode 100644 index 0000000..df7bf50 --- /dev/null +++ b/src/provenaclient/utils/TokenManager.py @@ -0,0 +1,893 @@ +import webbrowser +import requests +import time +from jose import jwt # type: ignore +from jose.constants import ALGORITHMS # type: ignore +from pydantic import BaseModel +from enum import Enum +from typing import Dict, Optional, List, Any +import json + + +class StorageType(str, Enum): + FILE = "FILE" + OBJECT = "OBJECT" + +# For usage in requests library + + +class BearerAuth(requests.auth.AuthBase): + def __init__(self, token: str): + self.token = token + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + r.headers["authorization"] = "Bearer " + self.token + return r + + +# Model for storing and serialising tokens +class Stage(str, Enum): + TEST = "TEST" + DEV = "DEV" + STAGE = "STAGE" + PROD = "PROD" + + +class Tokens(BaseModel): + access_token: str + # refresh tokens are marked as optional because offline tokens should not be cached + refresh_token: Optional[str] + + +class StageTokens(BaseModel): + stages: Dict[Stage, Optional[Tokens]] + + +LOCAL_STORAGE_DEFAULT = ".tokens.json" +DEFAULT_CLIENT_ID = "client-tools" + + +class AuthFlow(str, Enum): + DEVICE = "DEVICE" + OFFLINE = "OFFLINE" + + +class DeviceFlowManager: + def __init__( + self, + stage: str, + keycloak_endpoint: str, + auth_flow: AuthFlow = AuthFlow.DEVICE, + offline_token: Optional[str] = None, + client_id: str = DEFAULT_CLIENT_ID, + local_storage_location: Optional[str] = None, + local_storage_object: Optional[Dict[str, Any]] = None, + scopes: List[str] = [], + force_token_refresh: bool = False, + silent: bool = False + ) -> None: + """Generates a manager class. This manager class uses the + OAuth device authorisation flow to generate credentials on a per + application stage basis. The tokens are automatically refreshed when + accessed through the get_auth() function. + + Tokens are cached in local storage with a configurable file name and are + only reproduced if the refresh token expires. + + Parameters + ---------- + stage : str + The application stage to use. Choose from {list(Stage)}. + keycloak_endpoint : str + The keycloak endpoint to use. + client_id : str, optional + The client id for the keycloak authorisation, by default + DEFAULT_CLIENT_ID + local_storage_object: Optional[Dict[str, Any]] = None + provide a storage object rather than a location - this will cache + the tokens in the provided dictionary useful for local session + states + local_storage_location : str, optional + The storage location for caching creds, by default + LOCAL_STORAGE_DEFAULT + scopes : List[str], optional + The scopes you want to request against client, by default [] + force_token_refresh : bool, optional + If you want to force the manager to dump current creds, by default + False + silent : bool + Force silence in the stdout outputs for use in context where + printing would be irritating. By default False (helpful messages are + printed). + + Raises + ------ + ValueError + If the stage provided is invalid. + """ + self.silent = silent + + if local_storage_location is not None and local_storage_object is not None: + raise ValueError( + "Can't specify both local storage file and object.") + + self.storage_type: StorageType = StorageType.FILE + + if local_storage_object is None: + self.storage_type = StorageType.FILE + # use file storage + if local_storage_location is None: + self.optional_print( + f"No storage or object provided, using default location: {LOCAL_STORAGE_DEFAULT}.") + self.token_storage_location = LOCAL_STORAGE_DEFAULT + else: + self.token_storage_location = local_storage_location + else: + # use object storage + self.storage_type = StorageType.OBJECT + self.object_storage = local_storage_object + + self.optional_print(f"Using storage type: {self.storage_type}.") + + # check and validate auth flow preferences + if auth_flow == AuthFlow.OFFLINE: + if not offline_token: + raise ValueError( + "You are using an offline auth flow but did not provide an offline token!") + if auth_flow == AuthFlow.DEVICE: + if offline_token: + raise ValueError( + "You provided an offline token but specified the DEVICE auth flow. The DEVICE auth flow does not require an offline token.") + + self.optional_print(f"Using {auth_flow} auth flow.") + self.auth_flow = auth_flow + self.keycloak_endpoint = keycloak_endpoint + self.client_id = client_id + + self.offline_token = offline_token + + # initialise empty stage tokens + self.stage_tokens: Optional[Tokens] = None + self.public_key: Optional[str] = None + self.scopes: List[str] = scopes + + # pull out stage + try: + self.stage: Stage = Stage[stage] + except: + raise ValueError(f"Stage {stage} is not one of {list(Stage)}.") + + # set endpoints + self.token_endpoint = self.keycloak_endpoint + "/protocol/openid-connect/token" + self.device_endpoint = self.keycloak_endpoint + \ + "/protocol/openid-connect/auth/device" + + if force_token_refresh: + self.reset_storage() + + self.retrieve_keycloak_public_key() + self.get_tokens() + + def optional_print(self, message: Optional[str] = None) -> None: + """Prints only if the silent value is not + flagged. + + Parameters + ---------- + message : str + The message to print. + """ + if not self.silent: + if message: + print(message) + else: + print() + + def retrieve_local_tokens(self, stage: Stage) -> Optional[Tokens]: + """Retrieves credentials from a local cache file, if present. + Credentials are on a per stage basis. If the creds are valid + but expired, they will be refreshed. If this fails, then + a failure is indicated by None. + + Parameters + ---------- + stage : Stage + The stage to fetch creds for. + + Returns + ------- + Optional[Tokens] + Tokens object if successful or None. + """ + + if self.storage_type == StorageType.FILE: + self.optional_print( + "Looking for existing tokens in local storage.") + self.optional_print("") + # Try to read file + try: + stage_tokens = StageTokens.parse_file( + self.token_storage_location) + tokens = stage_tokens.stages.get(stage) + assert tokens + except: + self.optional_print( + f"No local storage tokens for stage {stage} found.") + self.optional_print("") + return None + elif self.storage_type == StorageType.OBJECT: + self.optional_print( + "Looking for existing tokens in provided object.") + self.optional_print("") + # Try to read object + try: + stage_tokens = StageTokens.parse_obj(self.object_storage) + tokens = stage_tokens.stages.get(stage) + assert tokens + except: + self.optional_print( + f"No local storage tokens in provided storage for {stage}.") + self.optional_print("") + return None + + # Validate + self.optional_print("Validating found tokens") + self.optional_print() + valid = True + try: + self.validate_token(tokens=tokens) + except: + valid = False + + # Return the tokens found if valid + if valid: + self.optional_print("Found tokens valid, using.") + self.optional_print() + return tokens + + elif self.auth_flow == AuthFlow.OFFLINE: + # no refresh from storage is available using the offline workflow as + # they are not cached + self.optional_print( + "Refresh not cached for offline workflow - regenerating using offline token." + ) + self.optional_print() + return None + + # Tokens found but were invalid, try refreshing + refresh_succeeded = True + try: + self.optional_print( + "Trying to use found tokens to refresh the access token.") + self.optional_print() + refreshed = self.perform_refresh(tokens=tokens) + + # unpack response and return access token + access_token = refreshed.get('access_token') + refresh_token = refreshed.get('refresh_token') + + # Make sure they are preset + assert access_token + assert refresh_token + + tokens = Tokens( + access_token=access_token, + refresh_token=refresh_token + ) + self.validate_token(tokens) + except: + refresh_succeeded = False + + # If refresh fails for some reason then return None + # otherwise return the tokens + if refresh_succeeded: + self.optional_print("Token refresh successful.") + self.optional_print() + return tokens + else: + self.optional_print( + "Tokens found in storage but they are not valid.") + self.optional_print() + return None + + def reset_storage(self) -> None: + """Resets the local storage by setting all + values to None. + """ + + if self.storage_type == StorageType.FILE: + self.optional_print("Flushing tokens from local storage.") + cleared_tokens = StageTokens( + stages={ + Stage.TEST: None, + Stage.DEV: None, + Stage.STAGE: None, + Stage.PROD: None + } + ) + + # Dump the cleared file into storage + with open(self.token_storage_location, 'w') as f: + f.write(cleared_tokens.json()) + elif self.storage_type == StorageType.OBJECT: + self.object_storage.clear() + + def update_local_storage(self, stage: Stage) -> None: + """Pulls the current StageTokens object from cache + storage, if present, then either updates the current + stage token value in existing or new StageTokens + object. Writes back to file. + + Parameters + ---------- + stage : Stage + The stage to update + """ + # Check current tokens + assert self.tokens + existing: Optional[bool] = None + existing_tokens: Optional[StageTokens] = None + + if self.storage_type == StorageType.FILE: + try: + existing_tokens = StageTokens.parse_file( + self.token_storage_location) + existing = True + except: + existing = False + + assert existing is not None + if existing: + # We have existing - update current stage + assert existing_tokens + + existing_tokens.stages[stage] = self.tokens + else: + existing_tokens = StageTokens( + stages={ + Stage.TEST: None, + Stage.DEV: None, + Stage.STAGE: None, + Stage.PROD: None + } + ) + existing_tokens.stages[stage] = self.tokens + + # if OFFLINE mode then remove all refresh tokens from the object so + # that we never cache refresh tokens + + if self.auth_flow == AuthFlow.OFFLINE: + for stage, tokens in existing_tokens.stages.items(): + if tokens: + tokens.refresh_token = None + + # Dump the file into storage + with open(self.token_storage_location, 'w') as f: + f.write(existing_tokens.json(exclude_none=True)) + elif self.storage_type == StorageType.OBJECT: + try: + existing_tokens = StageTokens.parse_obj( + self.object_storage) + existing = True + except: + existing = False + + assert existing is not None + if existing: + # We have existing - update current stage + assert existing_tokens + + existing_tokens.stages[stage] = self.tokens + else: + existing_tokens = StageTokens( + stages={ + Stage.TEST: None, + Stage.DEV: None, + Stage.STAGE: None, + Stage.PROD: None + } + ) + existing_tokens.stages[stage] = self.tokens + + if self.auth_flow == AuthFlow.OFFLINE: + for stage, tokens in existing_tokens.stages.items(): + if tokens: + tokens.refresh_token = None + + # update local storage object + self.object_storage.clear() + new = json.loads( + existing_tokens.json(exclude_none=True)) + for k, v in new.items(): + self.object_storage[k] = v + + def perform_offline_refresh(self) -> Dict[str, Any]: + """ + perform_offline_refresh + + Uses the current offline token to perform a token refresh. + + Returns + ------- + Dict[str, Any] + The response from the token endpoint iff status code == 200 + + Raises + ------ + Exception + Exception if non 200 status code + """ + # Perform a refresh grant + refresh_grant_type = "refresh_token" + + # Required openid connect fields + data = { + "grant_type": refresh_grant_type, + "client_id": self.client_id, + "refresh_token": self.offline_token, + "scope": " ".join(self.scopes) + } + + # Send API request + response = requests.post(self.token_endpoint, data=data) + + if (not response.status_code == 200): + raise Exception( + f"Something went wrong during offline token refresh. Status code: {response.status_code}.") + + return response.json() + + def get_tokens(self) -> None: + """Tries to get tokens. + First attempts to pull from the local storage. + Otherwise initiates a device auth flow then uses the + token endpoint to generate the creds. + + Raises + ------ + Exception + OAuth tokens not present in device auth flow + Exception + Tokens not present in keycloak token endpoint response + """ + # Try getting from local storage first + # These are always validated + self.optional_print("Attempting to generate authorisation tokens.") + self.optional_print() + + # try to get from local storage and attempt auto refresh + retrieved_tokens = self.retrieve_local_tokens(self.stage) + if retrieved_tokens: + self.tokens = retrieved_tokens + self.update_local_storage(self.stage) + return + + # Otherwise do a normal authorisation flow + if self.auth_flow == AuthFlow.DEVICE: + # device auth flow + + # grant type + device_grant_type = "urn:ietf:params:oauth:grant-type:device_code" + + self.optional_print( + "Initiating device auth flow to generate access and refresh tokens.") + self.optional_print() + device_auth_response = self.initiate_device_auth_flow() + + self.optional_print("Decoding response") + self.optional_print() + device_code = device_auth_response['device_code'] + user_code = device_auth_response['user_code'] + verification_uri = device_auth_response['verification_uri_complete'] + interval = device_auth_response['interval'] + + self.optional_print( + "Please authorise using the following endpoint.") + self.optional_print() + self.display_device_auth_flow(user_code, verification_uri) + self.optional_print() + + self.optional_print("Awaiting completion") + self.optional_print() + oauth_tokens = self.await_device_auth_flow_completion( + device_code=device_code, + interval=interval, + grant_type=device_grant_type, + ) + self.optional_print() + + if oauth_tokens is None: + raise Exception( + "Failed to retrieve tokens from device authorisation flow!") + + # pull out the refresh and access token + # this refresh token is standard (not offline access) + access_token = oauth_tokens.get('access_token') + refresh_token = oauth_tokens.get('refresh_token') + + # Check that they are present + try: + assert access_token is not None + assert refresh_token is not None + except Exception as e: + raise Exception( + f"Token payload did not include access or refresh token: Error: {e}") + # Set tokens + self.tokens = Tokens( + access_token=access_token, + refresh_token=refresh_token + ) + self.update_local_storage(self.stage) + + self.optional_print( + "Token generation complete. Authorisation successful.") + self.optional_print() + + elif self.auth_flow == AuthFlow.OFFLINE: + # offline auth flow + + # perform offline refresh + oauth_tokens = self.perform_offline_refresh() + + # pull out the refresh and access token + # this refresh token is standard (not offline access) + access_token = oauth_tokens.get('access_token') + refresh_token = oauth_tokens.get('refresh_token') + + # Check that they are present + try: + assert access_token is not None + assert refresh_token is not None + except Exception as e: + raise Exception( + f"Offline refresh token payload did not include access or refresh token: Error: {e}") + + # Set tokens + self.tokens = Tokens( + access_token=access_token, + refresh_token=refresh_token + ) + self.update_local_storage(self.stage) + + self.optional_print( + "Offline token generation complete. Authorisation successful.") + self.optional_print() + + def perform_token_refresh(self) -> None: + """Updates the current tokens by using the refresh token. + """ + assert self.tokens is not None + + self.optional_print("Refreshing using refresh token") + self.optional_print() + + refreshed = self.perform_refresh() + + # unpack response and return access token + access_token = refreshed.get('access_token') + refresh_token = refreshed.get('refresh_token') + + # Make sure they are preset + assert access_token + assert refresh_token + + self.tokens = Tokens( + access_token=access_token, + refresh_token=refresh_token + ) + self.update_local_storage(self.stage) + + def perform_refresh(self, tokens: Optional[Tokens] = None) -> Dict[str, Any]: + """Helper function to perform refresh. This accepts tokens + and other information from the class, calls the refresh endpoint, + and responds with the keycloak token endpoint response. + + Parameters + ---------- + tokens : Optional[Tokens], optional + The tokens object, by default None + + Returns + ------- + Dict[str, Any] + The response from the keycloak endpoint as json dict. + + Raises + ------ + Exception + Non 200 response code. + """ + # Perform a refresh grant + refresh_grant_type = "refresh_token" + + # make sure we have tokens to use + desired_tokens: Optional[Tokens] + if tokens: + desired_tokens = tokens + else: + desired_tokens = self.tokens + + assert desired_tokens + assert desired_tokens.refresh_token + + # Required openid connect fields + data = { + "grant_type": refresh_grant_type, + "client_id": self.client_id, + "refresh_token": desired_tokens.refresh_token, + "scope": " ".join(self.scopes) + } + + # Send API request + response = requests.post(self.token_endpoint, data=data) + + if (not response.status_code == 200): + raise Exception( + f"Something went wrong during token refresh. Status code: {response.status_code}.") + + return response.json() + + def initiate_device_auth_flow(self) -> Dict[str, Any]: + """Initiates OAuth device flow. + This is triggered by a post request to the device endpoint + of the keycloak server. The specified client (by id) must be + public and have the device auth flow enabled. + + Returns + ------- + Dict[str, Any] + The json response info from the device auth flow endpoint + """ + data = { + "client_id": self.client_id, + "scope": ' '.join(self.scopes) + } + response = requests.post(self.device_endpoint, data=data).json() + return response + + def get_token(self) -> str: + """Uses the current token - validates it, + refreshes if necessary, and returns the valid token + ready to be used. + + Returns + ------- + str + The access token + + Raises + ------ + Exception + Raises exception if tokens/public_key are not setup - make sure + that the object is instantiated properly before calling this function + Exception + If the token is invalid and cannot be refreshed + """ + # make auth object using access_token + if (self.tokens is None or self.public_key is None): + raise Exception( + "cannot generate token without access token or public key") + + assert self.tokens + assert self.public_key + + # are tokens valid? + try: + self.validate_token() + except Exception as e: + # tokens are invalid + self.optional_print(f"Token validation failed due to error: {e}") + # does token refresh work? + try: + self.perform_token_refresh() + self.validate_token() + except Exception as e: + try: + # Does new token generation work? + self.get_tokens() + self.validate_token() + except Exception as e: + raise Exception( + f"Device log in failed, access token expired/invalid, and refresh failed. Error: {e}") + return self.tokens.access_token + + def get_auth(self) -> BearerAuth: + """A helper function which produces a BearerAuth object for use + in the requests.xxx objects. For example: + + manager = DeviceAuthFlowManager(...) + auth = manager.get_auth + requests.post(..., auth=auth) + + Returns + ------- + BearerAuth + The requests auth object. + + Raises + ------ + Exception + Tokens are not present + Exception + Token validation failed and refresh or device auth failed + """ + # make auth object using access_token + if (self.tokens is None or self.public_key is None): + raise Exception( + "cannot generate bearer auth object without access token or public key") + + assert self.tokens + assert self.public_key + + # are tokens valid? + try: + self.validate_token() + except Exception as e: + # tokens are invalid + self.optional_print(f"Token validation failed due to error: {e}") + # does token refresh work? + try: + self.perform_token_refresh() + self.validate_token() + except Exception as e: + try: + # Does new token generation work? + self.get_tokens() + self.validate_token() + except Exception as e: + raise Exception( + f"Device log in failed, access token expired/invalid, and refresh failed. Error: {e}") + return BearerAuth(token=self.tokens.access_token) + + def retrieve_keycloak_public_key(self) -> None: + """Given the keycloak endpoint, retrieves the advertised + public key. + Based on https://github.com/nurgasemetey/fastapi-keycloak-oidc/blob/main/main.py + """ + error_message = f"Error finding public key from keycloak endpoint {self.keycloak_endpoint}." + try: + r = requests.get(self.keycloak_endpoint, + timeout=3) + r.raise_for_status() + response_json = r.json() + self.public_key = f"-----BEGIN PUBLIC KEY-----\r\n{response_json['public_key']}\r\n-----END PUBLIC KEY-----" + except requests.exceptions.HTTPError as errh: + self.optional_print(error_message) + self.optional_print("Http Error:" + str(errh)) + raise errh + except requests.exceptions.ConnectionError as errc: + self.optional_print(error_message) + self.optional_print("Error Connecting:" + str(errc)) + raise errc + except requests.exceptions.Timeout as errt: + self.optional_print(error_message) + self.optional_print("Timeout Error:" + str(errt)) + raise errt + except requests.exceptions.RequestException as err: + self.optional_print(error_message) + self.optional_print("An unknown error occured: " + str(err)) + raise err + + def display_device_auth_flow(self, user_code: str, verification_url: str) -> None: + """Displays the current device auth flow challenge - first by trying to + open a browser window - if this fails then prints suggestion to stdout + to try using the URL manually. + + Parameters + ---------- + user_code : str + The user code + verification_url : str + The url which embeds challenge code + """ + print(f"Verification URL: {verification_url}") + print(f"User Code: {user_code}") + try: + webbrowser.open(verification_url) + except Exception: + print("Tried to open web-browser but failed. Please visit URL above.") + + def await_device_auth_flow_completion( + self, + device_code: str, + interval: int, + grant_type: str, + ) -> Optional[Dict[str, Any]]: + """Ping the token endpoint as specified in the OAuth standard + at the advertised polling rate until response is positive + or failure. + + Parameters + ---------- + device_code : str + The device code + interval : int + The polling interval (ms) + grant_type : str + The OAuth grant type + + Returns + ------- + Optional[Dict[str, Any]] + If successful, the keycloak response + """ + # set up request + data = { + "grant_type": grant_type, + "device_code": device_code, + "client_id": self.client_id, + "scope": " ".join(self.scopes) + } + + # Setup success criteria + succeeded = False + timed_out = False + misc_fail = False + + # start time + response_data: Optional[Dict[str, Any]] = None + + # get requests session for repeated queries + session = requests.session() + + # Poll for success + while not succeeded and not timed_out and not misc_fail: + response = session.post(self.token_endpoint, data=data) + response_data = response.json() + assert response_data + if response_data.get('error'): + error = response_data['error'] + if error != 'authorization_pending': + misc_fail = True + # Wait appropriate OAuth poll interval + time.sleep(interval) + else: + # Successful as there was no error at the endpoint + return response_data + + try: + assert response_data + self.optional_print(f"Failed due to {response_data['error']}") + return None + except Exception as e: + self.optional_print( + f"Failed with unknown error, failed to find error message. Error {e}") + return None + + def validate_token(self, tokens: Optional[Tokens] = None) -> None: + """Uses the python-jose library to validate current creds. + + In this context, it is basically just checking signature + and expiry. The tokens are enforced at the API side + as well. + + Parameters + ---------- + tokens : Optional[Tokens], optional + The tokens object to validate, by default None + """ + # Validate either self.tokens or supply tokens optionally + test_tokens: Optional[Tokens] + if tokens: + test_tokens = tokens + else: + test_tokens = self.tokens + + # Check tokens are present + assert test_tokens + assert self.public_key + + # this is currently locally validating the token + # It is our responsibility to choose whether to honour the expiration date + # etc + # this will throw an exception if invalid + jwt_payload = jwt.decode( + test_tokens.access_token, + self.public_key, + algorithms=[ALGORITHMS.RS256], + options={ + "verify_signature": True, + "verify_aud": False, + "exp": True + } + ) \ No newline at end of file diff --git a/src/provenaclient/utils/httpClient.py b/src/provenaclient/utils/httpClient.py new file mode 100644 index 0000000..2c8ce5d --- /dev/null +++ b/src/provenaclient/utils/httpClient.py @@ -0,0 +1,16 @@ +from typing import Dict, Any, Optional +import httpx + +class HttpClient: + + @staticmethod + async def make_get_request(url: str, params:Optional[dict[str, Any]] = None, headers:Optional[dict[str,Any]] = None) -> httpx.Response: + async with httpx.AsyncClient() as client: + response = await client.get(url,params = params, headers= headers) + return response + + @staticmethod + async def make_post_request(url: str, data: Optional[dict[str, Any]] = None, headers: Optional[dict[str, Any]] = None) -> httpx.Response: + async with httpx.AsyncClient() as client: + response = await client.post(url, json = data, headers = headers) + return response \ No newline at end of file diff --git a/src/provenaclient/utils/test.py b/src/provenaclient/utils/test.py new file mode 100644 index 0000000..290547c --- /dev/null +++ b/src/provenaclient/utils/test.py @@ -0,0 +1,22 @@ +from .Auth import DeviceFlow + +def main(): + # Define the Keycloak endpoint and client ID + keycloak_url = "https://auth.dev.rrap-is.com/auth/realms/rrap" + client_id = "client-tools" + + try: + # Create the DeviceFlow object + device_auth = DeviceFlow(keycloak_endpoint=keycloak_url, client_id=client_id) + + # Initialize the device flow which will open a web browser for user code input + device_auth.init() + + + print("Initialization complete. Check your browser to authenticate.") + + except Exception as e: + print(f"An error occurred: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file