Skip to content

Commit

Permalink
feat: python hostcall (#860)
Browse files Browse the repository at this point in the history
Dead lock on python worker

<!--
Pull requests are squashed and merged using:
- their title as the commit message
- their description as the commit body

Having a good title and description is important for the users to get
readable changelog.
-->

<!-- 1. Explain WHAT the change is about -->

-

<!-- 2. Explain WHY the change cannot be made simpler -->

-

<!-- 3. Explain HOW users should update their code -->

#### Migration notes

...

- [ ] The change comes with new or modified tests
- [ ] Hard-to-understand functions have explanatory comments
- [ ] End-user documentation is updated to reflect the change
  • Loading branch information
j03-dev authored Oct 10, 2024
1 parent daec029 commit 9aefe89
Show file tree
Hide file tree
Showing 15 changed files with 508 additions and 232 deletions.
2 changes: 1 addition & 1 deletion src/metagen/src/fdk_python/static/main.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from .{{ mod_name }}_types import {{ imports }}

{% for func in funcs %}
@typed_{{ func.name }}
def {{ func.name }}(inp: {{ func.input_name }}) -> {{ func.output_name }}:
def {{ func.name }}(inp: {{ func.input_name }}, ctx: Ctx) -> {{ func.output_name }}:
# TODO: write your logic here
raise Exception("{{ func.name }} not implemented")
{% endfor %}
11 changes: 8 additions & 3 deletions src/metagen/src/fdk_python/static/types.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ from dataclasses import dataclass, asdict, fields

FORWARD_REFS = {}


class Ctx:
def gql(self, query: str, variables: str) -> Any:
pass

class Struct:
def repr(self):
return asdict(self)
Expand Down Expand Up @@ -93,10 +98,10 @@ def __repr(value: Any):


{%for func in funcs %}
def typed_{{ func.name }}(user_fn: Callable[[{{ func.input_name }}], {{ func.output_name }}]):
def exported_wrapper(raw_inp):
def typed_{{ func.name }}(user_fn: Callable[[{{ func.input_name }}, Ctx], {{ func.output_name }}]):
def exported_wrapper(raw_inp, ctx):
inp: {{ func.input_name }} = Struct.new({{ func.input_name }}, raw_inp)
out: {{ func.output_name }} = user_fn(inp)
out: {{ func.output_name }} = user_fn(inp, ctx)
if isinstance(out, list):
return [__repr(v) for v in out]
return __repr(out)
Expand Down
68 changes: 43 additions & 25 deletions src/pyrt_wit_wire/main.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,52 @@
import wit_wire.exports
import importlib
import importlib.abc
import importlib.machinery
import importlib.util
import json
import os
import inspect
import sys
import traceback
import types
from typing import Any, Callable, Dict, TypeVar

# NOTE: all imports must be toplevel as constrained by `componentize-py`
# https://github.com/bytecodealliance/componentize-py/issues/23
# from pyrt.imports.typegate_wire import hostcall
import wit_wire.exports
from wit_wire.exports.mat_wire import (
Err,
HandleErr_HandlerErr,
HandleErr_InJsonErr,
HandleErr_NoHandler,
HandleReq,
InitArgs,
InitResponse,
InitError_UnexpectedMat,
InitError_Other,
InitError_UnexpectedMat,
InitResponse,
MatInfo,
HandleReq,
HandleErr_NoHandler,
HandleErr_InJsonErr,
HandleErr_HandlerErr,
Err,
)

import json
import types
from typing import Callable, Any, Dict
import importlib
import importlib.util
import importlib.abc
import importlib.machinery
import os
import sys
import traceback
# NOTE: all imports must be toplevel as constrained by `componentize-py`
# https://github.com/bytecodealliance/componentize-py/issues/23
# from pyrt.imports.typegate_wire import hostcall
from wit_wire.imports.typegate_wire import hostcall

# the `MatWire` class is instantiated for each
# external call. We have to put any persisted
# state here.
handlers = {}


T = TypeVar("T")
HandlerFn = Callable[..., T]


class Ctx:
def gql(self, query: str, variables: str) -> Any:
data = json.loads(
hostcall("gql", json=json.dumps({"query": query, "variables": variables}))
)
return data["data"]


class MatWire(wit_wire.exports.MatWire):
def init(self, args: InitArgs):
for op in args.expected_ops:
Expand Down Expand Up @@ -64,12 +78,16 @@ def handle(self, req: HandleReq):


class ErasedHandler:
def __init__(self, handler_fn: Callable[[Any], Any]) -> None:
def __init__(self, handler_fn: HandlerFn[T]) -> None:
self.handler_fn = handler_fn
self.param_count = len(inspect.signature(self.handler_fn).parameters)

def handle(self, req: HandleReq):
in_parsed = json.loads(req.in_json)
out = self.handler_fn(in_parsed)
if self.param_count == 1:
out = self.handler_fn(in_parsed)
else:
out = self.handler_fn(in_parsed, Ctx())
return json.dumps(out)


Expand All @@ -79,7 +97,7 @@ def op_to_handler(op: MatInfo) -> ErasedHandler:
module = types.ModuleType(op.op_name)
exec(data_parsed["source"], module.__dict__)
fn = module.__dict__[data_parsed["func_name"]]
return ErasedHandler(handler_fn=lambda inp: fn(inp))
return ErasedHandler(handler_fn=fn)
elif data_parsed["ty"] == "import_function":
prefix = data_parsed["func_name"]

Expand All @@ -101,7 +119,7 @@ def op_to_handler(op: MatInfo) -> ErasedHandler:
return ErasedHandler(handler_fn=getattr(module, data_parsed["func_name"]))
elif data_parsed["ty"] == "lambda":
fn = eval(data_parsed["source"])
return ErasedHandler(handler_fn=lambda inp: fn(inp))
return ErasedHandler(handler_fn=fn)
else:
raise Err(InitError_UnexpectedMat(op))

Expand Down
24 changes: 22 additions & 2 deletions src/typegate/src/runtimes/wit_wire/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export class WitWireMessenger {

async handle(opName: string, args: ResolverArgs) {
const { _, ...inJson } = args;

let res;
try {
res = await Meta.wit_wire.handle(this.id, {
Expand Down Expand Up @@ -173,8 +174,12 @@ async function hostcall(cx: HostCallCtx, op_name: string, json: string) {
async function gql(cx: HostCallCtx, args: object) {
const argsValidator = zod.object({
query: zod.string(),
variables: zod.record(zod.string(), zod.unknown()),
variables: zod.union([
zod.string(),
zod.record(zod.string(), zod.unknown()),
]),
});

const parseRes = argsValidator.safeParse(args);
if (!parseRes.success) {
throw new Error("error validating gql args", {
Expand All @@ -184,6 +189,19 @@ async function gql(cx: HostCallCtx, args: object) {
});
}
const parsed = parseRes.data;

// Convert variables to an object if it's a string
let variables = parsed.variables;
if (typeof variables === "string") {
try {
variables = JSON.parse(variables);
} catch (error) {
throw new Error("Failed to parse variables string as JSON", {
cause: error,
});
}
}

const request = new Request(cx.typegraphUrl, {
method: "POST",
headers: {
Expand All @@ -193,15 +211,17 @@ async function gql(cx: HostCallCtx, args: object) {
},
body: JSON.stringify({
query: parsed.query,
variables: parsed.variables,
variables: variables,
}),
});

//TODO: make `handle` more friendly to internal requests
const res = await cx.typegate.handle(request, {
port: 0,
hostname: "internal",
transport: "tcp",
});

if (!res.ok) {
const text = await res.text();
throw new Error(`gql fetch on ${cx.typegraphUrl} failed: ${text}`, {
Expand Down
11 changes: 10 additions & 1 deletion tests/internal/internal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typegraph.graph.typegraph import Graph
from typegraph.policy import Policy
from typegraph.runtimes.deno import DenoRuntime
from typegraph.runtimes.python import PythonRuntime

from typegraph import t, typegraph

Expand All @@ -11,6 +12,7 @@ def internal(g: Graph):
internal = Policy.internal()

deno = DenoRuntime()
python = PythonRuntime()

inp = t.struct({"first": t.float(), "second": t.float()})
out = t.float()
Expand All @@ -19,7 +21,14 @@ def internal(g: Graph):
sum=deno.import_(inp, out, module="ts/logic.ts", name="sum").with_policy(
internal
),
remoteSum=deno.import_(
remoteSumDeno=deno.import_(
inp, out, module="ts/logic.ts", name="remoteSum"
).with_policy(public),
remoteSumPy=python.import_(
inp,
out,
module="py/logic.py",
name="remote_sum",
deps=["./py/logic_types.py"],
).with_policy(public),
)
38 changes: 36 additions & 2 deletions tests/internal/internal_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
// SPDX-License-Identifier: Elastic-2.0

import { gql, Meta } from "../utils/mod.ts";
import { join } from "@std/path/join";
import { assertEquals } from "@std/assert";

Meta.test({
name: "client table suite",
}, async (_) => {
const scriptsPath = join(import.meta.dirname!, ".");

assertEquals(
(
await Meta.cli(
{
env: {
// RUST_BACKTRACE: "1",
},
},
...`-C ${scriptsPath} gen`.split(" "),
)
).code,
0,
);
});

Meta.test(
{
Expand All @@ -13,11 +35,23 @@ Meta.test(
await t.should("work on the default worker", async () => {
await gql`
query {
remoteSum(first: 1.2, second: 2.3)
remoteSumDeno(first: 1.2, second: 2.3)
}
`
.expectData({
remoteSumDeno: 3.5,
})
.on(e, `http://localhost:${t.port}`);
});

await t.should("hostcall python work", async () => {
await gql`
query {
remoteSumPy(first: 1.2, second: 2.3)
}
`
.expectData({
remoteSum: 3.5,
remoteSumPy: 3.5,
})
.on(e, `http://localhost:${t.port}`);
});
Expand Down
12 changes: 12 additions & 0 deletions tests/internal/metatype.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
typegates:
dev:
url: "http://localhost:7890"
username: admin
password: password

metagen:
targets:
main:
- generator: fdk_python
path: ./py/
typegraph_path: internal.py
18 changes: 18 additions & 0 deletions tests/internal/py/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright Metatype OÜ, licensed under the Elastic License 2.0.
# SPDX-License-Identifier: Elastic-2.0

from .logic_types import Ctx
import json


def remote_sum(inp: dict, ctx: Ctx) -> float:
data = ctx.gql(
query="""
query q($first: Float!, $second: Float!) {
sum(first: $first, second: $second)
}
""",
variables=json.dumps(inp),
)
sum = data["sum"]
return sum
Loading

0 comments on commit 9aefe89

Please sign in to comment.