Skip to content

Commit

Permalink
Refactor client/server to better handle GraphQL
Browse files Browse the repository at this point in the history
  • Loading branch information
MetRonnie committed Aug 10, 2022
1 parent eaf7d21 commit 53a04bc
Show file tree
Hide file tree
Showing 15 changed files with 272 additions and 241 deletions.
2 changes: 1 addition & 1 deletion cylc/flow/async_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def _chain(self, item, coros, completed):
ret = await coro.func(item, *coro.args, **coro.kwargs)
except Exception as exc:
# if something goes wrong log the error and skip the item
LOG.warning(exc)
LOG.warning(f"{type(exc).__name__}: {exc}")
ret = False
if ret is True:
# filter passed -> continue
Expand Down
29 changes: 11 additions & 18 deletions cylc/flow/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
"""Package for network interfaces to Cylc scheduler objects."""

import asyncio
import getpass
import json
from typing import Any, Dict
from typing import NamedTuple, Optional

import zmq
import zmq.asyncio
Expand All @@ -45,21 +43,16 @@
MSG_TIMEOUT = "TIMEOUT"


def encode_(message: object) -> str:
"""Convert the structure holding a message field from JSON to a string."""
try:
return json.dumps(message)
except TypeError as exc:
return json.dumps({'errors': [{'message': str(exc)}]})


def decode_(message: str) -> Dict[str, Any]:
"""Convert an encoded message string to JSON with an added 'user' field."""
msg: object = json.loads(message)
if not isinstance(msg, dict):
raise ValueError(f"Expected message to be dict but got {type(msg)}")
msg['user'] = getpass.getuser() # assume this is the user
return msg
class ResponseTuple(NamedTuple):
"""Structure of server response messages."""
content: Optional[object] = None
err: Optional['ResponseErrTuple'] = None
user: Optional[str] = None


class ResponseErrTuple(NamedTuple):
message: str
traceback: Optional[str] = None


def get_location(workflow: str):
Expand Down
40 changes: 18 additions & 22 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Client for workflow runtime API."""

from functools import partial
import json
import os
from shutil import which
import socket
Expand All @@ -34,9 +35,8 @@
WorkflowStopped,
)
from cylc.flow.network import (
encode_,
decode_,
get_location,
ResponseTuple,
ZMQSocketBase
)
from cylc.flow.network.client_factory import CommsMeth
Expand Down Expand Up @@ -165,7 +165,7 @@ async def async_request(
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
) -> Union[bytes, object]:
"""Send an asynchronous request using asyncio.
Has the same arguments and return values as ``serial_request``.
Expand All @@ -187,12 +187,12 @@ async def async_request(
if req_meta:
msg['meta'].update(req_meta)
LOG.debug('zmq:send %s', msg)
message = encode_(msg)
message = json.dumps(msg)
self.socket.send_string(message)

# receive response
if self.poller.poll(timeout):
res = await self.socket.recv()
res: bytes = await self.socket.recv()
else:
if callable(self.timeout_handler):
self.timeout_handler()
Expand All @@ -204,32 +204,28 @@ async def async_request(
' This could be due to network or server issues.'
' Check the workflow log.'
)
LOG.debug('zmq:recv %s', res)

if msg['command'] in PB_METHOD_MAP:
response = {'data': res}
else:
response = decode_(res.decode())
LOG.debug('zmq:recv %s', response)
if command in PB_METHOD_MAP:
return res

try:
return response['data']
except KeyError:
error = response.get(
'error',
{'message': f'Received invalid response: {response}'},
)
raise ClientError(
error.get('message'),
error.get('traceback'),
)
response = ResponseTuple( # type: ignore[misc]
*json.loads(res.decode())
)

if response.content is not None:
return response.content
if response.err:
raise ClientError(*response.err)
raise ClientError(f"Received invalid response: {response}")

def serial_request(
self,
command: str,
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
) -> Union[bytes, object]:
"""Send a request.
For convenience use ``__call__`` to call this method.
Expand Down
14 changes: 11 additions & 3 deletions cylc/flow/network/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@

from functools import partial
import logging
from typing import TYPE_CHECKING, Any, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union

from inspect import isclass, iscoroutinefunction

from graphene.utils.str_converters import to_snake_case
from graphql.execution.utils import (
get_operation_root_type, get_field_def
)
from graphql.execution import ExecutionResult
from graphql.execution.values import get_argument_values, get_variable_values
from graphql.language.base import parse, print_ast
from graphql.language import ast
Expand All @@ -42,7 +43,6 @@
from cylc.flow.network.schema import NODE_MAP, get_type_str

if TYPE_CHECKING:
from graphql.execution import ExecutionResult
from graphql.language.ast import Document
from graphql.type import GraphQLSchema

Expand Down Expand Up @@ -146,6 +146,14 @@ def null_stripper(exe_result):
return exe_result


def format_execution_result(
result: Union[ExecutionResult, Dict[str, Any]]
) -> Dict[str, Any]:
if isinstance(result, ExecutionResult):
result = result.to_dict()
return strip_null(result)


class AstDocArguments:
"""Request doc Argument inspection."""

Expand Down Expand Up @@ -254,7 +262,7 @@ def execute_and_validate_and_strip(
document_ast: 'Document',
*args: Any,
**kwargs: Any
) -> Union['ExecutionResult', Observable]:
) -> Union[ExecutionResult, Observable]:
"""Wrapper around graphql ``execute_and_validate()`` that adds
null stripping."""
result = execute_and_validate(schema, document_ast, *args, **kwargs)
Expand Down
60 changes: 42 additions & 18 deletions cylc/flow/network/replier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Server for workflow runtime API."""

import getpass # noqa: F401
import getpass
import json
from queue import Queue
from typing import TYPE_CHECKING
from typing_extensions import Literal

import zmq

from cylc.flow import LOG
from cylc.flow.network import encode_, decode_, ZMQSocketBase
from cylc.flow.network import (
ResponseErrTuple, ResponseTuple, ZMQSocketBase
)

if TYPE_CHECKING:
from cylc.flow.network.server import WorkflowRuntimeServer


class WorkflowReplier(ZMQSocketBase):
Expand All @@ -46,11 +54,15 @@ class WorkflowReplier(ZMQSocketBase):
"""

def __init__(self, server, context=None):
def __init__(
self,
server: 'WorkflowRuntimeServer',
context=None
):
super().__init__(zmq.REP, bind=True, context=context)
self.server = server
self.workflow = server.schd.workflow
self.queue = Queue()
self.queue: Queue[Literal['STOP']] = Queue()

def _bespoke_stop(self) -> None:
"""Stop the listener and Authenticator.
Expand Down Expand Up @@ -92,27 +104,39 @@ def listener(self) -> None:
continue
# attempt to decode the message, authenticating the user in the
# process
response: bytes
try:
message = decode_(msg)
message = json.loads(msg)
user = getpass.getuser() # assume this is the user
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
LOG.exception('failed to decode message: "%s"', exc)
LOG.exception(exc)
LOG.error(f'failed to decode message: "{msg}"')
import traceback
response = encode_(
{
'error': {
'message': 'failed to decode message: "%s"' % msg,
'traceback': traceback.format_exc(),
}
}
response = json.dumps(
ResponseTuple(
err=ResponseErrTuple(
f'failed to decode message: {msg}"',
traceback.format_exc(),
)
)
).encode()
else:
# success case - serve the request
res = self.server.receiver(message)
# send back the string to bytes response
if isinstance(res.get('data'), bytes):
response = res['data']
res = self.server.receiver(message, user)
if isinstance(res.content, bytes): # is protobuf method
# just return bytes, as cannot serialize bytes to JSON
response = res.content
else:
response = encode_(res).encode()
try:
response = json.dumps(res).encode()
except TypeError as exc:
err_msg = f"failed to encode response: {res}\n{exc}"
LOG.warning(err_msg)
res = ResponseTuple(
err=ResponseErrTuple(err_msg)
)
response = json.dumps(res).encode()
# send back the string to bytes response
self.socket.send(response)
14 changes: 9 additions & 5 deletions cylc/flow/network/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from cylc.flow.id import Tokens
from cylc.flow.network.schema import (
DEF_TYPES,
GenericResponse,
NodesEdges,
PROXY_NODES,
SUB_RESOLVERS,
Expand Down Expand Up @@ -629,7 +630,7 @@ async def mutator(
w_args: Dict[str, Any],
kwargs: Dict[str, Any],
meta: Dict[str, Any]
) -> List[Dict[str, Any]]:
) -> List[GenericResponse]:
...


Expand All @@ -650,19 +651,22 @@ async def mutator(
w_args: Dict[str, Any],
kwargs: Dict[str, Any],
meta: Dict[str, Any]
) -> List[Dict[str, Any]]:
) -> List[GenericResponse]:
"""Mutate workflow."""
w_ids = [flow[WORKFLOW].id
for flow in await self.get_workflows_data(w_args)]
if not w_ids:
workflows = list(self.data_store_mgr.data.keys())
return [{
'response': (False, f'No matching workflow in {workflows}')}]
ret = GenericResponse(
success=False, message=f'No matching workflow in {workflows}'
)
return [ret]
w_id = w_ids[0]
result = await self._mutation_mapper(command, kwargs, meta)
if result is None:
result = (True, 'Command queued')
return [{'id': w_id, 'response': result}]
ret = GenericResponse(w_id, *result)
return [ret]

async def _mutation_mapper(
self, command: str, kwargs: Dict[str, Any], meta: Dict[str, Any]
Expand Down
Loading

0 comments on commit 53a04bc

Please sign in to comment.