diff --git a/cylc/flow/async_util.py b/cylc/flow/async_util.py index aa62b39acc7..39eb1250514 100644 --- a/cylc/flow/async_util.py +++ b/cylc/flow/async_util.py @@ -165,7 +165,7 @@ async def _chain(self, item, coros, completed): ret = await coro.func(item, *coro.args, **coro.kwargs) except Exception as exc: # if something goes wrong log the error and skip the item - LOG.warning(exc) + LOG.warning(f"{type(exc).__name__}: {exc}") ret = False if ret is True: # filter passed -> continue diff --git a/cylc/flow/network/__init__.py b/cylc/flow/network/__init__.py index d14e11c6487..6c1ff7de621 100644 --- a/cylc/flow/network/__init__.py +++ b/cylc/flow/network/__init__.py @@ -16,9 +16,7 @@ """Package for network interfaces to Cylc scheduler objects.""" import asyncio -import getpass -import json -from typing import Any, Dict +from typing import NamedTuple, Optional import zmq import zmq.asyncio @@ -45,21 +43,16 @@ MSG_TIMEOUT = "TIMEOUT" -def encode_(message: object) -> str: - """Convert the structure holding a message field from JSON to a string.""" - try: - return json.dumps(message) - except TypeError as exc: - return json.dumps({'errors': [{'message': str(exc)}]}) - - -def decode_(message: str) -> Dict[str, Any]: - """Convert an encoded message string to JSON with an added 'user' field.""" - msg: object = json.loads(message) - if not isinstance(msg, dict): - raise ValueError(f"Expected message to be dict but got {type(msg)}") - msg['user'] = getpass.getuser() # assume this is the user - return msg +class ResponseTuple(NamedTuple): + """Structure of server response messages.""" + content: Optional[object] = None + err: Optional['ResponseErrTuple'] = None + user: Optional[str] = None + + +class ResponseErrTuple(NamedTuple): + message: str + traceback: Optional[str] = None def get_location(workflow: str): diff --git a/cylc/flow/network/client.py b/cylc/flow/network/client.py index ae917c41aa9..ea5b6aad51f 100644 --- a/cylc/flow/network/client.py +++ b/cylc/flow/network/client.py @@ -16,6 +16,7 @@ """Client for workflow runtime API.""" from functools import partial +import json import os from shutil import which import socket @@ -34,9 +35,8 @@ WorkflowStopped, ) from cylc.flow.network import ( - encode_, - decode_, get_location, + ResponseTuple, ZMQSocketBase ) from cylc.flow.network.client_factory import CommsMeth @@ -165,7 +165,7 @@ async def async_request( args: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, req_meta: Optional[Dict[str, Any]] = None - ) -> object: + ) -> Union[bytes, object]: """Send an asynchronous request using asyncio. Has the same arguments and return values as ``serial_request``. @@ -187,12 +187,12 @@ async def async_request( if req_meta: msg['meta'].update(req_meta) LOG.debug('zmq:send %s', msg) - message = encode_(msg) + message = json.dumps(msg) self.socket.send_string(message) # receive response if self.poller.poll(timeout): - res = await self.socket.recv() + res: bytes = await self.socket.recv() else: if callable(self.timeout_handler): self.timeout_handler() @@ -204,24 +204,20 @@ async def async_request( ' This could be due to network or server issues.' ' Check the workflow log.' ) + LOG.debug('zmq:recv %s', res) - if msg['command'] in PB_METHOD_MAP: - response = {'data': res} - else: - response = decode_(res.decode()) - LOG.debug('zmq:recv %s', response) + if command in PB_METHOD_MAP: + return res - try: - return response['data'] - except KeyError: - error = response.get( - 'error', - {'message': f'Received invalid response: {response}'}, - ) - raise ClientError( - error.get('message'), - error.get('traceback'), - ) + response = ResponseTuple( # type: ignore[misc] + *json.loads(res.decode()) + ) + + if response.content is not None: + return response.content + if response.err: + raise ClientError(*response.err) + raise ClientError(f"Received invalid response: {response}") def serial_request( self, @@ -229,7 +225,7 @@ def serial_request( args: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, req_meta: Optional[Dict[str, Any]] = None - ) -> object: + ) -> Union[bytes, object]: """Send a request. For convenience use ``__call__`` to call this method. diff --git a/cylc/flow/network/graphql.py b/cylc/flow/network/graphql.py index 7650a58b154..ad20bd8730f 100644 --- a/cylc/flow/network/graphql.py +++ b/cylc/flow/network/graphql.py @@ -21,7 +21,7 @@ from functools import partial import logging -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Tuple, Union from inspect import isclass, iscoroutinefunction @@ -29,6 +29,7 @@ from graphql.execution.utils import ( get_operation_root_type, get_field_def ) +from graphql.execution import ExecutionResult from graphql.execution.values import get_argument_values, get_variable_values from graphql.language.base import parse, print_ast from graphql.language import ast @@ -42,7 +43,6 @@ from cylc.flow.network.schema import NODE_MAP, get_type_str if TYPE_CHECKING: - from graphql.execution import ExecutionResult from graphql.language.ast import Document from graphql.type import GraphQLSchema @@ -146,6 +146,14 @@ def null_stripper(exe_result): return exe_result +def format_execution_result( + result: Union[ExecutionResult, Dict[str, Any]] +) -> Dict[str, Any]: + if isinstance(result, ExecutionResult): + result = result.to_dict() + return strip_null(result) + + class AstDocArguments: """Request doc Argument inspection.""" @@ -254,7 +262,7 @@ def execute_and_validate_and_strip( document_ast: 'Document', *args: Any, **kwargs: Any -) -> Union['ExecutionResult', Observable]: +) -> Union[ExecutionResult, Observable]: """Wrapper around graphql ``execute_and_validate()`` that adds null stripping.""" result = execute_and_validate(schema, document_ast, *args, **kwargs) diff --git a/cylc/flow/network/replier.py b/cylc/flow/network/replier.py index 62e9361a8cc..16044111968 100644 --- a/cylc/flow/network/replier.py +++ b/cylc/flow/network/replier.py @@ -15,13 +15,21 @@ # along with this program. If not, see . """Server for workflow runtime API.""" -import getpass # noqa: F401 +import getpass +import json from queue import Queue +from typing import TYPE_CHECKING +from typing_extensions import Literal import zmq from cylc.flow import LOG -from cylc.flow.network import encode_, decode_, ZMQSocketBase +from cylc.flow.network import ( + ResponseErrTuple, ResponseTuple, ZMQSocketBase +) + +if TYPE_CHECKING: + from cylc.flow.network.server import WorkflowRuntimeServer class WorkflowReplier(ZMQSocketBase): @@ -46,11 +54,15 @@ class WorkflowReplier(ZMQSocketBase): """ - def __init__(self, server, context=None): + def __init__( + self, + server: 'WorkflowRuntimeServer', + context=None + ): super().__init__(zmq.REP, bind=True, context=context) self.server = server self.workflow = server.schd.workflow - self.queue = Queue() + self.queue: Queue[Literal['STOP']] = Queue() def _bespoke_stop(self) -> None: """Stop the listener and Authenticator. @@ -92,27 +104,39 @@ def listener(self) -> None: continue # attempt to decode the message, authenticating the user in the # process + response: bytes try: - message = decode_(msg) + message = json.loads(msg) + user = getpass.getuser() # assume this is the user except Exception as exc: # purposefully catch generic exception # failed to decode message, possibly resulting from failed # authentication - LOG.exception('failed to decode message: "%s"', exc) + LOG.exception(exc) + LOG.error(f'failed to decode message: "{msg}"') import traceback - response = encode_( - { - 'error': { - 'message': 'failed to decode message: "%s"' % msg, - 'traceback': traceback.format_exc(), - } - } + response = json.dumps( + ResponseTuple( + err=ResponseErrTuple( + f'failed to decode message: {msg}"', + traceback.format_exc(), + ) + ) ).encode() else: # success case - serve the request - res = self.server.receiver(message) - # send back the string to bytes response - if isinstance(res.get('data'), bytes): - response = res['data'] + res = self.server.receiver(message, user) + if isinstance(res.content, bytes): # is protobuf method + # just return bytes, as cannot serialize bytes to JSON + response = res.content else: - response = encode_(res).encode() + try: + response = json.dumps(res).encode() + except TypeError as exc: + err_msg = f"failed to encode response: {res}\n{exc}" + LOG.warning(err_msg) + res = ResponseTuple( + err=ResponseErrTuple(err_msg) + ) + response = json.dumps(res).encode() + # send back the string to bytes response self.socket.send(response) diff --git a/cylc/flow/network/resolvers.py b/cylc/flow/network/resolvers.py index dbb170d3eda..cb6633b60b8 100644 --- a/cylc/flow/network/resolvers.py +++ b/cylc/flow/network/resolvers.py @@ -45,6 +45,7 @@ from cylc.flow.id import Tokens from cylc.flow.network.schema import ( DEF_TYPES, + GenericResponse, NodesEdges, PROXY_NODES, SUB_RESOLVERS, @@ -629,7 +630,7 @@ async def mutator( w_args: Dict[str, Any], kwargs: Dict[str, Any], meta: Dict[str, Any] - ) -> List[Dict[str, Any]]: + ) -> List[GenericResponse]: ... @@ -650,19 +651,22 @@ async def mutator( w_args: Dict[str, Any], kwargs: Dict[str, Any], meta: Dict[str, Any] - ) -> List[Dict[str, Any]]: + ) -> List[GenericResponse]: """Mutate workflow.""" w_ids = [flow[WORKFLOW].id for flow in await self.get_workflows_data(w_args)] if not w_ids: workflows = list(self.data_store_mgr.data.keys()) - return [{ - 'response': (False, f'No matching workflow in {workflows}')}] + ret = GenericResponse( + success=False, message=f'No matching workflow in {workflows}' + ) + return [ret] w_id = w_ids[0] result = await self._mutation_mapper(command, kwargs, meta) if result is None: result = (True, 'Command queued') - return [{'id': w_id, 'response': result}] + ret = GenericResponse(w_id, *result) + return [ret] async def _mutation_mapper( self, command: str, kwargs: Dict[str, Any], meta: Dict[str, Any] diff --git a/cylc/flow/network/scan.py b/cylc/flow/network/scan.py index 12e03ef3f45..e42e5ac6cd0 100644 --- a/cylc/flow/network/scan.py +++ b/cylc/flow/network/scan.py @@ -50,6 +50,7 @@ import re from typing import AsyncGenerator, Dict, Iterable, List, Optional, Tuple, Union +from graphql.execution import ExecutionResult from pkg_resources import ( parse_requirements, parse_version @@ -417,7 +418,9 @@ def format_query(fields, filters=None): @pipe(preproc=format_query) -async def graphql_query(flow, fields, filters=None): +async def graphql_query( + flow: dict, fields: Iterable[str], filters: Optional[list] = None +) -> Union[bool, dict]: """Obtain information from a GraphQL request to the flow. Requires: @@ -425,9 +428,9 @@ async def graphql_query(flow, fields, filters=None): * contact_info Args: - flow (dict): + flow: Flow information dictionary, provided by scan through the pipe. - fields (iterable): + fields: Iterable containing the fields to request e.g:: ['id', 'name'] @@ -435,7 +438,7 @@ async def graphql_query(flow, fields, filters=None): One level of nesting is supported e.g:: {'name': None, 'meta': ['title']} - filters (list): + filters: Filter by the data returned from the query. List in the form ``[(key, ...), value]``, e.g:: @@ -458,7 +461,8 @@ async def graphql_query(flow, fields, filters=None): LOG.warning(f'Workflow not running: {flow["name"]}') return False try: - ret = await client.async_request( + ret: dict = await client.async_request( # type: ignore[assignment] + # (graphql request gives dict) 'graphql', { 'request_string': query, @@ -476,13 +480,20 @@ async def graphql_query(flow, fields, filters=None): LOG.exception(exc) return False else: + response = ExecutionResult(**ret) # stick the result into the flow object - for item in ret: - if 'error' in item: - LOG.exception(item['error']['message']) - return False - for workflow in ret.get('workflows', []): - flow.update(workflow) + if not response.data: + if response.errors: + LOG.error("Scan error(s)") + for err in response.errors: + LOG.error(err) + else: + LOG.exception("Scan error: empty response") + return False + for workflow in response.data.get('workflows', []): + flow.update(workflow) + # TODO: what if no items in workflows list, will this cause + # KeyError below when trying to access flow[field_]? # process filters for field, value in filters or []: diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py index bed66b86680..16a464fbae1 100644 --- a/cylc/flow/network/schema.py +++ b/cylc/flow/network/schema.py @@ -28,6 +28,7 @@ Any, List, Optional, + Type, ) import graphene @@ -1187,9 +1188,22 @@ class Meta: # Generic containers class GenericResponse(ObjectType): class Meta: - description = """Container for command queued response""" + description = """Container for workflow command queued response""" - result = GenericScalar() + workflowId = String() + success = Boolean(required=True) + message = String(required=True) + + # Define __init__ for benefit of static type checking: + def __init__( + self, + workflowId: Optional[str] = None, + success: Optional[bool] = None, + message: Optional[str] = None + ): + # Note: all args optional here to allow for not requesting them in a + # mutation + ObjectType.__init__(self, workflowId, success, message) # Mutators are used to call the internals of the parent program in the @@ -1207,7 +1221,7 @@ async def mutator( workflows: Optional[List[str]] = None, exworkflows: Optional[List[str]] = None, **kwargs: Any -) -> GenericResponse: +) -> List[GenericResponse]: """Call the resolver method that act on the workflow service via the internal command queue. @@ -1239,7 +1253,9 @@ async def mutator( ) meta = info.context.get('meta') # type: ignore[union-attr] res = await resolvers.mutator(info, command, w_args, kwargs, meta) - return GenericResponse(result=res) + return info.return_type.graphene_type( # type: ignore[union-attr] + results=res + ) # Input types: @@ -1409,7 +1425,15 @@ class Flow(String): # Mutations: -class Broadcast(Mutation): +class WorkflowsMutation: + """Base class for mutations involving workflows.""" + class Arguments: + workflows = graphene.List(WorkflowID, required=True) + + results = graphene.List(GenericResponse) + + +class Broadcast(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Override `[runtime]` configurations in a running workflow. @@ -1444,9 +1468,7 @@ class Meta: ''') resolver = partial(mutator, command='broadcast') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) - + class Arguments(WorkflowsMutation.Arguments): mode = BroadcastMode( # use the enum name as the default value # https://github.com/graphql-python/graphql-core-legacy/issues/166 @@ -1492,10 +1514,8 @@ class Arguments: # ''') # ) - result = GenericScalar() - -class SetHoldPoint(Mutation): +class SetHoldPoint(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Set workflow hold after cycle point. All tasks after this point @@ -1503,17 +1523,14 @@ class Meta: ''') resolver = partial(mutator, command='set_hold_point') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) + class Arguments(WorkflowsMutation.Arguments): point = CyclePoint( description='Hold all tasks after the specified cycle point.', required=True ) - result = GenericScalar() - -class Pause(Mutation): +class Pause(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Pause a workflow. @@ -1522,13 +1539,8 @@ class Meta: ''') resolver = partial(mutator, command='pause') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) - - result = GenericScalar() - -class Message(Mutation): +class Message(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Record task job messages. @@ -1545,8 +1557,7 @@ class Meta: ''') resolver = partial(mutator, command='put_messages') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) + class Arguments(WorkflowsMutation.Arguments): task_job = String(required=True) event_time = String(default_value=None) messages = graphene.List( @@ -1555,10 +1566,8 @@ class Arguments: default_value=None ) - result = GenericScalar() - -class ReleaseHoldPoint(Mutation): +class ReleaseHoldPoint(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Release all tasks and unset the workflow hold point, if set. @@ -1567,13 +1576,8 @@ class Meta: ''') resolver = partial(mutator, command='release_hold_point') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) - - result = GenericScalar() - -class Resume(Mutation): +class Resume(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Resume a paused workflow. @@ -1582,13 +1586,8 @@ class Meta: ''') resolver = partial(mutator, command='resume') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) - result = GenericScalar() - - -class Reload(Mutation): +class Reload(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Reload the configuration of a running workflow. @@ -1608,13 +1607,8 @@ class Meta: ''') resolver = partial(mutator, command='reload_workflow') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) - result = GenericScalar() - - -class SetVerbosity(Mutation): +class SetVerbosity(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Change the logging severity level of a running workflow. @@ -1625,14 +1619,11 @@ class Meta: ''') resolver = partial(mutator, command='set_verbosity') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) + class Arguments(WorkflowsMutation.Arguments): level = LogLevels(required=True) - result = GenericScalar() - -class SetGraphWindowExtent(Mutation): +class SetGraphWindowExtent(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Set the maximum graph distance (n) from an active node @@ -1641,14 +1632,11 @@ class Meta: ''') resolver = partial(mutator, command='set_graph_window_extent') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) + class Arguments(WorkflowsMutation.Arguments): n_edge_distance = Int(required=True) - result = GenericScalar() - -class Stop(Mutation): +class Stop(Mutation, WorkflowsMutation): class Meta: description = sstrip(f''' Tell a workflow to shut down or stop a specified @@ -1666,8 +1654,7 @@ class Meta: ''') resolver = partial(mutator, command='stop') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) + class Arguments(WorkflowsMutation.Arguments): mode = WorkflowStopMode( default_value=WorkflowStopMode.Clean.name ) @@ -1684,10 +1671,8 @@ class Arguments: description='Number of flow to stop.' ) - result = GenericScalar() - -class ExtTrigger(Mutation): +class ExtTrigger(Mutation, WorkflowsMutation): class Meta: description = sstrip(''' Report an external event message to a scheduler. @@ -1714,8 +1699,7 @@ class Meta: ''') resolver = partial(mutator, command='put_ext_trigger') - class Arguments: - workflows = graphene.List(WorkflowID, required=True) + class Arguments(WorkflowsMutation.Arguments): message = String( description='External trigger message.', required=True @@ -1725,21 +1709,11 @@ class Arguments: required=True ) - result = GenericScalar() - -class TaskMutation: - class Arguments: - workflows = graphene.List( - WorkflowID, - required=True - ) - tasks = graphene.List( - NamespaceIDGlob, - required=True - ) - - result = GenericScalar() +class TasksMutation(WorkflowsMutation): + """Base class for mutations involving tasks.""" + class Arguments(WorkflowsMutation.Arguments): + tasks = graphene.List(NamespaceIDGlob, required=True) class FlowMutationArguments: @@ -1763,7 +1737,7 @@ class FlowMutationArguments: ) -class Hold(Mutation, TaskMutation): +class Hold(Mutation, TasksMutation): class Meta: description = sstrip(''' Hold tasks within a workflow. @@ -1773,7 +1747,7 @@ class Meta: resolver = partial(mutator, command='hold') -class Release(Mutation, TaskMutation): +class Release(Mutation, TasksMutation): class Meta: description = sstrip(''' Release held tasks within a workflow. @@ -1783,7 +1757,7 @@ class Meta: resolver = partial(mutator, command='release') -class Kill(Mutation, TaskMutation): +class Kill(Mutation, TasksMutation): # TODO: This should be a job mutation? class Meta: description = sstrip(''' @@ -1792,7 +1766,7 @@ class Meta: resolver = partial(mutator, command='kill_tasks') -class Poll(Mutation, TaskMutation): +class Poll(Mutation, TasksMutation): class Meta: description = sstrip(''' Poll (query) task jobs to verify and update their statuses. @@ -1806,11 +1780,8 @@ class Meta: ''') resolver = partial(mutator, command='poll_tasks') - class Arguments(TaskMutation.Arguments): - ... - -class Remove(Mutation, TaskMutation): +class Remove(Mutation, TasksMutation): class Meta: description = sstrip(''' Remove one or more task instances from a running workflow. @@ -1819,7 +1790,7 @@ class Meta: resolver = partial(mutator, command='remove_tasks') -class SetOutputs(Mutation, TaskMutation): +class SetOutputs(Mutation, TasksMutation): class Meta: description = sstrip(''' Artificially mark task outputs as completed. @@ -1831,7 +1802,7 @@ class Meta: ''') resolver = partial(mutator, command='force_spawn_children') - class Arguments(TaskMutation.Arguments): + class Arguments(TasksMutation.Arguments): outputs = graphene.List( String, default_value=[TASK_OUTPUT_SUCCEEDED], @@ -1840,7 +1811,7 @@ class Arguments(TaskMutation.Arguments): flow_num = Int() -class Trigger(Mutation, TaskMutation): +class Trigger(Mutation, TasksMutation): class Meta: description = sstrip(''' Manually trigger tasks. @@ -1853,7 +1824,7 @@ class Meta: ''') resolver = partial(mutator, command='force_trigger_tasks') - class Arguments(TaskMutation.Arguments, FlowMutationArguments): + class Arguments(TasksMutation.Arguments, FlowMutationArguments): flow_wait = Boolean( default_value=False, description=sstrip(''' @@ -1879,17 +1850,13 @@ class Arguments(TaskMutation.Arguments, FlowMutationArguments): ) -def _mut_field(cls): +def _mut_field(cls: Type[Mutation]) -> Field: """Convert a mutation class into a field. Sets the field metadata appropriately. Args: - field (class): - Subclass of graphene.Mutation - - Returns: - graphene.Field + cls: Subclass of graphene.Mutation """ return cls.Field(description=cls._meta.description) diff --git a/cylc/flow/network/server.py b/cylc/flow/network/server.py index 6802bf0cedb..1649345864e 100644 --- a/cylc/flow/network/server.py +++ b/cylc/flow/network/server.py @@ -22,6 +22,7 @@ from time import sleep from typing import Any, Dict, List, Optional, Union +from graphql.error import GraphQLError from graphql.execution import ExecutionResult from graphql.execution.executors.asyncio import AsyncioExecutor import zmq @@ -29,9 +30,13 @@ from cylc.flow import LOG, workflow_files from cylc.flow.cfgspec.glbl_cfg import glbl_cfg +from cylc.flow.network import ResponseErrTuple, ResponseTuple from cylc.flow.network.authorisation import authorise from cylc.flow.network.graphql import ( - CylcGraphQLBackend, IgnoreFieldMiddleware, instantiate_middleware + CylcGraphQLBackend, + IgnoreFieldMiddleware, + format_execution_result, + instantiate_middleware ) from cylc.flow.network.publisher import WorkflowPublisher from cylc.flow.network.replier import WorkflowReplier @@ -252,44 +257,56 @@ def operate(self): # Yield control to other threads sleep(self.OPERATE_SLEEP_INTERVAL) - def receiver(self, message): + def receiver( + self, message: Dict[str, Any], user: str + ) -> ResponseTuple: """Process incoming messages and coordinate response. Wrap incoming messages, dispatch them to exposed methods and/or coordinate a publishing stream. Args: - message (dict): message contents + message: message contents """ # TODO: If requested, coordinate publishing response/stream. # determine the server method to call + if not isinstance(message, dict): + return ResponseTuple( + err=ResponseErrTuple( + f'Expected dict but request is: {message}' + ) + ) try: method = getattr(self, message['command']) - args = message['args'] - args.update({'user': message['user']}) + args: dict = message['args'] + args.update({'user': user}) if 'meta' in message: args['meta'] = message['meta'] except KeyError: # malformed message - return {'error': { - 'message': 'Request missing required field(s).'}} + return ResponseTuple( + err=ResponseErrTuple('Request missing required field(s).') + ) except AttributeError: # no exposed method by that name - return {'error': { - 'message': 'No method by the name "%s"' % message['command']}} - + return ResponseTuple( + err=ResponseErrTuple( + f"No method by the name '{message['command']}'" + ) + ) # generate response try: response = method(**args) except Exception as exc: # includes incorrect arguments (TypeError) - LOG.exception(exc) # note the error server side + LOG.exception(exc) # log the error server side import traceback - return {'error': { - 'message': str(exc), 'traceback': traceback.format_exc()}} + return ResponseTuple( + err=ResponseErrTuple(str(exc), traceback.format_exc()) + ) - return {'data': response} + return ResponseTuple(content=response) def register_endpoints(self): """Register all exposed methods.""" @@ -341,16 +358,13 @@ def graphql( request_string: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None - ): + ) -> Dict[str, Any]: """Return the GraphQL schema execution result. Args: request_string: GraphQL request passed to Graphene. variables: Dict of variables passed to Graphene. meta: Dict containing auth user etc. - - Returns: - object: Execution result, or a list with errors. """ try: executed: ExecutionResult = schema.execute( @@ -367,20 +381,22 @@ def graphql( return_promise=False, ) except Exception as exc: - return 'ERROR: GraphQL execution error \n%s' % exc + raise GraphQLError(f"ERROR: GraphQL execution error \n{exc}") if executed.errors: - errors: List[Any] = [] - for error in executed.errors: - if hasattr(error, '__traceback__'): + for i, excp in enumerate(executed.errors): + if isinstance(excp, GraphQLError): + error = excp + else: + error = GraphQLError(message=str(excp)) + if hasattr(excp, '__traceback__'): import traceback - errors.append({'error': { - 'message': str(error), - 'traceback': traceback.format_exception( - error.__class__, error, error.__traceback__)}}) - continue - errors.append(getattr(error, 'message', None)) - return errors - return executed.data + extensions = error.extensions or {} + extensions['traceback'] = traceback.format_exception( + excp.__class__, excp, excp.__traceback__ + ) + error.extensions = extensions + executed.errors[i] = error + return format_execution_result(executed) # UIServer Data Commands @authorise() diff --git a/tests/integration/graphql/test_root.py b/tests/integration/graphql/test_root.py index 12b16822abb..a177e1085ca 100644 --- a/tests/integration/graphql/test_root.py +++ b/tests/integration/graphql/test_root.py @@ -62,9 +62,11 @@ async def test_workflows(harness): schd, client, query = harness ret = await query('workflows(ids: ["%s"]) { id }' % schd.workflow) assert ret == { - 'workflows': [ - {'id': f'~{schd.owner}/{schd.workflow}'} - ] + 'data': { + 'workflows': [ + {'id': f'~{schd.owner}/{schd.workflow}'} + ] + } } @@ -73,7 +75,9 @@ async def test_jobs(harness): schd, client, query = harness ret = await query('workflows(ids: ["%s"]) { id }' % schd.workflow) assert ret == { - 'workflows': [ - {'id': f'~{schd.owner}/{schd.workflow}'} - ] + 'data': { + 'workflows': [ + {'id': f'~{schd.owner}/{schd.workflow}'} + ] + } } diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 07485d2dcf1..0303aeb8e51 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -37,7 +37,7 @@ async def test_graphql(harness): 'graphql', {'request_string': 'query { workflows { id } }'} ) - workflows = ret['workflows'] + workflows = ret['data']['workflows'] assert len(workflows) == 1 workflow = workflows[0] assert schd.workflow in workflow['id'] diff --git a/tests/integration/test_graphql.py b/tests/integration/test_graphql.py index 81d0c806069..43cb53cfcab 100644 --- a/tests/integration/test_graphql.py +++ b/tests/integration/test_graphql.py @@ -116,11 +116,11 @@ async def test_workflows(harness): {'request_string': 'query { workflows { id } }'} ) assert ret == { - 'workflows': [ - { - 'id': f'{w_tokens}' - } - ] + 'data': { + 'workflows': [ + {'id': f'{w_tokens}'} + ] + } } @@ -132,6 +132,7 @@ async def test_tasks(harness): 'graphql', {'request_string': 'query { tasks { id } }'} ) + ret = ret['data'] ids = [ w_tokens.duplicate(cycle=f'$namespace|{namespace}').id for namespace in ('a', 'b', 'c', 'd') @@ -150,7 +151,7 @@ async def test_tasks(harness): 'graphql', {'request_string': 'query { task(id: "%s") { id } }' % id_} ) - assert ret == { + assert ret['data'] == { 'task': {'id': id_} } @@ -163,6 +164,7 @@ async def test_families(harness): 'graphql', {'request_string': 'query { families { id } }'} ) + ret = ret['data'] ids = [ w_tokens.duplicate( cycle=f'$namespace|{namespace}' @@ -183,7 +185,7 @@ async def test_families(harness): 'graphql', {'request_string': 'query { family(id: "%s") { id } }' % id_} ) - assert ret == { + assert ret['data'] == { 'family': {'id': id_} } @@ -196,6 +198,7 @@ async def test_task_proxies(harness): 'graphql', {'request_string': 'query { taskProxies { id } }'} ) + ret = ret['data'] ids = [ w_tokens.duplicate( cycle='1', @@ -217,7 +220,7 @@ async def test_task_proxies(harness): 'graphql', {'request_string': 'query { taskProxy(id: "%s") { id } }' % ids[0]} ) - assert ret == { + assert ret['data'] == { 'taskProxy': {'id': ids[0]} } @@ -230,6 +233,7 @@ async def test_family_proxies(harness): 'graphql', {'request_string': 'query { familyProxies { id } }'} ) + ret = ret['data'] ids = [ w_tokens.duplicate( cycle='1', @@ -252,7 +256,7 @@ async def test_family_proxies(harness): 'graphql', {'request_string': 'query { familyProxy(id: "%s") { id } }' % id_} ) - assert ret == { + assert ret['data'] == { 'familyProxy': {'id': id_} } @@ -288,7 +292,7 @@ async def test_edges(harness): 'graphql', {'request_string': 'query { edges { id } }'} ) - assert ret == { + assert ret['data'] == { 'edges': [ {'id': id_} for id_ in e_ids @@ -300,6 +304,7 @@ async def test_edges(harness): 'graphql', {'request_string': 'query { nodesEdges { nodes {id}\nedges {id} } }'} ) + ret = ret['data'] ret['nodesEdges']['nodes'].sort(key=lambda x: x['id']) ret['nodesEdges']['edges'].sort(key=lambda x: x['id']) assert ret == { @@ -334,7 +339,7 @@ async def test_jobs(harness): 'graphql', {'request_string': 'query { jobs { id } }'} ) - assert ret == { + assert ret['data'] == { 'jobs': [ {'id': f'{j_id}'} ] @@ -345,6 +350,6 @@ async def test_jobs(harness): 'graphql', {'request_string': 'query { job(id: "%s") { id } }' % j_id} ) - assert ret == { + assert ret['data'] == { 'job': {'id': f'{j_id}'} } diff --git a/tests/integration/test_replier.py b/tests/integration/test_replier.py index ce0b53fdaa8..8ae873d6fae 100644 --- a/tests/integration/test_replier.py +++ b/tests/integration/test_replier.py @@ -14,11 +14,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from async_timeout import timeout -from cylc.flow.network import decode_ +import json +from cylc.flow.network import ResponseTuple from cylc.flow.network.client import WorkflowRuntimeClient import asyncio +from async_timeout import timeout import pytest @@ -28,7 +29,9 @@ async def test_listener(one, start, ): client = WorkflowRuntimeClient(one.workflow) client.socket.send_string(r'Not JSON') res = await client.socket.recv() - assert 'error' in decode_(res.decode()) + response = ResponseTuple(*json.loads(res.decode())) + assert response.content is None + assert 'failed to decode message' in response.err[0] one.server.replier.queue.put('STOP') async with timeout(2): diff --git a/tests/integration/test_resolvers.py b/tests/integration/test_resolvers.py index 6301861b058..d58c6b89212 100644 --- a/tests/integration/test_resolvers.py +++ b/tests/integration/test_resolvers.py @@ -15,7 +15,7 @@ # along with this program. If not, see . import logging -from typing import AsyncGenerator, Callable +from typing import Callable from unittest.mock import Mock import pytest @@ -54,7 +54,7 @@ async def mock_flow( mod_flow: Callable[..., str], mod_scheduler: Callable[..., Scheduler], mod_start, -) -> AsyncGenerator[Scheduler, None]: +): ret = Mock() ret.reg = mod_flow({ 'scheduler': { @@ -202,14 +202,15 @@ async def test_mutator(mock_flow, flow_args): }) args = {} meta = {} - response = await mock_flow.resolvers.mutator( + resolvers: Resolvers = mock_flow.resolvers + response = await resolvers.mutator( None, 'pause', flow_args, args, meta ) - assert response[0]['id'] == mock_flow.id + assert response[0].workflowId == mock_flow.id async def test_mutation_mapper(mock_flow): diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py index c6625b37c77..f5730c97f84 100644 --- a/tests/integration/test_server.py +++ b/tests/integration/test_server.py @@ -22,6 +22,8 @@ from cylc.flow.network.server import PB_METHOD_MAP +from cylc.flow.scheduler import Scheduler + @pytest.fixture(scope='module') async def myflow(mod_flow, mod_scheduler, mod_run, mod_one_conf): @@ -50,7 +52,7 @@ def test_graphql(myflow): }} }} ''' - data = call_server_method(myflow.server.graphql, request_string) + data = call_server_method(myflow.server.graphql, request_string)['data'] assert myflow.id == data['workflows'][0]['id'] @@ -106,32 +108,29 @@ async def test_operate(one, start): one.server.operate() -async def test_receiver(one, start): +async def test_receiver(one: Scheduler, start): """Test the receiver with different message objects.""" + user = 'troi' async with timeout(5): async with start(one): # start with a message that works - msg = {'command': 'api', 'user': '', 'args': {}} - assert 'error' not in one.server.receiver(msg) - assert 'data' in one.server.receiver(msg) - - # remove the user field - should error - msg2 = dict(msg) - msg2.pop('user') - assert 'error' in one.server.receiver(msg2) + msg = {'command': 'api', 'args': {}} + response = one.server.receiver(msg, user) + assert response.err is None + assert response.content is not None # remove the command field - should error msg3 = dict(msg) msg3.pop('command') - assert 'error' in one.server.receiver(msg3) + assert one.server.receiver(msg3, user).err is not None # provide an invalid command - should error msg4 = {**msg, 'command': 'foobar'} - assert 'error' in one.server.receiver(msg4) + assert one.server.receiver(msg4, user).err is not None # simulate a command failure with the original message # (the one which worked earlier) - should error def _api(*args, **kwargs): raise Exception('foo') one.server.api = _api - assert 'error' in one.server.receiver(msg) + assert one.server.receiver(msg, user).err is not None