Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/oauth2-v2' into oauth2-v2
Browse files Browse the repository at this point in the history
# Conflicts:
#	packages/opal-client/opal_client/client.py
#	packages/opal-client/opal_client/data/updater.py
#	packages/opal-client/opal_client/policy/updater.py
  • Loading branch information
Ondrej Scecina committed Nov 5, 2024
2 parents 55079e3 + 9f46644 commit 6ea4078
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 76 deletions.
41 changes: 11 additions & 30 deletions packages/opal-client/opal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import websockets
from fastapi import FastAPI, status
from fastapi.responses import JSONResponse
from fastapi_websocket_pubsub.pub_sub_client import PubSubOnConnectCallback
from fastapi_websocket_rpc.rpc_channel import OnDisconnectCallback
from opal_client.callbacks.api import init_callbacks_api
from opal_client.callbacks.register import CallbacksRegister
from opal_client.config import PolicyStoreTypes, opal_client_config
Expand All @@ -29,8 +27,8 @@
from opal_client.policy_store.policy_store_client_factory import (
PolicyStoreClientFactory,
)
from opal_common.authentication.deps import JWTAuthenticator
from opal_common.authentication.verifier import JWTVerifier
from opal_common.authentication.authenticator import Authenticator
from opal_common.authentication.authenticator_factory import AuthenticatorFactory
from opal_common.config import opal_common_config
from opal_common.logger import configure_logs, logger
from opal_common.middleware import configure_middleware
Expand All @@ -49,15 +47,11 @@ def __init__(
inline_opa_options: OpaServerOptions = None,
inline_cedar_enabled: bool = None,
inline_cedar_options: CedarServerOptions = None,
verifier: Optional[JWTVerifier] = None,
authenticator: Optional[Authenticator] = None,
store_backup_path: Optional[str] = None,
store_backup_interval: Optional[int] = None,
offline_mode_enabled: bool = False,
shard_id: Optional[str] = None,
on_data_updater_connect: List[PubSubOnConnectCallback] = None,
on_data_updater_disconnect: List[OnDisconnectCallback] = None,
on_policy_updater_connect: List[PubSubOnConnectCallback] = None,
on_policy_updater_disconnect: List[OnDisconnectCallback] = None,
) -> None:
"""
Args:
Expand All @@ -68,6 +62,10 @@ def __init__(
data_updater (DataUpdater, optional): Defaults to None.
policy_updater (PolicyUpdater, optional): Defaults to None.
"""
if authenticator is not None:
self.authenticator = authenticator
else:
self.authenticator = AuthenticatorFactory.create()
self._shard_id = shard_id
# defaults
policy_store_type: PolicyStoreTypes = (
Expand Down Expand Up @@ -123,8 +121,7 @@ def __init__(
policy_store=self.policy_store,
callbacks_register=self._callbacks_register,
opal_client_id=opal_client_identifier,
on_connect=on_policy_updater_connect,
on_disconnect=on_policy_updater_disconnect,
authenticator=self.authenticator,
)
else:
self.policy_updater = None
Expand All @@ -146,8 +143,7 @@ def __init__(
callbacks_register=self._callbacks_register,
opal_client_id=opal_client_identifier,
shard_id=self._shard_id,
on_connect=on_data_updater_connect,
on_disconnect=on_data_updater_disconnect,
authenticator=self.authenticator,
)
else:
self.data_updater = None
Expand All @@ -170,19 +166,6 @@ def __init__(
"OPAL client is configured to trust self-signed certificates"
)

if verifier is not None:
self.verifier = verifier
else:
self.verifier = JWTVerifier(
public_key=opal_common_config.AUTH_PUBLIC_KEY,
algorithm=opal_common_config.AUTH_JWT_ALGORITHM,
audience=opal_common_config.AUTH_JWT_AUDIENCE,
issuer=opal_common_config.AUTH_JWT_ISSUER,
)
if not self.verifier.enabled:
logger.info(
"API authentication disabled (public encryption key was not provided)"
)
self.store_backup_path = (
store_backup_path or opal_client_config.STORE_BACKUP_PATH
)
Expand Down Expand Up @@ -258,13 +241,11 @@ def _init_fast_api_app(self):
def _configure_api_routes(self, app: FastAPI):
"""mounts the api routes on the app object."""

authenticator = JWTAuthenticator(self.verifier)

# Init api routers with required dependencies
policy_router = init_policy_router(policy_updater=self.policy_updater)
data_router = init_data_router(data_updater=self.data_updater)
policy_store_router = init_policy_store_router(authenticator)
callbacks_router = init_callbacks_api(authenticator, self._callbacks_register)
policy_store_router = init_policy_store_router(self.authenticator)
callbacks_router = init_callbacks_api(self.authenticator, self._callbacks_register)

# mount the api routes on the app object
app.include_router(policy_router, tags=["Policy Updater"])
Expand Down
120 changes: 84 additions & 36 deletions packages/opal-client/opal_client/data/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import aiohttp
from aiohttp.client import ClientError, ClientSession
from fastapi_websocket_pubsub import PubSubClient
from fastapi_websocket_pubsub.pub_sub_client import PubSubOnConnectCallback
from fastapi_websocket_rpc.rpc_channel import OnDisconnectCallback, RpcChannel
from fastapi_websocket_rpc.rpc_channel import RpcChannel
from opal_client.callbacks.register import CallbacksRegister
from opal_client.callbacks.reporter import CallbacksReporter
from opal_client.config import opal_client_config
Expand All @@ -25,6 +24,8 @@
DEFAULT_POLICY_STORE_GETTER,
)
from opal_common.async_utils import TakeANumberQueue, TasksPool, repeated_call
from opal_common.authentication.authenticator import Authenticator
from opal_common.authentication.authenticator_factory import AuthenticatorFactory
from opal_common.config import opal_common_config
from opal_common.fetcher.events import FetcherConfig
from opal_common.http_utils import is_http_error_response
Expand All @@ -42,6 +43,54 @@


class DataUpdater:
async def trigger_data_update(self, update: DataUpdate):
raise NotImplementedError()

async def get_policy_data_config(self, url: str = None) -> DataSourceConfig:
raise NotImplementedError()

async def get_base_policy_data(
self, config_url: str = None, data_fetch_reason="Initial load"
):
raise NotImplementedError()

async def on_connect(self, client: PubSubClient, channel: RpcChannel):
raise NotImplementedError()

async def on_disconnect(self, channel: RpcChannel):
raise NotImplementedError()

async def start(self):
raise NotImplementedError()

async def stop(self):
raise NotImplementedError()

async def wait_until_done(self):
raise NotImplementedError()

@staticmethod
def calc_hash(data):
"""Calculate an hash (sah256) on the given data, if data isn't a
string, it will be converted to JSON.
String are encoded as 'utf-8' prior to hash calculation.
Returns:
the hash of the given data (as a a hexdigit string) or '' on failure to process.
"""
try:
if not isinstance(data, str):
data = json.dumps(data, default=pydantic_encoder)
return hashlib.sha256(data.encode("utf-8")).hexdigest()
except:
logger.exception("Failed to calculate hash for data {data}", data=data)
return ""

@property
def callbacks_reporter(self) -> CallbacksReporter:
raise NotImplementedError()

class DefaultDataUpdater(DataUpdater):
def __init__(
self,
token: str = None,
Expand All @@ -55,8 +104,7 @@ def __init__(
callbacks_register: Optional[CallbacksRegister] = None,
opal_client_id: str = None,
shard_id: Optional[str] = None,
on_connect: List[PubSubOnConnectCallback] = None,
on_disconnect: List[OnDisconnectCallback] = None,
authenticator: Optional[Authenticator] = None,
):
"""Keeps policy-stores (e.g. OPA) up to date with relevant data Obtains
data configuration on startup from OPAL-server Uses Pub/Sub to
Expand Down Expand Up @@ -135,8 +183,10 @@ def __init__(
self._updates_storing_queue = TakeANumberQueue(logger)
self._tasks = TasksPool()
self._polling_update_tasks = []
self._on_connect_callbacks = on_connect or []
self._on_disconnect_callbacks = on_disconnect or []
if authenticator is not None:
self._authenticator = authenticator
else:
self._authenticator = AuthenticatorFactory.create()

async def __aenter__(self):
await self.start()
Expand Down Expand Up @@ -182,20 +232,30 @@ async def get_policy_data_config(self, url: str = None) -> DataSourceConfig:
if url is None:
url = self._data_sources_config_url
logger.info("Getting data-sources configuration from '{source}'", source=url)

headers = {}
if self._extra_headers is not None:
headers = self._extra_headers.copy()
headers['Accept'] = "application/json"

try:
async with ClientSession(headers=self._extra_headers) as session:
response = await session.get(url, **self._ssl_context_kwargs)
if response.status == 200:
return DataSourceConfig.parse_obj(await response.json())
else:
error_details = await response.json()
raise ClientError(
f"Fetch data sources failed with status code {response.status}, error: {error_details}"
)
response = await self._load_policy_data_config(url, headers)

if response.status == 200:
return DataSourceConfig.parse_obj(await response.json())
else:
error_details = await response.text()
raise ClientError(
f"Fetch data sources failed with status code {response.status}, error: {error_details}"
)
except:
logger.exception(f"Failed to load data sources config")
raise

async def _load_policy_data_config(self, url: str, headers) -> aiohttp.ClientResponse:
async with ClientSession(headers=headers) as session:
return await session.get(url, **self._ssl_context_kwargs)

async def get_base_policy_data(
self, config_url: str = None, data_fetch_reason="Initial load"
):
Expand Down Expand Up @@ -279,13 +339,18 @@ async def _subscriber(self):
"""Coroutine meant to be spunoff with create_task to listen in the
background for data events and pass them to the data_fetcher."""
logger.info("Subscribing to topics: {topics}", topics=self._data_topics)

headers = {}
if self._extra_headers is not None:
headers = self._extra_headers.copy()
await self._authenticator.authenticate(headers)

self._client = PubSubClient(
self._data_topics,
self._update_policy_data_callback,
methods_class=TenantAwareRpcEventClientMethods,
on_connect=[self.on_connect, *self._on_connect_callbacks],
on_disconnect=[self.on_disconnect, *self._on_disconnect_callbacks],
extra_headers=self._extra_headers,
on_connect=[self.on_connect],
extra_headers=headers,
keep_alive=opal_client_config.KEEP_ALIVE_INTERVAL,
server_uri=self._server_url,
**self._ssl_context_kwargs,
Expand Down Expand Up @@ -344,23 +409,6 @@ async def wait_until_done(self):
if self._subscriber_task is not None:
await self._subscriber_task

@staticmethod
def calc_hash(data):
"""Calculate an hash (sah256) on the given data, if data isn't a
string, it will be converted to JSON.
String are encoded as 'utf-8' prior to hash calculation.
Returns:
the hash of the given data (as a a hexdigit string) or '' on failure to process.
"""
try:
if not isinstance(data, str):
data = json.dumps(data, default=pydantic_encoder)
return hashlib.sha256(data.encode("utf-8")).hexdigest()
except:
logger.exception("Failed to calculate hash for data {data}", data=data)
return ""

async def _update_policy_data(
self,
update: DataUpdate,
Expand Down Expand Up @@ -473,7 +521,7 @@ async def _store_fetched_update(self, update_item):
policy_data = result
# Create a report on the data-fetching
report = DataEntryReport(
entry=entry, hash=self.calc_hash(policy_data), fetched=True
entry=entry, hash=DataUpdater.calc_hash(policy_data), fetched=True
)

try:
Expand Down
28 changes: 18 additions & 10 deletions packages/opal-client/opal_client/policy/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import pydantic
from fastapi_websocket_pubsub import PubSubClient
from fastapi_websocket_pubsub.pub_sub_client import PubSubOnConnectCallback
from fastapi_websocket_rpc.rpc_channel import OnDisconnectCallback, RpcChannel
from fastapi_websocket_rpc.rpc_channel import RpcChannel
from opal_client.callbacks.register import CallbacksRegister
from opal_client.callbacks.reporter import CallbacksReporter
from opal_client.config import opal_client_config
Expand All @@ -17,6 +16,8 @@
DEFAULT_POLICY_STORE_GETTER,
)
from opal_common.async_utils import TakeANumberQueue, TasksPool
from opal_common.authentication.authenticator import Authenticator
from opal_common.authentication.authenticator_factory import AuthenticatorFactory
from opal_common.config import opal_common_config
from opal_common.schemas.data import DataUpdateReport
from opal_common.schemas.policy import PolicyBundle, PolicyUpdateMessage
Expand Down Expand Up @@ -44,8 +45,7 @@ def __init__(
data_fetcher: Optional[DataFetcher] = None,
callbacks_register: Optional[CallbacksRegister] = None,
opal_client_id: str = None,
on_connect: List[PubSubOnConnectCallback] = None,
on_disconnect: List[OnDisconnectCallback] = None,
authenticator: Optional[Authenticator] = None,
):
"""inits the policy updater.
Expand All @@ -67,6 +67,10 @@ def __init__(
self._opal_client_id = opal_client_id
self._scope_id = opal_client_config.SCOPE_ID

if authenticator is not None:
self._authenticator = authenticator
else:
self._authenticator = AuthenticatorFactory.create()
# The policy store we'll save policy modules into (i.e: OPA)
self._policy_store = policy_store or DEFAULT_POLICY_STORE_GETTER()
# pub/sub server url and authentication data
Expand All @@ -90,7 +94,7 @@ def __init__(
self._policy_update_task = None
self._stopping = False
# policy fetcher - fetches policy bundles
self._policy_fetcher = PolicyFetcher()
self._policy_fetcher = PolicyFetcher(authenticator=self._authenticator)
# callbacks on policy changes
self._data_fetcher = data_fetcher or DataFetcher()
self._callbacks_register = callbacks_register or CallbacksRegister()
Expand All @@ -107,8 +111,6 @@ def __init__(
)
self._policy_update_queue = asyncio.Queue()
self._tasks = TasksPool()
self._on_connect_callbacks = on_connect or []
self._on_disconnect_callbacks = on_disconnect or []

async def __aenter__(self):
await self.start()
Expand Down Expand Up @@ -245,12 +247,18 @@ async def _subscriber(self):
update_policy() callback (which will fetch the relevant policy bundle
from the server and update the policy store)."""
logger.info("Subscribing to topics: {topics}", topics=self._topics)

headers = {}
if self._extra_headers is not None:
headers = self._extra_headers.copy()
await self._authenticator.authenticate(headers)

self._client = PubSubClient(
topics=self._topics,
callback=self._update_policy_callback,
on_connect=[self._on_connect, *self._on_connect_callbacks],
on_disconnect=[self._on_disconnect, *self._on_disconnect_callbacks],
extra_headers=self._extra_headers,
on_connect=[self._on_connect],
on_disconnect=[self._on_disconnect],
extra_headers=headers,
keep_alive=opal_client_config.KEEP_ALIVE_INTERVAL,
server_uri=self._server_url,
**self._ssl_context_kwargs,
Expand Down

0 comments on commit 6ea4078

Please sign in to comment.