diff --git a/cylc/flow/async_util.py b/cylc/flow/async_util.py
index aa62b39acc7..39eb1250514 100644
--- a/cylc/flow/async_util.py
+++ b/cylc/flow/async_util.py
@@ -165,7 +165,7 @@ async def _chain(self, item, coros, completed):
ret = await coro.func(item, *coro.args, **coro.kwargs)
except Exception as exc:
# if something goes wrong log the error and skip the item
- LOG.warning(exc)
+ LOG.warning(f"{type(exc).__name__}: {exc}")
ret = False
if ret is True:
# filter passed -> continue
diff --git a/cylc/flow/network/__init__.py b/cylc/flow/network/__init__.py
index d14e11c6487..6c1ff7de621 100644
--- a/cylc/flow/network/__init__.py
+++ b/cylc/flow/network/__init__.py
@@ -16,9 +16,7 @@
"""Package for network interfaces to Cylc scheduler objects."""
import asyncio
-import getpass
-import json
-from typing import Any, Dict
+from typing import NamedTuple, Optional
import zmq
import zmq.asyncio
@@ -45,21 +43,16 @@
MSG_TIMEOUT = "TIMEOUT"
-def encode_(message: object) -> str:
- """Convert the structure holding a message field from JSON to a string."""
- try:
- return json.dumps(message)
- except TypeError as exc:
- return json.dumps({'errors': [{'message': str(exc)}]})
-
-
-def decode_(message: str) -> Dict[str, Any]:
- """Convert an encoded message string to JSON with an added 'user' field."""
- msg: object = json.loads(message)
- if not isinstance(msg, dict):
- raise ValueError(f"Expected message to be dict but got {type(msg)}")
- msg['user'] = getpass.getuser() # assume this is the user
- return msg
+class ResponseTuple(NamedTuple):
+ """Structure of server response messages."""
+ content: Optional[object] = None
+ err: Optional['ResponseErrTuple'] = None
+ user: Optional[str] = None
+
+
+class ResponseErrTuple(NamedTuple):
+ message: str
+ traceback: Optional[str] = None
def get_location(workflow: str):
diff --git a/cylc/flow/network/client.py b/cylc/flow/network/client.py
index ae917c41aa9..ea5b6aad51f 100644
--- a/cylc/flow/network/client.py
+++ b/cylc/flow/network/client.py
@@ -16,6 +16,7 @@
"""Client for workflow runtime API."""
from functools import partial
+import json
import os
from shutil import which
import socket
@@ -34,9 +35,8 @@
WorkflowStopped,
)
from cylc.flow.network import (
- encode_,
- decode_,
get_location,
+ ResponseTuple,
ZMQSocketBase
)
from cylc.flow.network.client_factory import CommsMeth
@@ -165,7 +165,7 @@ async def async_request(
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
- ) -> object:
+ ) -> Union[bytes, object]:
"""Send an asynchronous request using asyncio.
Has the same arguments and return values as ``serial_request``.
@@ -187,12 +187,12 @@ async def async_request(
if req_meta:
msg['meta'].update(req_meta)
LOG.debug('zmq:send %s', msg)
- message = encode_(msg)
+ message = json.dumps(msg)
self.socket.send_string(message)
# receive response
if self.poller.poll(timeout):
- res = await self.socket.recv()
+ res: bytes = await self.socket.recv()
else:
if callable(self.timeout_handler):
self.timeout_handler()
@@ -204,24 +204,20 @@ async def async_request(
' This could be due to network or server issues.'
' Check the workflow log.'
)
+ LOG.debug('zmq:recv %s', res)
- if msg['command'] in PB_METHOD_MAP:
- response = {'data': res}
- else:
- response = decode_(res.decode())
- LOG.debug('zmq:recv %s', response)
+ if command in PB_METHOD_MAP:
+ return res
- try:
- return response['data']
- except KeyError:
- error = response.get(
- 'error',
- {'message': f'Received invalid response: {response}'},
- )
- raise ClientError(
- error.get('message'),
- error.get('traceback'),
- )
+ response = ResponseTuple( # type: ignore[misc]
+ *json.loads(res.decode())
+ )
+
+ if response.content is not None:
+ return response.content
+ if response.err:
+ raise ClientError(*response.err)
+ raise ClientError(f"Received invalid response: {response}")
def serial_request(
self,
@@ -229,7 +225,7 @@ def serial_request(
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
- ) -> object:
+ ) -> Union[bytes, object]:
"""Send a request.
For convenience use ``__call__`` to call this method.
diff --git a/cylc/flow/network/graphql.py b/cylc/flow/network/graphql.py
index 7650a58b154..ad20bd8730f 100644
--- a/cylc/flow/network/graphql.py
+++ b/cylc/flow/network/graphql.py
@@ -21,7 +21,7 @@
from functools import partial
import logging
-from typing import TYPE_CHECKING, Any, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Tuple, Union
from inspect import isclass, iscoroutinefunction
@@ -29,6 +29,7 @@
from graphql.execution.utils import (
get_operation_root_type, get_field_def
)
+from graphql.execution import ExecutionResult
from graphql.execution.values import get_argument_values, get_variable_values
from graphql.language.base import parse, print_ast
from graphql.language import ast
@@ -42,7 +43,6 @@
from cylc.flow.network.schema import NODE_MAP, get_type_str
if TYPE_CHECKING:
- from graphql.execution import ExecutionResult
from graphql.language.ast import Document
from graphql.type import GraphQLSchema
@@ -146,6 +146,14 @@ def null_stripper(exe_result):
return exe_result
+def format_execution_result(
+ result: Union[ExecutionResult, Dict[str, Any]]
+) -> Dict[str, Any]:
+ if isinstance(result, ExecutionResult):
+ result = result.to_dict()
+ return strip_null(result)
+
+
class AstDocArguments:
"""Request doc Argument inspection."""
@@ -254,7 +262,7 @@ def execute_and_validate_and_strip(
document_ast: 'Document',
*args: Any,
**kwargs: Any
-) -> Union['ExecutionResult', Observable]:
+) -> Union[ExecutionResult, Observable]:
"""Wrapper around graphql ``execute_and_validate()`` that adds
null stripping."""
result = execute_and_validate(schema, document_ast, *args, **kwargs)
diff --git a/cylc/flow/network/replier.py b/cylc/flow/network/replier.py
index 62e9361a8cc..16044111968 100644
--- a/cylc/flow/network/replier.py
+++ b/cylc/flow/network/replier.py
@@ -15,13 +15,21 @@
# along with this program. If not, see .
"""Server for workflow runtime API."""
-import getpass # noqa: F401
+import getpass
+import json
from queue import Queue
+from typing import TYPE_CHECKING
+from typing_extensions import Literal
import zmq
from cylc.flow import LOG
-from cylc.flow.network import encode_, decode_, ZMQSocketBase
+from cylc.flow.network import (
+ ResponseErrTuple, ResponseTuple, ZMQSocketBase
+)
+
+if TYPE_CHECKING:
+ from cylc.flow.network.server import WorkflowRuntimeServer
class WorkflowReplier(ZMQSocketBase):
@@ -46,11 +54,15 @@ class WorkflowReplier(ZMQSocketBase):
"""
- def __init__(self, server, context=None):
+ def __init__(
+ self,
+ server: 'WorkflowRuntimeServer',
+ context=None
+ ):
super().__init__(zmq.REP, bind=True, context=context)
self.server = server
self.workflow = server.schd.workflow
- self.queue = Queue()
+ self.queue: Queue[Literal['STOP']] = Queue()
def _bespoke_stop(self) -> None:
"""Stop the listener and Authenticator.
@@ -92,27 +104,39 @@ def listener(self) -> None:
continue
# attempt to decode the message, authenticating the user in the
# process
+ response: bytes
try:
- message = decode_(msg)
+ message = json.loads(msg)
+ user = getpass.getuser() # assume this is the user
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
- LOG.exception('failed to decode message: "%s"', exc)
+ LOG.exception(exc)
+ LOG.error(f'failed to decode message: "{msg}"')
import traceback
- response = encode_(
- {
- 'error': {
- 'message': 'failed to decode message: "%s"' % msg,
- 'traceback': traceback.format_exc(),
- }
- }
+ response = json.dumps(
+ ResponseTuple(
+ err=ResponseErrTuple(
+ f'failed to decode message: {msg}"',
+ traceback.format_exc(),
+ )
+ )
).encode()
else:
# success case - serve the request
- res = self.server.receiver(message)
- # send back the string to bytes response
- if isinstance(res.get('data'), bytes):
- response = res['data']
+ res = self.server.receiver(message, user)
+ if isinstance(res.content, bytes): # is protobuf method
+ # just return bytes, as cannot serialize bytes to JSON
+ response = res.content
else:
- response = encode_(res).encode()
+ try:
+ response = json.dumps(res).encode()
+ except TypeError as exc:
+ err_msg = f"failed to encode response: {res}\n{exc}"
+ LOG.warning(err_msg)
+ res = ResponseTuple(
+ err=ResponseErrTuple(err_msg)
+ )
+ response = json.dumps(res).encode()
+ # send back the string to bytes response
self.socket.send(response)
diff --git a/cylc/flow/network/resolvers.py b/cylc/flow/network/resolvers.py
index dbb170d3eda..cb6633b60b8 100644
--- a/cylc/flow/network/resolvers.py
+++ b/cylc/flow/network/resolvers.py
@@ -45,6 +45,7 @@
from cylc.flow.id import Tokens
from cylc.flow.network.schema import (
DEF_TYPES,
+ GenericResponse,
NodesEdges,
PROXY_NODES,
SUB_RESOLVERS,
@@ -629,7 +630,7 @@ async def mutator(
w_args: Dict[str, Any],
kwargs: Dict[str, Any],
meta: Dict[str, Any]
- ) -> List[Dict[str, Any]]:
+ ) -> List[GenericResponse]:
...
@@ -650,19 +651,22 @@ async def mutator(
w_args: Dict[str, Any],
kwargs: Dict[str, Any],
meta: Dict[str, Any]
- ) -> List[Dict[str, Any]]:
+ ) -> List[GenericResponse]:
"""Mutate workflow."""
w_ids = [flow[WORKFLOW].id
for flow in await self.get_workflows_data(w_args)]
if not w_ids:
workflows = list(self.data_store_mgr.data.keys())
- return [{
- 'response': (False, f'No matching workflow in {workflows}')}]
+ ret = GenericResponse(
+ success=False, message=f'No matching workflow in {workflows}'
+ )
+ return [ret]
w_id = w_ids[0]
result = await self._mutation_mapper(command, kwargs, meta)
if result is None:
result = (True, 'Command queued')
- return [{'id': w_id, 'response': result}]
+ ret = GenericResponse(w_id, *result)
+ return [ret]
async def _mutation_mapper(
self, command: str, kwargs: Dict[str, Any], meta: Dict[str, Any]
diff --git a/cylc/flow/network/scan.py b/cylc/flow/network/scan.py
index 12e03ef3f45..e42e5ac6cd0 100644
--- a/cylc/flow/network/scan.py
+++ b/cylc/flow/network/scan.py
@@ -50,6 +50,7 @@
import re
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Tuple, Union
+from graphql.execution import ExecutionResult
from pkg_resources import (
parse_requirements,
parse_version
@@ -417,7 +418,9 @@ def format_query(fields, filters=None):
@pipe(preproc=format_query)
-async def graphql_query(flow, fields, filters=None):
+async def graphql_query(
+ flow: dict, fields: Iterable[str], filters: Optional[list] = None
+) -> Union[bool, dict]:
"""Obtain information from a GraphQL request to the flow.
Requires:
@@ -425,9 +428,9 @@ async def graphql_query(flow, fields, filters=None):
* contact_info
Args:
- flow (dict):
+ flow:
Flow information dictionary, provided by scan through the pipe.
- fields (iterable):
+ fields:
Iterable containing the fields to request e.g::
['id', 'name']
@@ -435,7 +438,7 @@ async def graphql_query(flow, fields, filters=None):
One level of nesting is supported e.g::
{'name': None, 'meta': ['title']}
- filters (list):
+ filters:
Filter by the data returned from the query.
List in the form ``[(key, ...), value]``, e.g::
@@ -458,7 +461,8 @@ async def graphql_query(flow, fields, filters=None):
LOG.warning(f'Workflow not running: {flow["name"]}')
return False
try:
- ret = await client.async_request(
+ ret: dict = await client.async_request( # type: ignore[assignment]
+ # (graphql request gives dict)
'graphql',
{
'request_string': query,
@@ -476,13 +480,20 @@ async def graphql_query(flow, fields, filters=None):
LOG.exception(exc)
return False
else:
+ response = ExecutionResult(**ret)
# stick the result into the flow object
- for item in ret:
- if 'error' in item:
- LOG.exception(item['error']['message'])
- return False
- for workflow in ret.get('workflows', []):
- flow.update(workflow)
+ if not response.data:
+ if response.errors:
+ LOG.error("Scan error(s)")
+ for err in response.errors:
+ LOG.error(err)
+ else:
+ LOG.exception("Scan error: empty response")
+ return False
+ for workflow in response.data.get('workflows', []):
+ flow.update(workflow)
+ # TODO: what if no items in workflows list, will this cause
+ # KeyError below when trying to access flow[field_]?
# process filters
for field, value in filters or []:
diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py
index bed66b86680..16a464fbae1 100644
--- a/cylc/flow/network/schema.py
+++ b/cylc/flow/network/schema.py
@@ -28,6 +28,7 @@
Any,
List,
Optional,
+ Type,
)
import graphene
@@ -1187,9 +1188,22 @@ class Meta:
# Generic containers
class GenericResponse(ObjectType):
class Meta:
- description = """Container for command queued response"""
+ description = """Container for workflow command queued response"""
- result = GenericScalar()
+ workflowId = String()
+ success = Boolean(required=True)
+ message = String(required=True)
+
+ # Define __init__ for benefit of static type checking:
+ def __init__(
+ self,
+ workflowId: Optional[str] = None,
+ success: Optional[bool] = None,
+ message: Optional[str] = None
+ ):
+ # Note: all args optional here to allow for not requesting them in a
+ # mutation
+ ObjectType.__init__(self, workflowId, success, message)
# Mutators are used to call the internals of the parent program in the
@@ -1207,7 +1221,7 @@ async def mutator(
workflows: Optional[List[str]] = None,
exworkflows: Optional[List[str]] = None,
**kwargs: Any
-) -> GenericResponse:
+) -> List[GenericResponse]:
"""Call the resolver method that act on the workflow service
via the internal command queue.
@@ -1239,7 +1253,9 @@ async def mutator(
)
meta = info.context.get('meta') # type: ignore[union-attr]
res = await resolvers.mutator(info, command, w_args, kwargs, meta)
- return GenericResponse(result=res)
+ return info.return_type.graphene_type( # type: ignore[union-attr]
+ results=res
+ )
# Input types:
@@ -1409,7 +1425,15 @@ class Flow(String):
# Mutations:
-class Broadcast(Mutation):
+class WorkflowsMutation:
+ """Base class for mutations involving workflows."""
+ class Arguments:
+ workflows = graphene.List(WorkflowID, required=True)
+
+ results = graphene.List(GenericResponse)
+
+
+class Broadcast(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Override `[runtime]` configurations in a running workflow.
@@ -1444,9 +1468,7 @@ class Meta:
''')
resolver = partial(mutator, command='broadcast')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
-
+ class Arguments(WorkflowsMutation.Arguments):
mode = BroadcastMode(
# use the enum name as the default value
# https://github.com/graphql-python/graphql-core-legacy/issues/166
@@ -1492,10 +1514,8 @@ class Arguments:
# ''')
# )
- result = GenericScalar()
-
-class SetHoldPoint(Mutation):
+class SetHoldPoint(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Set workflow hold after cycle point. All tasks after this point
@@ -1503,17 +1523,14 @@ class Meta:
''')
resolver = partial(mutator, command='set_hold_point')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
+ class Arguments(WorkflowsMutation.Arguments):
point = CyclePoint(
description='Hold all tasks after the specified cycle point.',
required=True
)
- result = GenericScalar()
-
-class Pause(Mutation):
+class Pause(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Pause a workflow.
@@ -1522,13 +1539,8 @@ class Meta:
''')
resolver = partial(mutator, command='pause')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
-
- result = GenericScalar()
-
-class Message(Mutation):
+class Message(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Record task job messages.
@@ -1545,8 +1557,7 @@ class Meta:
''')
resolver = partial(mutator, command='put_messages')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
+ class Arguments(WorkflowsMutation.Arguments):
task_job = String(required=True)
event_time = String(default_value=None)
messages = graphene.List(
@@ -1555,10 +1566,8 @@ class Arguments:
default_value=None
)
- result = GenericScalar()
-
-class ReleaseHoldPoint(Mutation):
+class ReleaseHoldPoint(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Release all tasks and unset the workflow hold point, if set.
@@ -1567,13 +1576,8 @@ class Meta:
''')
resolver = partial(mutator, command='release_hold_point')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
-
- result = GenericScalar()
-
-class Resume(Mutation):
+class Resume(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Resume a paused workflow.
@@ -1582,13 +1586,8 @@ class Meta:
''')
resolver = partial(mutator, command='resume')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
- result = GenericScalar()
-
-
-class Reload(Mutation):
+class Reload(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Reload the configuration of a running workflow.
@@ -1608,13 +1607,8 @@ class Meta:
''')
resolver = partial(mutator, command='reload_workflow')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
- result = GenericScalar()
-
-
-class SetVerbosity(Mutation):
+class SetVerbosity(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Change the logging severity level of a running workflow.
@@ -1625,14 +1619,11 @@ class Meta:
''')
resolver = partial(mutator, command='set_verbosity')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
+ class Arguments(WorkflowsMutation.Arguments):
level = LogLevels(required=True)
- result = GenericScalar()
-
-class SetGraphWindowExtent(Mutation):
+class SetGraphWindowExtent(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Set the maximum graph distance (n) from an active node
@@ -1641,14 +1632,11 @@ class Meta:
''')
resolver = partial(mutator, command='set_graph_window_extent')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
+ class Arguments(WorkflowsMutation.Arguments):
n_edge_distance = Int(required=True)
- result = GenericScalar()
-
-class Stop(Mutation):
+class Stop(Mutation, WorkflowsMutation):
class Meta:
description = sstrip(f'''
Tell a workflow to shut down or stop a specified
@@ -1666,8 +1654,7 @@ class Meta:
''')
resolver = partial(mutator, command='stop')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
+ class Arguments(WorkflowsMutation.Arguments):
mode = WorkflowStopMode(
default_value=WorkflowStopMode.Clean.name
)
@@ -1684,10 +1671,8 @@ class Arguments:
description='Number of flow to stop.'
)
- result = GenericScalar()
-
-class ExtTrigger(Mutation):
+class ExtTrigger(Mutation, WorkflowsMutation):
class Meta:
description = sstrip('''
Report an external event message to a scheduler.
@@ -1714,8 +1699,7 @@ class Meta:
''')
resolver = partial(mutator, command='put_ext_trigger')
- class Arguments:
- workflows = graphene.List(WorkflowID, required=True)
+ class Arguments(WorkflowsMutation.Arguments):
message = String(
description='External trigger message.',
required=True
@@ -1725,21 +1709,11 @@ class Arguments:
required=True
)
- result = GenericScalar()
-
-class TaskMutation:
- class Arguments:
- workflows = graphene.List(
- WorkflowID,
- required=True
- )
- tasks = graphene.List(
- NamespaceIDGlob,
- required=True
- )
-
- result = GenericScalar()
+class TasksMutation(WorkflowsMutation):
+ """Base class for mutations involving tasks."""
+ class Arguments(WorkflowsMutation.Arguments):
+ tasks = graphene.List(NamespaceIDGlob, required=True)
class FlowMutationArguments:
@@ -1763,7 +1737,7 @@ class FlowMutationArguments:
)
-class Hold(Mutation, TaskMutation):
+class Hold(Mutation, TasksMutation):
class Meta:
description = sstrip('''
Hold tasks within a workflow.
@@ -1773,7 +1747,7 @@ class Meta:
resolver = partial(mutator, command='hold')
-class Release(Mutation, TaskMutation):
+class Release(Mutation, TasksMutation):
class Meta:
description = sstrip('''
Release held tasks within a workflow.
@@ -1783,7 +1757,7 @@ class Meta:
resolver = partial(mutator, command='release')
-class Kill(Mutation, TaskMutation):
+class Kill(Mutation, TasksMutation):
# TODO: This should be a job mutation?
class Meta:
description = sstrip('''
@@ -1792,7 +1766,7 @@ class Meta:
resolver = partial(mutator, command='kill_tasks')
-class Poll(Mutation, TaskMutation):
+class Poll(Mutation, TasksMutation):
class Meta:
description = sstrip('''
Poll (query) task jobs to verify and update their statuses.
@@ -1806,11 +1780,8 @@ class Meta:
''')
resolver = partial(mutator, command='poll_tasks')
- class Arguments(TaskMutation.Arguments):
- ...
-
-class Remove(Mutation, TaskMutation):
+class Remove(Mutation, TasksMutation):
class Meta:
description = sstrip('''
Remove one or more task instances from a running workflow.
@@ -1819,7 +1790,7 @@ class Meta:
resolver = partial(mutator, command='remove_tasks')
-class SetOutputs(Mutation, TaskMutation):
+class SetOutputs(Mutation, TasksMutation):
class Meta:
description = sstrip('''
Artificially mark task outputs as completed.
@@ -1831,7 +1802,7 @@ class Meta:
''')
resolver = partial(mutator, command='force_spawn_children')
- class Arguments(TaskMutation.Arguments):
+ class Arguments(TasksMutation.Arguments):
outputs = graphene.List(
String,
default_value=[TASK_OUTPUT_SUCCEEDED],
@@ -1840,7 +1811,7 @@ class Arguments(TaskMutation.Arguments):
flow_num = Int()
-class Trigger(Mutation, TaskMutation):
+class Trigger(Mutation, TasksMutation):
class Meta:
description = sstrip('''
Manually trigger tasks.
@@ -1853,7 +1824,7 @@ class Meta:
''')
resolver = partial(mutator, command='force_trigger_tasks')
- class Arguments(TaskMutation.Arguments, FlowMutationArguments):
+ class Arguments(TasksMutation.Arguments, FlowMutationArguments):
flow_wait = Boolean(
default_value=False,
description=sstrip('''
@@ -1879,17 +1850,13 @@ class Arguments(TaskMutation.Arguments, FlowMutationArguments):
)
-def _mut_field(cls):
+def _mut_field(cls: Type[Mutation]) -> Field:
"""Convert a mutation class into a field.
Sets the field metadata appropriately.
Args:
- field (class):
- Subclass of graphene.Mutation
-
- Returns:
- graphene.Field
+ cls: Subclass of graphene.Mutation
"""
return cls.Field(description=cls._meta.description)
diff --git a/cylc/flow/network/server.py b/cylc/flow/network/server.py
index 6802bf0cedb..1649345864e 100644
--- a/cylc/flow/network/server.py
+++ b/cylc/flow/network/server.py
@@ -22,6 +22,7 @@
from time import sleep
from typing import Any, Dict, List, Optional, Union
+from graphql.error import GraphQLError
from graphql.execution import ExecutionResult
from graphql.execution.executors.asyncio import AsyncioExecutor
import zmq
@@ -29,9 +30,13 @@
from cylc.flow import LOG, workflow_files
from cylc.flow.cfgspec.glbl_cfg import glbl_cfg
+from cylc.flow.network import ResponseErrTuple, ResponseTuple
from cylc.flow.network.authorisation import authorise
from cylc.flow.network.graphql import (
- CylcGraphQLBackend, IgnoreFieldMiddleware, instantiate_middleware
+ CylcGraphQLBackend,
+ IgnoreFieldMiddleware,
+ format_execution_result,
+ instantiate_middleware
)
from cylc.flow.network.publisher import WorkflowPublisher
from cylc.flow.network.replier import WorkflowReplier
@@ -252,44 +257,56 @@ def operate(self):
# Yield control to other threads
sleep(self.OPERATE_SLEEP_INTERVAL)
- def receiver(self, message):
+ def receiver(
+ self, message: Dict[str, Any], user: str
+ ) -> ResponseTuple:
"""Process incoming messages and coordinate response.
Wrap incoming messages, dispatch them to exposed methods and/or
coordinate a publishing stream.
Args:
- message (dict): message contents
+ message: message contents
"""
# TODO: If requested, coordinate publishing response/stream.
# determine the server method to call
+ if not isinstance(message, dict):
+ return ResponseTuple(
+ err=ResponseErrTuple(
+ f'Expected dict but request is: {message}'
+ )
+ )
try:
method = getattr(self, message['command'])
- args = message['args']
- args.update({'user': message['user']})
+ args: dict = message['args']
+ args.update({'user': user})
if 'meta' in message:
args['meta'] = message['meta']
except KeyError:
# malformed message
- return {'error': {
- 'message': 'Request missing required field(s).'}}
+ return ResponseTuple(
+ err=ResponseErrTuple('Request missing required field(s).')
+ )
except AttributeError:
# no exposed method by that name
- return {'error': {
- 'message': 'No method by the name "%s"' % message['command']}}
-
+ return ResponseTuple(
+ err=ResponseErrTuple(
+ f"No method by the name '{message['command']}'"
+ )
+ )
# generate response
try:
response = method(**args)
except Exception as exc:
# includes incorrect arguments (TypeError)
- LOG.exception(exc) # note the error server side
+ LOG.exception(exc) # log the error server side
import traceback
- return {'error': {
- 'message': str(exc), 'traceback': traceback.format_exc()}}
+ return ResponseTuple(
+ err=ResponseErrTuple(str(exc), traceback.format_exc())
+ )
- return {'data': response}
+ return ResponseTuple(content=response)
def register_endpoints(self):
"""Register all exposed methods."""
@@ -341,16 +358,13 @@ def graphql(
request_string: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None
- ):
+ ) -> Dict[str, Any]:
"""Return the GraphQL schema execution result.
Args:
request_string: GraphQL request passed to Graphene.
variables: Dict of variables passed to Graphene.
meta: Dict containing auth user etc.
-
- Returns:
- object: Execution result, or a list with errors.
"""
try:
executed: ExecutionResult = schema.execute(
@@ -367,20 +381,22 @@ def graphql(
return_promise=False,
)
except Exception as exc:
- return 'ERROR: GraphQL execution error \n%s' % exc
+ raise GraphQLError(f"ERROR: GraphQL execution error \n{exc}")
if executed.errors:
- errors: List[Any] = []
- for error in executed.errors:
- if hasattr(error, '__traceback__'):
+ for i, excp in enumerate(executed.errors):
+ if isinstance(excp, GraphQLError):
+ error = excp
+ else:
+ error = GraphQLError(message=str(excp))
+ if hasattr(excp, '__traceback__'):
import traceback
- errors.append({'error': {
- 'message': str(error),
- 'traceback': traceback.format_exception(
- error.__class__, error, error.__traceback__)}})
- continue
- errors.append(getattr(error, 'message', None))
- return errors
- return executed.data
+ extensions = error.extensions or {}
+ extensions['traceback'] = traceback.format_exception(
+ excp.__class__, excp, excp.__traceback__
+ )
+ error.extensions = extensions
+ executed.errors[i] = error
+ return format_execution_result(executed)
# UIServer Data Commands
@authorise()
diff --git a/tests/integration/graphql/test_root.py b/tests/integration/graphql/test_root.py
index 12b16822abb..a177e1085ca 100644
--- a/tests/integration/graphql/test_root.py
+++ b/tests/integration/graphql/test_root.py
@@ -62,9 +62,11 @@ async def test_workflows(harness):
schd, client, query = harness
ret = await query('workflows(ids: ["%s"]) { id }' % schd.workflow)
assert ret == {
- 'workflows': [
- {'id': f'~{schd.owner}/{schd.workflow}'}
- ]
+ 'data': {
+ 'workflows': [
+ {'id': f'~{schd.owner}/{schd.workflow}'}
+ ]
+ }
}
@@ -73,7 +75,9 @@ async def test_jobs(harness):
schd, client, query = harness
ret = await query('workflows(ids: ["%s"]) { id }' % schd.workflow)
assert ret == {
- 'workflows': [
- {'id': f'~{schd.owner}/{schd.workflow}'}
- ]
+ 'data': {
+ 'workflows': [
+ {'id': f'~{schd.owner}/{schd.workflow}'}
+ ]
+ }
}
diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py
index 07485d2dcf1..0303aeb8e51 100644
--- a/tests/integration/test_client.py
+++ b/tests/integration/test_client.py
@@ -37,7 +37,7 @@ async def test_graphql(harness):
'graphql',
{'request_string': 'query { workflows { id } }'}
)
- workflows = ret['workflows']
+ workflows = ret['data']['workflows']
assert len(workflows) == 1
workflow = workflows[0]
assert schd.workflow in workflow['id']
diff --git a/tests/integration/test_graphql.py b/tests/integration/test_graphql.py
index 81d0c806069..43cb53cfcab 100644
--- a/tests/integration/test_graphql.py
+++ b/tests/integration/test_graphql.py
@@ -116,11 +116,11 @@ async def test_workflows(harness):
{'request_string': 'query { workflows { id } }'}
)
assert ret == {
- 'workflows': [
- {
- 'id': f'{w_tokens}'
- }
- ]
+ 'data': {
+ 'workflows': [
+ {'id': f'{w_tokens}'}
+ ]
+ }
}
@@ -132,6 +132,7 @@ async def test_tasks(harness):
'graphql',
{'request_string': 'query { tasks { id } }'}
)
+ ret = ret['data']
ids = [
w_tokens.duplicate(cycle=f'$namespace|{namespace}').id
for namespace in ('a', 'b', 'c', 'd')
@@ -150,7 +151,7 @@ async def test_tasks(harness):
'graphql',
{'request_string': 'query { task(id: "%s") { id } }' % id_}
)
- assert ret == {
+ assert ret['data'] == {
'task': {'id': id_}
}
@@ -163,6 +164,7 @@ async def test_families(harness):
'graphql',
{'request_string': 'query { families { id } }'}
)
+ ret = ret['data']
ids = [
w_tokens.duplicate(
cycle=f'$namespace|{namespace}'
@@ -183,7 +185,7 @@ async def test_families(harness):
'graphql',
{'request_string': 'query { family(id: "%s") { id } }' % id_}
)
- assert ret == {
+ assert ret['data'] == {
'family': {'id': id_}
}
@@ -196,6 +198,7 @@ async def test_task_proxies(harness):
'graphql',
{'request_string': 'query { taskProxies { id } }'}
)
+ ret = ret['data']
ids = [
w_tokens.duplicate(
cycle='1',
@@ -217,7 +220,7 @@ async def test_task_proxies(harness):
'graphql',
{'request_string': 'query { taskProxy(id: "%s") { id } }' % ids[0]}
)
- assert ret == {
+ assert ret['data'] == {
'taskProxy': {'id': ids[0]}
}
@@ -230,6 +233,7 @@ async def test_family_proxies(harness):
'graphql',
{'request_string': 'query { familyProxies { id } }'}
)
+ ret = ret['data']
ids = [
w_tokens.duplicate(
cycle='1',
@@ -252,7 +256,7 @@ async def test_family_proxies(harness):
'graphql',
{'request_string': 'query { familyProxy(id: "%s") { id } }' % id_}
)
- assert ret == {
+ assert ret['data'] == {
'familyProxy': {'id': id_}
}
@@ -288,7 +292,7 @@ async def test_edges(harness):
'graphql',
{'request_string': 'query { edges { id } }'}
)
- assert ret == {
+ assert ret['data'] == {
'edges': [
{'id': id_}
for id_ in e_ids
@@ -300,6 +304,7 @@ async def test_edges(harness):
'graphql',
{'request_string': 'query { nodesEdges { nodes {id}\nedges {id} } }'}
)
+ ret = ret['data']
ret['nodesEdges']['nodes'].sort(key=lambda x: x['id'])
ret['nodesEdges']['edges'].sort(key=lambda x: x['id'])
assert ret == {
@@ -334,7 +339,7 @@ async def test_jobs(harness):
'graphql',
{'request_string': 'query { jobs { id } }'}
)
- assert ret == {
+ assert ret['data'] == {
'jobs': [
{'id': f'{j_id}'}
]
@@ -345,6 +350,6 @@ async def test_jobs(harness):
'graphql',
{'request_string': 'query { job(id: "%s") { id } }' % j_id}
)
- assert ret == {
+ assert ret['data'] == {
'job': {'id': f'{j_id}'}
}
diff --git a/tests/integration/test_replier.py b/tests/integration/test_replier.py
index ce0b53fdaa8..8ae873d6fae 100644
--- a/tests/integration/test_replier.py
+++ b/tests/integration/test_replier.py
@@ -14,11 +14,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-from async_timeout import timeout
-from cylc.flow.network import decode_
+import json
+from cylc.flow.network import ResponseTuple
from cylc.flow.network.client import WorkflowRuntimeClient
import asyncio
+from async_timeout import timeout
import pytest
@@ -28,7 +29,9 @@ async def test_listener(one, start, ):
client = WorkflowRuntimeClient(one.workflow)
client.socket.send_string(r'Not JSON')
res = await client.socket.recv()
- assert 'error' in decode_(res.decode())
+ response = ResponseTuple(*json.loads(res.decode()))
+ assert response.content is None
+ assert 'failed to decode message' in response.err[0]
one.server.replier.queue.put('STOP')
async with timeout(2):
diff --git a/tests/integration/test_resolvers.py b/tests/integration/test_resolvers.py
index 6301861b058..d58c6b89212 100644
--- a/tests/integration/test_resolvers.py
+++ b/tests/integration/test_resolvers.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
import logging
-from typing import AsyncGenerator, Callable
+from typing import Callable
from unittest.mock import Mock
import pytest
@@ -54,7 +54,7 @@ async def mock_flow(
mod_flow: Callable[..., str],
mod_scheduler: Callable[..., Scheduler],
mod_start,
-) -> AsyncGenerator[Scheduler, None]:
+):
ret = Mock()
ret.reg = mod_flow({
'scheduler': {
@@ -202,14 +202,15 @@ async def test_mutator(mock_flow, flow_args):
})
args = {}
meta = {}
- response = await mock_flow.resolvers.mutator(
+ resolvers: Resolvers = mock_flow.resolvers
+ response = await resolvers.mutator(
None,
'pause',
flow_args,
args,
meta
)
- assert response[0]['id'] == mock_flow.id
+ assert response[0].workflowId == mock_flow.id
async def test_mutation_mapper(mock_flow):
diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py
index c6625b37c77..f5730c97f84 100644
--- a/tests/integration/test_server.py
+++ b/tests/integration/test_server.py
@@ -22,6 +22,8 @@
from cylc.flow.network.server import PB_METHOD_MAP
+from cylc.flow.scheduler import Scheduler
+
@pytest.fixture(scope='module')
async def myflow(mod_flow, mod_scheduler, mod_run, mod_one_conf):
@@ -50,7 +52,7 @@ def test_graphql(myflow):
}}
}}
'''
- data = call_server_method(myflow.server.graphql, request_string)
+ data = call_server_method(myflow.server.graphql, request_string)['data']
assert myflow.id == data['workflows'][0]['id']
@@ -106,32 +108,29 @@ async def test_operate(one, start):
one.server.operate()
-async def test_receiver(one, start):
+async def test_receiver(one: Scheduler, start):
"""Test the receiver with different message objects."""
+ user = 'troi'
async with timeout(5):
async with start(one):
# start with a message that works
- msg = {'command': 'api', 'user': '', 'args': {}}
- assert 'error' not in one.server.receiver(msg)
- assert 'data' in one.server.receiver(msg)
-
- # remove the user field - should error
- msg2 = dict(msg)
- msg2.pop('user')
- assert 'error' in one.server.receiver(msg2)
+ msg = {'command': 'api', 'args': {}}
+ response = one.server.receiver(msg, user)
+ assert response.err is None
+ assert response.content is not None
# remove the command field - should error
msg3 = dict(msg)
msg3.pop('command')
- assert 'error' in one.server.receiver(msg3)
+ assert one.server.receiver(msg3, user).err is not None
# provide an invalid command - should error
msg4 = {**msg, 'command': 'foobar'}
- assert 'error' in one.server.receiver(msg4)
+ assert one.server.receiver(msg4, user).err is not None
# simulate a command failure with the original message
# (the one which worked earlier) - should error
def _api(*args, **kwargs):
raise Exception('foo')
one.server.api = _api
- assert 'error' in one.server.receiver(msg)
+ assert one.server.receiver(msg, user).err is not None