From 7535b4036ce980c9a05bc33a9e61a7938ea1303e Mon Sep 17 00:00:00 2001 From: Theodor Mihalache <84387487+tmihalac@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:21:40 -0400 Subject: [PATCH] fix: Added Offline Store Arrow client errors handler (#4524) * fix: Added Offline Store Arrow client errors handler Signed-off-by: Theodor Mihalache * Added more tests Signed-off-by: Theodor Mihalache --------- Signed-off-by: Theodor Mihalache --- sdk/python/feast/arrow_error_handler.py | 49 +++++++++++++ .../feast/infra/offline_stores/remote.py | 68 ++++++++++++++++--- sdk/python/feast/offline_server.py | 54 ++++++++++----- .../client/arrow_flight_auth_interceptor.py | 9 --- sdk/python/feast/permissions/server/arrow.py | 31 ++------- .../tests/unit/test_arrow_error_decorator.py | 33 +++++++++ sdk/python/tests/unit/test_offline_server.py | 32 ++++++++- 7 files changed, 215 insertions(+), 61 deletions(-) create mode 100644 sdk/python/feast/arrow_error_handler.py create mode 100644 sdk/python/tests/unit/test_arrow_error_decorator.py diff --git a/sdk/python/feast/arrow_error_handler.py b/sdk/python/feast/arrow_error_handler.py new file mode 100644 index 0000000000..e873592bd5 --- /dev/null +++ b/sdk/python/feast/arrow_error_handler.py @@ -0,0 +1,49 @@ +import logging +from functools import wraps + +import pyarrow.flight as fl + +from feast.errors import FeastError + +logger = logging.getLogger(__name__) + + +def arrow_client_error_handling_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + mapped_error = FeastError.from_error_detail(_get_exception_data(e.args[0])) + if mapped_error is not None: + raise mapped_error + raise e + + return wrapper + + +def arrow_server_error_handling_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if isinstance(e, FeastError): + raise fl.FlightError(e.to_error_detail()) + + return wrapper + + +def _get_exception_data(except_str) -> str: + substring = "Flight error: " + + # Find the starting index of the substring + position = except_str.find(substring) + end_json_index = except_str.find("}") + + if position != -1 and end_json_index != -1: + # Extract the part of the string after the substring + result = except_str[position + len(substring) : end_json_index + 1] + return result + + return "" diff --git a/sdk/python/feast/infra/offline_stores/remote.py b/sdk/python/feast/infra/offline_stores/remote.py index 40239c8950..8154f75f87 100644 --- a/sdk/python/feast/infra/offline_stores/remote.py +++ b/sdk/python/feast/infra/offline_stores/remote.py @@ -10,9 +10,12 @@ import pyarrow as pa import pyarrow.flight as fl import pyarrow.parquet +from pyarrow import Schema +from pyarrow._flight import FlightCallOptions, FlightDescriptor, Ticket from pydantic import StrictInt, StrictStr from feast import OnDemandFeatureView +from feast.arrow_error_handler import arrow_client_error_handling_decorator from feast.data_source import DataSource from feast.feature_logging import ( FeatureServiceLoggingSource, @@ -27,8 +30,10 @@ RetrievalMetadata, ) from feast.infra.registry.base_registry import BaseRegistry +from feast.permissions.auth.auth_type import AuthType +from feast.permissions.auth_model import AuthConfig from feast.permissions.client.arrow_flight_auth_interceptor import ( - build_arrow_flight_client, + FlightAuthInterceptorFactory, ) from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage @@ -36,6 +41,43 @@ logger = logging.getLogger(__name__) +class FeastFlightClient(fl.FlightClient): + @arrow_client_error_handling_decorator + def get_flight_info( + self, descriptor: FlightDescriptor, options: FlightCallOptions = None + ): + return super().get_flight_info(descriptor, options) + + @arrow_client_error_handling_decorator + def do_get(self, ticket: Ticket, options: FlightCallOptions = None): + return super().do_get(ticket, options) + + @arrow_client_error_handling_decorator + def do_put( + self, + descriptor: FlightDescriptor, + schema: Schema, + options: FlightCallOptions = None, + ): + return super().do_put(descriptor, schema, options) + + @arrow_client_error_handling_decorator + def list_flights(self, criteria: bytes = b"", options: FlightCallOptions = None): + return super().list_flights(criteria, options) + + @arrow_client_error_handling_decorator + def list_actions(self, options: FlightCallOptions = None): + return super().list_actions(options) + + +def build_arrow_flight_client(host: str, port, auth_config: AuthConfig): + if auth_config.type != AuthType.NONE.value: + middlewares = [FlightAuthInterceptorFactory(auth_config)] + return FeastFlightClient(f"grpc://{host}:{port}", middleware=middlewares) + + return FeastFlightClient(f"grpc://{host}:{port}") + + class RemoteOfflineStoreConfig(FeastConfigBaseModel): type: Literal["remote"] = "remote" host: StrictStr @@ -48,7 +90,7 @@ class RemoteOfflineStoreConfig(FeastConfigBaseModel): class RemoteRetrievalJob(RetrievalJob): def __init__( self, - client: fl.FlightClient, + client: FeastFlightClient, api: str, api_parameters: Dict[str, Any], entity_df: Union[pd.DataFrame, str] = None, @@ -338,7 +380,7 @@ def _send_retrieve_remote( api_parameters: Dict[str, Any], entity_df: Union[pd.DataFrame, str], table: pa.Table, - client: fl.FlightClient, + client: FeastFlightClient, ): command_descriptor = _call_put( api, @@ -351,19 +393,19 @@ def _send_retrieve_remote( def _call_get( - client: fl.FlightClient, + client: FeastFlightClient, command_descriptor: fl.FlightDescriptor, ): flight = client.get_flight_info(command_descriptor) ticket = flight.endpoints[0].ticket reader = client.do_get(ticket) - return reader.read_all() + return read_all(reader) def _call_put( api: str, api_parameters: Dict[str, Any], - client: fl.FlightClient, + client: FeastFlightClient, entity_df: Union[pd.DataFrame, str], table: pa.Table, ): @@ -391,7 +433,7 @@ def _put_parameters( command_descriptor: fl.FlightDescriptor, entity_df: Union[pd.DataFrame, str], table: pa.Table, - client: fl.FlightClient, + client: FeastFlightClient, ): updatedTable: pa.Table @@ -404,10 +446,20 @@ def _put_parameters( writer, _ = client.do_put(command_descriptor, updatedTable.schema) - writer.write_table(updatedTable) + write_table(writer, updatedTable) + + +@arrow_client_error_handling_decorator +def write_table(writer, updated_table: pa.Table): + writer.write_table(updated_table) writer.close() +@arrow_client_error_handling_decorator +def read_all(reader): + return reader.read_all() + + def _create_empty_table(): schema = pa.schema( { diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index 839acada93..ff3db579d0 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -9,16 +9,18 @@ import pyarrow.flight as fl from feast import FeatureStore, FeatureView, utils +from feast.arrow_error_handler import arrow_server_error_handling_decorator from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_view import DUMMY_ENTITY_NAME from feast.infra.offline_stores.offline_utils import get_offline_store_from_config from feast.permissions.action import AuthzedAction from feast.permissions.security_manager import assert_permissions from feast.permissions.server.arrow import ( - arrowflight_middleware, + AuthorizationMiddlewareFactory, inject_user_details_decorator, ) from feast.permissions.server.utils import ( + AuthManagerType, ServerType, init_auth_manager, init_security_manager, @@ -34,7 +36,7 @@ class OfflineServer(fl.FlightServerBase): def __init__(self, store: FeatureStore, location: str, **kwargs): super(OfflineServer, self).__init__( location, - middleware=arrowflight_middleware( + middleware=self.arrow_flight_auth_middleware( str_to_auth_manager_type(store.config.auth_config.type) ), **kwargs, @@ -45,6 +47,25 @@ def __init__(self, store: FeatureStore, location: str, **kwargs): self.store = store self.offline_store = get_offline_store_from_config(store.config.offline_store) + def arrow_flight_auth_middleware( + self, + auth_type: AuthManagerType, + ) -> dict[str, fl.ServerMiddlewareFactory]: + """ + A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined. + The authorization middleware key is `auth`. + + Returns: + dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns an empty dict. + """ + + if auth_type == AuthManagerType.NONE: + return {} + + return { + "auth": AuthorizationMiddlewareFactory(), + } + @classmethod def descriptor_to_key(self, descriptor: fl.FlightDescriptor): return ( @@ -61,15 +82,7 @@ def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor): return fl.FlightInfo(schema, descriptor, endpoints, -1, -1) @inject_user_details_decorator - def get_flight_info( - self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor - ): - key = OfflineServer.descriptor_to_key(descriptor) - if key in self.flights: - return self._make_flight_info(key, descriptor) - raise KeyError("Flight not found.") - - @inject_user_details_decorator + @arrow_server_error_handling_decorator def list_flights(self, context: fl.ServerCallContext, criteria: bytes): for key, table in self.flights.items(): if key[1] is not None: @@ -79,9 +92,20 @@ def list_flights(self, context: fl.ServerCallContext, criteria: bytes): yield self._make_flight_info(key, descriptor) + @inject_user_details_decorator + @arrow_server_error_handling_decorator + def get_flight_info( + self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor + ): + key = OfflineServer.descriptor_to_key(descriptor) + if key in self.flights: + return self._make_flight_info(key, descriptor) + raise KeyError("Flight not found.") + # Expects to receive request parameters and stores them in the flights dictionary # Indexed by the unique command @inject_user_details_decorator + @arrow_server_error_handling_decorator def do_put( self, context: fl.ServerCallContext, @@ -179,6 +203,7 @@ def _validate_do_get_parameters(self, command: dict): # Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance # and returns the stream of data @inject_user_details_decorator + @arrow_server_error_handling_decorator def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): key = ast.literal_eval(ticket.ticket.decode()) if key not in self.flights: @@ -337,6 +362,7 @@ def pull_latest_from_table_or_query(self, command: dict): utils.make_tzaware(datetime.fromisoformat(command["end_date"])), ) + @arrow_server_error_handling_decorator def list_actions(self, context): return [ ( @@ -431,12 +457,6 @@ def persist(self, command: dict, key: str): traceback.print_exc() raise e - def do_action(self, context: fl.ServerCallContext, action: fl.Action): - pass - - def do_drop_dataset(self, dataset): - pass - def remove_dummies(fv: FeatureView) -> FeatureView: """ diff --git a/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py b/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py index 7ef84fbeae..c3281bfa51 100644 --- a/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py +++ b/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py @@ -1,6 +1,5 @@ import pyarrow.flight as fl -from feast.permissions.auth.auth_type import AuthType from feast.permissions.auth_model import AuthConfig from feast.permissions.client.client_auth_token import get_auth_token @@ -28,11 +27,3 @@ def __init__(self, auth_config: AuthConfig): def start_call(self, info): return FlightBearerTokenInterceptor(self.auth_config) - - -def build_arrow_flight_client(host: str, port, auth_config: AuthConfig): - if auth_config.type != AuthType.NONE.value: - middleware_factory = FlightAuthInterceptorFactory(auth_config) - return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory]) - else: - return fl.FlightClient(f"grpc://{host}:{port}") diff --git a/sdk/python/feast/permissions/server/arrow.py b/sdk/python/feast/permissions/server/arrow.py index 5eba7d0916..4f0afc3ee5 100644 --- a/sdk/python/feast/permissions/server/arrow.py +++ b/sdk/python/feast/permissions/server/arrow.py @@ -5,7 +5,7 @@ import asyncio import functools import logging -from typing import Optional, cast +from typing import cast import pyarrow.flight as fl from pyarrow.flight import ServerCallContext @@ -14,41 +14,19 @@ get_auth_manager, ) from feast.permissions.security_manager import get_security_manager -from feast.permissions.server.utils import ( - AuthManagerType, -) from feast.permissions.user import User logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -def arrowflight_middleware( - auth_type: AuthManagerType, -) -> Optional[dict[str, fl.ServerMiddlewareFactory]]: - """ - A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined. - The authorization middleware key is `auth`. - - Returns: - dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns `None`. - """ - - if auth_type == AuthManagerType.NONE: - return None - - return { - "auth": AuthorizationMiddlewareFactory(), - } - - class AuthorizationMiddlewareFactory(fl.ServerMiddlewareFactory): """ A middleware factory to intercept the authorization header and propagate it to the authorization middleware. """ - def __init__(self): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def start_call(self, info, headers): """ @@ -65,7 +43,8 @@ class AuthorizationMiddleware(fl.ServerMiddleware): A server middleware holding the authorization header and offering a method to extract the user credentials. """ - def __init__(self, access_token: str): + def __init__(self, access_token: str, *args, **kwargs): + super().__init__(*args, **kwargs) self.access_token = access_token def call_completed(self, exception): diff --git a/sdk/python/tests/unit/test_arrow_error_decorator.py b/sdk/python/tests/unit/test_arrow_error_decorator.py new file mode 100644 index 0000000000..fc350d34c0 --- /dev/null +++ b/sdk/python/tests/unit/test_arrow_error_decorator.py @@ -0,0 +1,33 @@ +import pyarrow.flight as fl +import pytest + +from feast.arrow_error_handler import arrow_client_error_handling_decorator +from feast.errors import PermissionNotFoundException + +permissionError = PermissionNotFoundException("dummy_name", "dummy_project") + + +@arrow_client_error_handling_decorator +def decorated_method(error): + raise error + + +@pytest.mark.parametrize( + "error, expected_raised_error", + [ + (fl.FlightError("Flight error: "), fl.FlightError("Flight error: ")), + ( + fl.FlightError(f"Flight error: {permissionError.to_error_detail()}"), + permissionError, + ), + (fl.FlightError("Test Error"), fl.FlightError("Test Error")), + (RuntimeError("Flight error: "), RuntimeError("Flight error: ")), + (permissionError, permissionError), + ], +) +def test_rest_error_handling_with_feast_exception(error, expected_raised_error): + with pytest.raises( + type(expected_raised_error), + match=str(expected_raised_error), + ): + decorated_method(error) diff --git a/sdk/python/tests/unit/test_offline_server.py b/sdk/python/tests/unit/test_offline_server.py index 237e2ecad4..7c38d9bfca 100644 --- a/sdk/python/tests/unit/test_offline_server.py +++ b/sdk/python/tests/unit/test_offline_server.py @@ -8,7 +8,8 @@ import pyarrow.flight as flight import pytest -from feast import FeatureStore +from feast import FeatureStore, FeatureView, FileSource +from feast.errors import FeatureViewNotFoundException from feast.feature_logging import FeatureServiceLoggingSource from feast.infra.offline_stores.remote import ( RemoteOfflineStore, @@ -120,6 +121,35 @@ def test_remote_offline_store_apis(): _test_pull_all_from_table_or_query(str(temp_dir), fs) +def test_remote_offline_store_exception_handling(): + with tempfile.TemporaryDirectory() as temp_dir: + store = default_store(str(temp_dir)) + location = "grpc+tcp://localhost:0" + + _init_auth_manager(store=store) + server = OfflineServer(store=store, location=location) + + assertpy.assert_that(server).is_not_none + assertpy.assert_that(server.port).is_not_equal_to(0) + + fs = remote_feature_store(server) + data_file = os.path.join( + temp_dir, fs.project, "feature_repo/data/driver_stats.parquet" + ) + data_df = pd.read_parquet(data_file) + + with pytest.raises( + FeatureViewNotFoundException, + match="Feature view test does not exist in project test_remote_offline", + ): + RemoteOfflineStore.offline_write_batch( + fs.config, + FeatureView(name="test", source=FileSource(path="test")), + pa.Table.from_pandas(data_df), + progress=None, + ) + + def _test_get_historical_features_returns_data(fs: FeatureStore): entity_df = pd.DataFrame.from_dict( {