Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor client/server & GraphQL workflow mutations #4529

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cylc/flow/async_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def _chain(self, item, coros, completed):
ret = await coro.func(item, *coro.args, **coro.kwargs)
except Exception as exc:
# if something goes wrong log the error and skip the item
LOG.warning(exc)
LOG.warning(f"{type(exc).__name__}: {exc}")
ret = False
if ret is True:
# filter passed -> continue
Expand Down
16 changes: 14 additions & 2 deletions cylc/flow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,15 @@ def __str__(self):

class ClientError(CylcError):

def __init__(self, message: str, traceback: Optional[str] = None):
def __init__(
self,
message: str,
traceback: Optional[str] = None,
workflow: Optional[str] = None
):
self.message = message
self.traceback = traceback
self.workflow = workflow

def __str__(self) -> str:
ret = self.message
Expand All @@ -277,7 +283,13 @@ def __str__(self):


class ClientTimeout(CylcError):
pass

def __init__(self, message: str, workflow: Optional[str] = None):
self.message = message
self.workflow = workflow

def __str__(self) -> str:
return self.message


class CyclingError(CylcError):
Expand Down
22 changes: 9 additions & 13 deletions cylc/flow/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
"""Package for network interfaces to Cylc scheduler objects."""

import asyncio
import getpass
import json
from typing import NamedTuple, Optional

import zmq
import zmq.asyncio
Expand All @@ -44,19 +43,16 @@
MSG_TIMEOUT = "TIMEOUT"


def encode_(message):
"""Convert the structure holding a message field from JSON to a string."""
try:
return json.dumps(message)
except TypeError as exc:
return json.dumps({'errors': [{'message': str(exc)}]})
class ResponseTuple(NamedTuple):
"""Structure of server response messages."""
content: Optional[object] = None
err: Optional['ResponseErrTuple'] = None
user: Optional[str] = None


def decode_(message):
"""Convert an encoded message string to JSON with an added 'user' field."""
msg = json.loads(message)
msg['user'] = getpass.getuser() # assume this is the user
return msg
class ResponseErrTuple(NamedTuple):
message: str
traceback: Optional[str] = None


def get_location(workflow: str):
Expand Down
40 changes: 18 additions & 22 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Client for workflow runtime API."""

from functools import partial
import json
import os
from shutil import which
import socket
Expand All @@ -34,9 +35,8 @@
WorkflowStopped,
)
from cylc.flow.network import (
encode_,
decode_,
get_location,
ResponseTuple,
ZMQSocketBase
)
from cylc.flow.network.client_factory import CommsMeth
Expand Down Expand Up @@ -165,7 +165,7 @@ async def async_request(
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
) -> Union[bytes, object]:
"""Send an asynchronous request using asyncio.

Has the same arguments and return values as ``serial_request``.
Expand All @@ -187,12 +187,12 @@ async def async_request(
if req_meta:
msg['meta'].update(req_meta)
LOG.debug('zmq:send %s', msg)
message = encode_(msg)
message = json.dumps(msg)
self.socket.send_string(message)

# receive response
if self.poller.poll(timeout):
res = await self.socket.recv()
res: bytes = await self.socket.recv()
else:
if callable(self.timeout_handler):
self.timeout_handler()
Expand All @@ -204,32 +204,28 @@ async def async_request(
' This could be due to network or server issues.'
' Check the workflow log.'
)
LOG.debug('zmq:recv %s', res)

if msg['command'] in PB_METHOD_MAP:
response = {'data': res}
else:
response = decode_(res.decode())
LOG.debug('zmq:recv %s', response)
if command in PB_METHOD_MAP:
return res

try:
return response['data']
except KeyError:
error = response.get(
'error',
{'message': f'Received invalid response: {response}'},
)
raise ClientError(
error.get('message'),
error.get('traceback'),
)
response = ResponseTuple( # type: ignore[misc]
*json.loads(res.decode())
)

if response.content is not None:
return response.content
if response.err:
raise ClientError(*response.err)
raise ClientError("No response from server. Check the workflow log.")

def serial_request(
self,
command: str,
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
) -> Union[bytes, object]:
"""Send a request.

For convenience use ``__call__`` to call this method.
Expand Down
14 changes: 11 additions & 3 deletions cylc/flow/network/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@

from functools import partial
import logging
from typing import TYPE_CHECKING, Any, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union

from inspect import isclass, iscoroutinefunction

from graphene.utils.str_converters import to_snake_case
from graphql.execution.utils import (
get_operation_root_type, get_field_def
)
from graphql.execution import ExecutionResult
from graphql.execution.values import get_argument_values, get_variable_values
from graphql.language.base import parse, print_ast
from graphql.language import ast
Expand All @@ -42,7 +43,6 @@
from cylc.flow.network.schema import NODE_MAP, get_type_str

if TYPE_CHECKING:
from graphql.execution import ExecutionResult
from graphql.language.ast import Document
from graphql.type import GraphQLSchema

Expand Down Expand Up @@ -146,6 +146,14 @@ def null_stripper(exe_result):
return exe_result


def format_execution_result(
result: Union[ExecutionResult, Dict[str, Any]]
) -> Dict[str, Any]:
if isinstance(result, ExecutionResult):
result = result.to_dict()
return strip_null(result)


class AstDocArguments:
"""Request doc Argument inspection."""

Expand Down Expand Up @@ -254,7 +262,7 @@ def execute_and_validate_and_strip(
document_ast: 'Document',
*args: Any,
**kwargs: Any
) -> Union['ExecutionResult', Observable]:
) -> Union[ExecutionResult, Observable]:
"""Wrapper around graphql ``execute_and_validate()`` that adds
null stripping."""
result = execute_and_validate(schema, document_ast, *args, **kwargs)
Expand Down
64 changes: 44 additions & 20 deletions cylc/flow/network/replier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Server for workflow runtime API."""

import getpass # noqa: F401
import getpass
import json
from queue import Queue
from typing import TYPE_CHECKING
from typing_extensions import Literal

import zmq

from cylc.flow import LOG
from cylc.flow.network import encode_, decode_, ZMQSocketBase
from cylc.flow.network import (
ResponseErrTuple, ResponseTuple, ZMQSocketBase
)

if TYPE_CHECKING:
from cylc.flow.network.server import WorkflowRuntimeServer


class WorkflowReplier(ZMQSocketBase):
Expand All @@ -46,13 +54,17 @@ class WorkflowReplier(ZMQSocketBase):

"""

def __init__(self, server, context=None):
def __init__(
self,
server: 'WorkflowRuntimeServer',
context=None
):
super().__init__(zmq.REP, bind=True, context=context)
self.server = server
self.workflow = server.schd.workflow
self.queue = Queue()
self.queue: Queue[Literal['STOP']] = Queue()

def _bespoke_stop(self):
def _bespoke_stop(self) -> None:
"""Stop the listener and Authenticator.

Overwrites Base method.
Expand All @@ -62,7 +74,7 @@ def _bespoke_stop(self):
if self.queue is not None:
self.queue.put('STOP')

def listener(self):
def listener(self) -> None:
"""The server main loop, listen for and serve requests.

When called, this method will receive and respond until there are no
Expand Down Expand Up @@ -92,27 +104,39 @@ def listener(self):
continue
# attempt to decode the message, authenticating the user in the
# process
response: bytes
try:
message = decode_(msg)
message = json.loads(msg)
user = getpass.getuser() # assume this is the user
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
LOG.exception('failed to decode message: "%s"', exc)
LOG.exception(exc)
LOG.error(f'failed to decode message: "{msg}"')
import traceback
response = encode_(
{
'error': {
'message': 'failed to decode message: "%s"' % msg,
'traceback': traceback.format_exc(),
}
}
response = json.dumps(
ResponseTuple(
err=ResponseErrTuple(
f'failed to decode message: {msg}"',
traceback.format_exc(),
)
)
).encode()
else:
# success case - serve the request
res = self.server.receiver(message)
# send back the string to bytes response
if isinstance(res.get('data'), bytes):
response = res['data']
res = self.server.receiver(message, user)
if isinstance(res.content, bytes): # is protobuf method
# just return bytes, as cannot serialize bytes to JSON
response = res.content
else:
response = encode_(res).encode()
try:
response = json.dumps(res).encode()
except TypeError as exc:
err_msg = f"failed to encode response: {res}\n{exc}"
LOG.warning(err_msg)
res = ResponseTuple(
err=ResponseErrTuple(err_msg)
)
response = json.dumps(res).encode()
# send back the string to bytes response
self.socket.send(response)
Loading