From 0081436015d05da6875855ee017241fdf6c289d8 Mon Sep 17 00:00:00 2001 From: Luckas Date: Sun, 1 Dec 2024 07:13:12 +0300 Subject: [PATCH] feat(metagen): python client file upload (#931) - Closes [MET-769](https://linear.app/metatypedev/issue/MET-769/add-file-upload-support-for-python-metagen-client) #### Migration notes --- - [x] The change comes with new or modified tests - [ ] Hard-to-understand functions have explanatory comments - [ ] End-user documentation is updated to reflect the change --- src/metagen/src/client_py/mod.rs | 3 +- src/metagen/src/client_py/node_metas.rs | 33 +- src/metagen/src/client_py/static/client.py | 231 ++++- src/metagen/src/client_py/types.rs | 2 +- src/metagen/src/client_rs/node_metas.rs | 2 +- src/metagen/src/shared/files.rs | 66 +- tests/metagen/metagen_test.ts | 302 +++--- tests/metagen/typegraphs/sample/metatype.yml | 3 + tests/metagen/typegraphs/sample/py/client.py | 249 ++++- .../typegraphs/sample/py_upload/client.py | 894 ++++++++++++++++++ .../typegraphs/sample/py_upload/main.py | 39 + 11 files changed, 1612 insertions(+), 212 deletions(-) create mode 100644 tests/metagen/typegraphs/sample/py_upload/client.py create mode 100644 tests/metagen/typegraphs/sample/py_upload/main.py diff --git a/src/metagen/src/client_py/mod.rs b/src/metagen/src/client_py/mod.rs index 7dc401f5de..9f29a93a4a 100644 --- a/src/metagen/src/client_py/mod.rs +++ b/src/metagen/src/client_py/mod.rs @@ -210,7 +210,7 @@ class QueryGraph(QueryGraphBase): {{"{node_name}": NodeDescs.{meta_method}}}, "$q" )[0] - return {node_type}(node.node_name, node.instance_name, node.args, node.sub_nodes) + return {node_type}(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) "# )?; } @@ -280,6 +280,7 @@ fn render_node_metas( Rc::new(node_metas::PyNodeMetasRenderer { name_mapper, named_types: named_types.clone(), + input_files: manifest.input_files.clone(), }), ); for &id in &manifest.node_metas { diff --git a/src/metagen/src/client_py/node_metas.rs b/src/metagen/src/client_py/node_metas.rs index 4ba688ba8f..1c5d29e404 100644 --- a/src/metagen/src/client_py/node_metas.rs +++ b/src/metagen/src/client_py/node_metas.rs @@ -1,16 +1,20 @@ // Copyright Metatype OÜ, licensed under the Mozilla Public License Version 2.0. // SPDX-License-Identifier: MPL-2.0 -use std::fmt::Write; +use std::{collections::HashMap, fmt::Write, ops::Not}; use common::typegraph::*; use super::utils::normalize_type_title; -use crate::{interlude::*, shared::types::*}; +use crate::{ + interlude::*, + shared::{files::TypePath, types::*}, +}; pub struct PyNodeMetasRenderer { pub name_mapper: Rc, pub named_types: Rc>>, + pub input_files: Rc>>, } impl PyNodeMetasRenderer { @@ -52,6 +56,7 @@ impl PyNodeMetasRenderer { ty_name: &str, return_node: &str, argument_fields: Option>>, + input_files: Option, ) -> std::fmt::Result { write!( dest, @@ -84,6 +89,13 @@ impl PyNodeMetasRenderer { }},"# )?; } + if let Some(input_files) = input_files { + write!( + dest, + r#" + input_files={input_files},"# + )?; + } write!( dest, r#" @@ -172,7 +184,22 @@ impl RenderType for PyNodeMetasRenderer { }; let node_name = &base.title; let ty_name = normalize_type_title(node_name).to_pascal_case(); - self.render_for_func(renderer, &ty_name, &return_ty_name, props)?; + let input_files = self + .input_files + .get(&cursor.id) + .map(|files| { + files + .iter() + .map(|path| path.to_vec_str()) + .collect::>() + }) + .and_then(|files| { + files + .is_empty() + .not() + .then_some(serde_json::to_string(&files).unwrap()) + }); + self.render_for_func(renderer, &ty_name, &return_ty_name, props, input_files)?; ty_name } TypeNode::Object { data, base } => { diff --git a/src/metagen/src/client_py/static/client.py b/src/metagen/src/client_py/static/client.py index d943e26f9f..d77766eb43 100644 --- a/src/metagen/src/client_py/static/client.py +++ b/src/metagen/src/client_py/static/client.py @@ -1,9 +1,14 @@ +import io +import re +import uuid import typing import dataclasses as dc import json +import urllib import urllib.request as request import urllib.error import http.client as http_c +import mimetypes def selection_to_nodes( @@ -148,6 +153,7 @@ def selection_to_nodes( instance_name="__typename", args=None, sub_nodes=None, + files=None, ) ) @@ -161,7 +167,9 @@ def selection_to_nodes( ) sub_nodes = union_selections - node = SelectNode(node_name, instance_name, instance_args, sub_nodes) + node = SelectNode( + node_name, instance_name, instance_args, sub_nodes, meta.input_files + ) out.append(node) found_nodes.discard("_") @@ -184,6 +192,13 @@ def selection_to_nodes( SelectionT = typing.TypeVar("SelectionT") +@dc.dataclass +class File: + content: bytes + name: str + mimetype: typing.Optional[str] = None + + # # --- --- Graph node types --- --- # # @@ -197,6 +212,9 @@ def selection_to_nodes( typing.Dict[str, typing.List["SelectNode"]], ] +TypePath = typing.List[typing.Union[typing.Literal["?"], typing.Literal["[]"], str]] +ValuePath = typing.List[typing.Union[typing.Literal[""], str]] + @dc.dataclass class SelectNode(typing.Generic[Out]): @@ -204,6 +222,7 @@ class SelectNode(typing.Generic[Out]): instance_name: str args: typing.Optional["NodeArgs"] sub_nodes: SubNodes + files: typing.Optional[typing.List[TypePath]] @dc.dataclass @@ -224,6 +243,92 @@ class NodeMeta: sub_nodes: typing.Optional[typing.Dict[str, NodeMetaFn]] = None variants: typing.Optional[typing.Dict[str, NodeMetaFn]] = None arg_types: typing.Optional[typing.Dict[str, str]] = None + input_files: typing.Optional[typing.List[TypePath]] = None + + +class FileExtractor: + def __init__(self): + self.path: TypePath = [] + self.current_path: ValuePath = [] + self.result: typing.Dict[str, File] = {} + + def extract_from_value(self, value: typing.Any): + next_segment = self.path[len(self.current_path)] + + if next_segment == "?": + if value is None: + return + self.current_path.append("") + self.extract_from_value(value) + self.current_path.pop() + return + + if next_segment == "[]": + if not isinstance(value, list): + raise Exception(f"Expected array at {self.format_path()}") + + for idx in range(len(value)): + self.current_path.append(f"[{idx}]") + self.extract_from_array(value, idx) + self.current_path.pop() + return + + if next_segment.startswith("."): + if not isinstance(value, dict): + raise Exception(f"Expected dictionary at {self.format_path()}") + + self.current_path.append(next_segment) + self.extract_from_object(value, next_segment[1:]) + self.current_path.pop() + return + + def extract_from_array(self, parent: typing.List[typing.Any], idx: int): + value = parent[idx] + + if len(self.current_path) == len(self.path): + if isinstance(value, File): + self.result[self.format_path()] = value + parent[idx] = None + return + + raise Exception(f"Expected File at {self.format_path()}") + + self.extract_from_value(value) + + def extract_from_object(self, parent: typing.Dict[str, typing.Any], key: str): + value = parent.get(key) + + if len(self.current_path) == len(self.path): + if isinstance(value, File): + self.result[self.format_path()] = value + parent[key] = None + return + + raise Exception(f"Expected File at {self.format_path()}") + + self.extract_from_value(value) + + def format_path(self): + res = "" + + for path in self.current_path: + res += f".{path[1:-1]}" if path.startswith("[") else path + + return res + + +def extract_files( + key: str, obj: typing.Dict[str, typing.Any], paths: typing.List[TypePath] +): + extractor = FileExtractor() + + for path in paths: + if path[0] and path[0].startswith("." + key): + extractor.current_path = [] + extractor.path = path + extractor.extract_from_value(obj) + + return extractor.result # @@ -343,6 +448,7 @@ def convert_query_node_gql( ty_to_gql_ty_map: typing.Dict[str, str], node: SelectNode, variables: typing.Dict[str, NodeArgValue], + files: typing.Dict[str, File], ): out = ( f"{node.instance_name}: {node.node_name}" @@ -353,6 +459,16 @@ def convert_query_node_gql( arg_row = "" for key, val in node.args.items(): name = f"in{len(variables)}" + obj = {key: val.value} + + if node.files is not None and len(node.files) > 0: + extracted_files = extract_files(key, obj, node.files) + + for path, file in extracted_files.items(): + path_in_variables = re.sub(r"^\.[^.\[]+", f".{name}", path) + files[path_in_variables] = file + + val.value = obj[key] variables[name] = val arg_row += f"{key}: ${name}, " if len(arg_row): @@ -373,21 +489,75 @@ def convert_query_node_gql( sub_node_list += f"... on {gql_ty} {{ " for node in sub_nodes: - sub_node_list += ( - f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables)} " - ) + sub_node_list += f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables, files)} " sub_node_list += "}" out += f" {{ {sub_node_list}}}" elif isinstance(node.sub_nodes, list): sub_node_list = "" for node in node.sub_nodes: sub_node_list += ( - f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables)} " + f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables, files)} " ) out += f" {{ {sub_node_list}}}" return out +class MultiPartForm: + def __init__(self): + self.form_fields: typing.List[typing.Tuple[str, str]] = [] + self.files: typing.List[typing.Tuple[str, File]] = [] + self.boundary = uuid.uuid4().hex.encode("utf-8") + + def add_field(self, name: str, value: str): + self.form_fields.append((name, value)) + + def add_file(self, key, file: File): + self.files.append((key, file)) + + def get_content_type(self): + return f"multipart/form-data; boundary={self.boundary.decode('utf-8')}" + + def _form_data(self, name): + return f'Content-Disposition: form-data; name="{name}"\r\n'.encode("utf-8") + + def _attached_file(self, name, filename): + return f'Content-Disposition: file; name="{name}"; filename="{filename}"\r\n'.encode( + "utf-8" + ) + + def _content_type(self, ct): + return f"Content-Type: {ct}\r\n".encode("utf-8") + + def __bytes__(self): + buffer = io.BytesIO() + boundary = b"--" + self.boundary + b"\r\n" + + for name, value in self.form_fields: + buffer.write(boundary) + buffer.write(self._form_data(name)) + buffer.write(b"\r\n") + buffer.write(value.encode("utf-8")) + buffer.write(b"\r\n") + + for key, file in self.files: + mimetype = ( + file.mimetype + or mimetypes.guess_type(file.name)[0] + or "application/octet-stream" + ) + + buffer.write(boundary) + buffer.write(self._attached_file(key, file.name)) + buffer.write(self._content_type(mimetype)) + buffer.write(b"\r\n") + buffer.write(file.content) + buffer.write(b"\r\n") + + buffer.write(b"--" + self.boundary + b"--\r\n") + + return buffer.getvalue() + + class GraphQLTransportBase: def __init__( self, @@ -406,10 +576,13 @@ def build_gql( name: str = "", ): variables: typing.Dict[str, NodeArgValue] = {} + files: typing.Dict[str, File] = {} root_nodes = "" for key, node in query.items(): - fixed_node = SelectNode(node.node_name, key, node.args, node.sub_nodes) - root_nodes += f" {convert_query_node_gql(self.ty_to_gql_ty_map, fixed_node, variables)}\n" + fixed_node = SelectNode( + node.node_name, key, node.args, node.sub_nodes, node.files + ) + root_nodes += f" {convert_query_node_gql(self.ty_to_gql_ty_map, fixed_node, variables, files)}\n" args_row = "" for key, val in variables.items(): args_row += f"${key}: {self.ty_to_gql_ty_map[val.type_name]}, " @@ -420,30 +593,44 @@ def build_gql( doc = f"{ty} {name}{args_row} {{\n{root_nodes}}}" variables = {key: val.value for key, val in variables.items()} # print(doc, variables) - return (doc, variables) + return (doc, variables, files) def build_req( self, doc: str, variables: typing.Dict[str, typing.Any], opts: typing.Optional[GraphQLTransportOptions] = None, + files: typing.Dict[str, File] = {}, ): headers = {} headers.update(self.opts.headers) if opts: headers.update(opts.headers) - headers.update( - { - "accept": "application/json", - "content-type": "application/json", - } - ) - data = json.dumps({"query": doc, "variables": variables}).encode("utf-8") + headers.update({"accept": "application/json"}) + + body = json.dumps({"query": doc, "variables": variables}) + + if len(files) > 0: + form_data = MultiPartForm() + form_data.add_field("operations", body) + map = {} + + for idx, (path, file) in enumerate(files.items()): + map[idx] = ["variables" + path] + form_data.add_file(f"{idx}", file) + + form_data.add_field("map", json.dumps(map)) + headers.update({"Content-type": form_data.get_content_type()}) + body = bytes(form_data) + else: + headers.update({"Content-type": "application/json"}) + body = body.encode("utf-8") + return GraphQLRequest( addr=self.addr, method="POST", headers=headers, - body=data, + body=body, ) def handle_response(self, res: GraphQLResponse): @@ -463,8 +650,9 @@ def fetch( doc: str, variables: typing.Dict[str, typing.Any], opts: typing.Optional[GraphQLTransportOptions], + files: typing.Dict[str, File] = {}, ): - req = self.build_req(doc, variables, opts) + req = self.build_req(doc, variables, opts, files) try: with request.urlopen( request.Request( @@ -498,7 +686,7 @@ def query( opts: typing.Optional[GraphQLTransportOptions] = None, name: str = "", ) -> typing.Dict[str, Out]: - doc, variables = self.build_gql( + doc, variables, _ = self.build_gql( {key: val for key, val in inp.items()}, "query", name ) return self.fetch(doc, variables, opts) @@ -509,10 +697,10 @@ def mutation( opts: typing.Optional[GraphQLTransportOptions] = None, name: str = "", ) -> typing.Dict[str, Out]: - doc, variables = self.build_gql( + doc, variables, files = self.build_gql( {key: val for key, val in inp.items()}, "mutation", name ) - return self.fetch(doc, variables, opts) + return self.fetch(doc, variables, opts, files) def prepare_query( self, @@ -538,10 +726,11 @@ def __init__( name: str = "", ): dry_run_node = fun(PreparedArgs()) - doc, variables = transport.build_gql(dry_run_node, ty, name) + doc, variables, files = transport.build_gql(dry_run_node, ty, name) self.doc = doc self._mapping = variables self.transport = transport + self.files = files def resolve_vars( self, diff --git a/src/metagen/src/client_py/types.rs b/src/metagen/src/client_py/types.rs index 0e5933154a..490ba7f4e5 100644 --- a/src/metagen/src/client_py/types.rs +++ b/src/metagen/src/client_py/types.rs @@ -164,7 +164,7 @@ impl RenderType for PyTypeRenderer { TypeNode::String { .. } => "str".into(), TypeNode::File { base, .. } if body_required => { let ty_name = normalize_type_title(&base.title); - self.render_alias(renderer, &ty_name, "bytes")?; + self.render_alias(renderer, &ty_name, "File")?; ty_name } TypeNode::File { .. } => "bytes".into(), diff --git a/src/metagen/src/client_rs/node_metas.rs b/src/metagen/src/client_rs/node_metas.rs index 9b60867c1e..c9a5fc34c2 100644 --- a/src/metagen/src/client_rs/node_metas.rs +++ b/src/metagen/src/client_rs/node_metas.rs @@ -204,7 +204,7 @@ impl RenderType for RsNodeMetasRenderer { // .map(|s| serde_json::to_string(&s).unwrap()) // .collect::>() // }) - .map(|path| format!("&[{}]", path.0.join(", "))) + .map(|path| path.serialize_rs()) .collect::>() }) .map(|files| { diff --git a/src/metagen/src/shared/files.rs b/src/metagen/src/shared/files.rs index 4cb3a5ceac..096e5a66c7 100644 --- a/src/metagen/src/shared/files.rs +++ b/src/metagen/src/shared/files.rs @@ -1,7 +1,7 @@ // Copyright Metatype OÜ, licensed under the Mozilla Public License Version 2.0. // SPDX-License-Identifier: MPL-2.0 -use std::{borrow::Cow, collections::HashMap}; +use std::collections::HashMap; use crate::interlude::*; use common::typegraph::{ @@ -11,30 +11,76 @@ use common::typegraph::{ }; #[derive(Debug)] -pub struct TypePath(pub Vec>); +pub enum ObjectPathSegment { + Prop(String), + Array, + Optional, +} + +impl TryFrom<&PathSegment> for ObjectPathSegment { + type Error = anyhow::Error; + + fn try_from(value: &PathSegment) -> Result { + match &value.edge { + Edge::ObjectProp(key) => Ok(ObjectPathSegment::Prop(key.to_owned())), + Edge::ArrayItem => Ok(ObjectPathSegment::Array), + Edge::OptionalItem => Ok(ObjectPathSegment::Optional), + Edge::UnionVariant(_) => bail!("file input is not supported in polymorphic types"), + _ => bail!("unexpected path segment in input type: {:?}", value), + } + } +} + +impl ObjectPathSegment { + pub fn serialize_rs(&self) -> String { + match self { + ObjectPathSegment::Prop(key) => format!("TypePathSegment::ObjectProp({key:?})"), + ObjectPathSegment::Array => "TypePathSegment::ArrayItem".to_owned(), + ObjectPathSegment::Optional => "TypePathSegment::Optional".to_owned(), + } + } -fn serialize_path_segment(seg: &PathSegment) -> Result> { - match &seg.edge { - Edge::ObjectProp(key) => Ok(format!("TypePathSegment::ObjectProp({key:?})").into()), - Edge::ArrayItem => Ok("TypePathSegment::ArrayItem".into()), - Edge::OptionalItem => Ok("TypePathSegment::Optional".into()), - Edge::UnionVariant(_) => bail!("file input is not supported in polymorphic types"), - _ => bail!("unexpected path segment in input type: {:?}", seg), + pub fn serialize(&self) -> String { + match self { + ObjectPathSegment::Prop(key) => format!(".{key}"), + ObjectPathSegment::Array => "[]".to_owned(), + ObjectPathSegment::Optional => "?".to_owned(), + } } } +#[derive(Debug)] +pub struct TypePath(pub Vec); + impl<'a> TryFrom<&'a [PathSegment]> for TypePath { type Error = anyhow::Error; fn try_from(tg_path: &'a [PathSegment]) -> Result { let mut path = Vec::with_capacity(tg_path.len()); for seg in tg_path { - path.push(serialize_path_segment(seg)?); + path.push(ObjectPathSegment::try_from(seg)?); } Ok(TypePath(path)) } } +impl TypePath { + pub fn serialize_rs(&self) -> String { + format!( + "&[{}]", + self.0 + .iter() + .map(|path| path.serialize_rs()) + .collect::>() + .join(", ") + ) + } + + pub fn to_vec_str(&self) -> Vec { + self.0.iter().map(|path| path.serialize()).collect() + } +} + pub fn get_path_to_files(tg: &Typegraph, root: u32) -> Result>> { visitor2::traverse_types( tg, diff --git a/tests/metagen/metagen_test.ts b/tests/metagen/metagen_test.ts index cc8fb36c62..26e6750c2a 100644 --- a/tests/metagen/metagen_test.ts +++ b/tests/metagen/metagen_test.ts @@ -213,8 +213,9 @@ Meta.test("Metagen within sdk", async (t) => { }); Meta.test("Metagen within sdk with custom template", async (t) => { - const workspace = join(import.meta.dirname!, "typegraphs") - .slice(workspaceDir.length); + const workspace = join(import.meta.dirname!, "typegraphs").slice( + workspaceDir.length, + ); const targetName = "my_target"; const genConfig = { @@ -542,158 +543,169 @@ Meta.test("fdk table suite", async (metaTest) => { } }); -Meta.test({ - name: "client table suite", -}, async (metaTest) => { - const scriptsPath = join(import.meta.dirname!, "typegraphs/sample"); +Meta.test( + { + name: "client table suite", + }, + async (metaTest) => { + const scriptsPath = join(import.meta.dirname!, "typegraphs/sample"); - assertEquals( - ( - await Meta.cli( - { - env: { - // RUST_BACKTRACE: "1", + assertEquals( + ( + await Meta.cli( + { + env: { + // RUST_BACKTRACE: "1", + }, }, - }, - ...`-C ${scriptsPath} gen`.split(" "), - ) - ).code, - 0, - ); + ...`-C ${scriptsPath} gen`.split(" "), + ) + ).code, + 0, + ); - const postSchema = zod.object({ - id: zod.string(), - slug: zod.string(), - title: zod.string(), - }); - const userSchema = zod.object({ - id: zod.string(), - email: zod.string(), - }); - const expectedSchemaQ = zod.object({ - user: userSchema.extend({ - post1: postSchema.array(), - post2: zod.object({ - // NOTE: no id - slug: zod.string(), - title: zod.string(), - }).array(), - }), - posts: postSchema, - scalarNoArgs: zod.string(), - }); - const expectedSchemaM = zod.object({ - scalarArgs: zod.string(), - compositeNoArgs: postSchema, - compositeArgs: postSchema, - }); - const expectedSchema = zod.tuple([ - expectedSchemaQ, - expectedSchemaQ, - expectedSchemaM, - expectedSchemaQ, - expectedSchemaM, - zod.object({ - scalarUnion: zod.string(), - compositeUnion1: postSchema, - compositeUnion2: zod.object({}), - mixedUnion: zod.string(), - }), - ]); - const cases = [ - { - skip: false, - name: "client_rs", - command: $`cargo run`.cwd( - join(scriptsPath, "rs"), - ), - expected: expectedSchema, - }, - { - name: "client_ts", - // NOTE: dax replaces commands to deno with - // commands to xtask so we go through bah - command: $`bash -c "deno run -A main.ts"`.cwd( - join(scriptsPath, "ts"), - ), - expected: expectedSchema, - }, - { - name: "client_py", - command: $`python3 main.py`.cwd( - join(scriptsPath, "py"), - ), - expected: expectedSchema, - }, - ]; + const postSchema = zod.object({ + id: zod.string(), + slug: zod.string(), + title: zod.string(), + }); + const userSchema = zod.object({ + id: zod.string(), + email: zod.string(), + }); + const expectedSchemaQ = zod.object({ + user: userSchema.extend({ + post1: postSchema.array(), + post2: zod + .object({ + // NOTE: no id + slug: zod.string(), + title: zod.string(), + }) + .array(), + }), + posts: postSchema, + scalarNoArgs: zod.string(), + }); + const expectedSchemaM = zod.object({ + scalarArgs: zod.string(), + compositeNoArgs: postSchema, + compositeArgs: postSchema, + }); + const expectedSchema = zod.tuple([ + expectedSchemaQ, + expectedSchemaQ, + expectedSchemaM, + expectedSchemaQ, + expectedSchemaM, + zod.object({ + scalarUnion: zod.string(), + compositeUnion1: postSchema, + compositeUnion2: zod.object({}), + mixedUnion: zod.string(), + }), + ]); + const cases = [ + { + skip: false, + name: "client_rs", + command: $`cargo run`.cwd(join(scriptsPath, "rs")), + expected: expectedSchema, + }, + { + name: "client_ts", + // NOTE: dax replaces commands to deno with + // commands to xtask so we go through bah + command: $`bash -c "deno run -A main.ts"`.cwd(join(scriptsPath, "ts")), + expected: expectedSchema, + }, + { + name: "client_py", + command: $`python3 main.py`.cwd(join(scriptsPath, "py")), + expected: expectedSchema, + }, + ]; - await using _engine = await metaTest.engine( - "metagen/typegraphs/sample.ts", - ); - for (const { name, command, expected, skip } of cases) { - if (skip) { - continue; + await using _engine = await metaTest.engine("metagen/typegraphs/sample.ts"); + for (const { name, command, expected, skip } of cases) { + if (skip) { + continue; + } + await metaTest.should(name, async () => { + // const res = await command + // .env({ "TG_PORT": metaTest.port.toString() }); + const res = await command + .env({ TG_PORT: metaTest.port.toString() }) + .text(); + expected.parse(JSON.parse(res)); + }); } - await metaTest.should(name, async () => { - // const res = await command - // .env({ "TG_PORT": metaTest.port.toString() }); - const res = await command - .env({ "TG_PORT": metaTest.port.toString() }).text(); - expected.parse(JSON.parse(res)); + }, +); + +Meta.test( + { + name: "client table suite for file upload", + }, + async (t) => { + const scriptsPath = join(import.meta.dirname!, "typegraphs/sample"); + const res = await Meta.cli({}, ...`-C ${scriptsPath} gen`.split(" ")); + // console.log("--- >>> --- >>> STDERR"); + // console.log(res.stderr); + // console.log("--- >>> --- >>> STDERR end"); + assertEquals(res.code, 0); + + const expectedSchemaU1 = zod.object({ + upload: zod.boolean(), + }); + const expectedSchemaUn = zod.object({ + uploadMany: zod.boolean(), }); - } -}); - -Meta.test({ - name: "client table suite for file upload", -}, async (t) => { - const scriptsPath = join(import.meta.dirname!, "typegraphs/sample"); - const res = await Meta.cli({}, ...`-C ${scriptsPath} gen`.split(" ")); - // console.log("--- >>> --- >>> STDERR"); - // console.log(res.stderr); - // console.log("--- >>> --- >>> STDERR end"); - assertEquals(res.code, 0); - - const expectedSchemaU1 = zod.object({ - upload: zod.boolean(), - }); - const expectedSchemaUn = zod.object({ - uploadMany: zod.boolean(), - }); - const expectedSchema = zod.tuple([ - expectedSchemaU1, - // expectedSchemaU1, - expectedSchemaUn, - expectedSchemaU1, - expectedSchemaUn, - ]); + const expectedSchema = zod.tuple([ + expectedSchemaU1, + // expectedSchemaU1, + expectedSchemaUn, + expectedSchemaU1, + expectedSchemaUn, + ]); + + const cases = [ + { + name: "client_rs_upload", + skip: false, + command: $`cargo run`.cwd(join(scriptsPath, "rs_upload")), + expected: expectedSchema, + }, + { + name: "client_py_upload", + skip: false, + command: $`bash -c "python main.py"`.cwd( + join(scriptsPath, "py_upload"), + ), + expected: zod.tuple([expectedSchemaU1, expectedSchemaUn]), + }, + ]; - const cases = [ - { - name: "client_rs_upload", - skip: false, - command: $`cargo run`.cwd(join(scriptsPath, "rs_upload")), - expected: expectedSchema, - }, - ]; + await using _engine2 = await t.engine( + "metagen/typegraphs/file_upload_sample.ts", + { secrets: { ...s3Secrets } }, + ); - await using _engine2 = await t.engine( - "metagen/typegraphs/file_upload_sample.ts", - { secrets: { ...s3Secrets } }, - ); + await prepareBucket(); - await prepareBucket(); + for (const { name, command, expected, skip } of cases) { + if (skip) { + continue; + } - for (const { name, command, expected, skip } of cases) { - if (skip) { - continue; + await t.should(name, async () => { + const res = await command + .env({ TG_PORT: t.port.toString() }) + .stderr("inherit") + .text(); + expected.parse(JSON.parse(res)); + }); } - - await t.should(name, async () => { - const res = await command - .env({ "TG_PORT": t.port.toString() }).stderr("inherit").text(); - expected.parse(JSON.parse(res)); - }); - } -}); + }, +); diff --git a/tests/metagen/typegraphs/sample/metatype.yml b/tests/metagen/typegraphs/sample/metatype.yml index 540c861185..7d3a3c3863 100644 --- a/tests/metagen/typegraphs/sample/metatype.yml +++ b/tests/metagen/typegraphs/sample/metatype.yml @@ -23,3 +23,6 @@ metagen: typegraph_path: ../file_upload_sample.ts # skip_cargo_toml: true skip_lib_rs: true + - generator: client_py + path: ./py_upload/ + typegraph_path: ../file_upload_sample.ts diff --git a/tests/metagen/typegraphs/sample/py/client.py b/tests/metagen/typegraphs/sample/py/client.py index 10a2eb34dc..b45cd8cb4c 100644 --- a/tests/metagen/typegraphs/sample/py/client.py +++ b/tests/metagen/typegraphs/sample/py/client.py @@ -1,12 +1,17 @@ # This file was @generated by metagen and is intended # to be generated again on subsequent metagen runs. +import io +import re +import uuid import typing import dataclasses as dc import json +import urllib import urllib.request as request import urllib.error import http.client as http_c +import mimetypes def selection_to_nodes( @@ -151,6 +156,7 @@ def selection_to_nodes( instance_name="__typename", args=None, sub_nodes=None, + files=None, ) ) @@ -164,7 +170,9 @@ def selection_to_nodes( ) sub_nodes = union_selections - node = SelectNode(node_name, instance_name, instance_args, sub_nodes) + node = SelectNode( + node_name, instance_name, instance_args, sub_nodes, meta.input_files + ) out.append(node) found_nodes.discard("_") @@ -187,6 +195,13 @@ def selection_to_nodes( SelectionT = typing.TypeVar("SelectionT") +@dc.dataclass +class File: + content: bytes + name: str + mimetype: typing.Optional[str] = None + + # # --- --- Graph node types --- --- # # @@ -200,6 +215,9 @@ def selection_to_nodes( typing.Dict[str, typing.List["SelectNode"]], ] +TypePath = typing.List[typing.Union[typing.Literal["?"], typing.Literal["[]"], str]] +ValuePath = typing.List[typing.Union[typing.Literal[""], str]] + @dc.dataclass class SelectNode(typing.Generic[Out]): @@ -207,6 +225,7 @@ class SelectNode(typing.Generic[Out]): instance_name: str args: typing.Optional["NodeArgs"] sub_nodes: SubNodes + files: typing.Optional[typing.List[TypePath]] @dc.dataclass @@ -227,6 +246,92 @@ class NodeMeta: sub_nodes: typing.Optional[typing.Dict[str, NodeMetaFn]] = None variants: typing.Optional[typing.Dict[str, NodeMetaFn]] = None arg_types: typing.Optional[typing.Dict[str, str]] = None + input_files: typing.Optional[typing.List[TypePath]] = None + + +class FileExtractor: + def __init__(self): + self.path: TypePath = [] + self.current_path: ValuePath = [] + self.result: typing.Dict[str, File] = {} + + def extract_from_value(self, value: typing.Any): + next_segment = self.path[len(self.current_path)] + + if next_segment == "?": + if value is None: + return + self.current_path.append("") + self.extract_from_value(value) + self.current_path.pop() + return + + if next_segment == "[]": + if not isinstance(value, list): + raise Exception(f"Expected array at {self.format_path()}") + + for idx in range(len(value)): + self.current_path.append(f"[{idx}]") + self.extract_from_array(value, idx) + self.current_path.pop() + return + + if next_segment.startswith("."): + if not isinstance(value, dict): + raise Exception(f"Expected dictionary at {self.format_path()}") + + self.current_path.append(next_segment) + self.extract_from_object(value, next_segment[1:]) + self.current_path.pop() + return + + def extract_from_array(self, parent: typing.List[typing.Any], idx: int): + value = parent[idx] + + if len(self.current_path) == len(self.path): + if isinstance(value, File): + self.result[self.format_path()] = value + parent[idx] = None + return + + raise Exception(f"Expected File at {self.format_path()}") + + self.extract_from_value(value) + + def extract_from_object(self, parent: typing.Dict[str, typing.Any], key: str): + value = parent.get(key) + + if len(self.current_path) == len(self.path): + if isinstance(value, File): + self.result[self.format_path()] = value + parent[key] = None + return + + raise Exception(f"Expected File at {self.format_path()}") + + self.extract_from_value(value) + + def format_path(self): + res = "" + + for path in self.current_path: + res += f".{path[1:-1]}" if path.startswith("[") else path + + return res + + +def extract_files( + key: str, obj: typing.Dict[str, typing.Any], paths: typing.List[TypePath] +): + extractor = FileExtractor() + + for path in paths: + if path[0] and path[0].startswith("." + key): + extractor.current_path = [] + extractor.path = path + extractor.extract_from_value(obj) + + return extractor.result # @@ -346,6 +451,7 @@ def convert_query_node_gql( ty_to_gql_ty_map: typing.Dict[str, str], node: SelectNode, variables: typing.Dict[str, NodeArgValue], + files: typing.Dict[str, File], ): out = ( f"{node.instance_name}: {node.node_name}" @@ -356,6 +462,16 @@ def convert_query_node_gql( arg_row = "" for key, val in node.args.items(): name = f"in{len(variables)}" + obj = {key: val.value} + + if node.files is not None and len(node.files) > 0: + extracted_files = extract_files(key, obj, node.files) + + for path, file in extracted_files.items(): + path_in_variables = re.sub(r"^\.[^.\[]+", f".{name}", path) + files[path_in_variables] = file + + val.value = obj[key] variables[name] = val arg_row += f"{key}: ${name}, " if len(arg_row): @@ -376,21 +492,75 @@ def convert_query_node_gql( sub_node_list += f"... on {gql_ty} {{ " for node in sub_nodes: - sub_node_list += ( - f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables)} " - ) + sub_node_list += f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables, files)} " sub_node_list += "}" out += f" {{ {sub_node_list}}}" elif isinstance(node.sub_nodes, list): sub_node_list = "" for node in node.sub_nodes: sub_node_list += ( - f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables)} " + f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables, files)} " ) out += f" {{ {sub_node_list}}}" return out +class MultiPartForm: + def __init__(self): + self.form_fields: typing.List[typing.Tuple[str, str]] = [] + self.files: typing.List[typing.Tuple[str, File]] = [] + self.boundary = uuid.uuid4().hex.encode("utf-8") + + def add_field(self, name: str, value: str): + self.form_fields.append((name, value)) + + def add_file(self, key, file: File): + self.files.append((key, file)) + + def get_content_type(self): + return f"multipart/form-data; boundary={self.boundary.decode('utf-8')}" + + def _form_data(self, name): + return f'Content-Disposition: form-data; name="{name}"\r\n'.encode("utf-8") + + def _attached_file(self, name, filename): + return f'Content-Disposition: file; name="{name}"; filename="{filename}"\r\n'.encode( + "utf-8" + ) + + def _content_type(self, ct): + return f"Content-Type: {ct}\r\n".encode("utf-8") + + def __bytes__(self): + buffer = io.BytesIO() + boundary = b"--" + self.boundary + b"\r\n" + + for name, value in self.form_fields: + buffer.write(boundary) + buffer.write(self._form_data(name)) + buffer.write(b"\r\n") + buffer.write(value.encode("utf-8")) + buffer.write(b"\r\n") + + for key, file in self.files: + mimetype = ( + file.mimetype + or mimetypes.guess_type(file.name)[0] + or "application/octet-stream" + ) + + buffer.write(boundary) + buffer.write(self._attached_file(key, file.name)) + buffer.write(self._content_type(mimetype)) + buffer.write(b"\r\n") + buffer.write(file.content) + buffer.write(b"\r\n") + + buffer.write(b"--" + self.boundary + b"--\r\n") + + return buffer.getvalue() + + class GraphQLTransportBase: def __init__( self, @@ -409,10 +579,13 @@ def build_gql( name: str = "", ): variables: typing.Dict[str, NodeArgValue] = {} + files: typing.Dict[str, File] = {} root_nodes = "" for key, node in query.items(): - fixed_node = SelectNode(node.node_name, key, node.args, node.sub_nodes) - root_nodes += f" {convert_query_node_gql(self.ty_to_gql_ty_map, fixed_node, variables)}\n" + fixed_node = SelectNode( + node.node_name, key, node.args, node.sub_nodes, node.files + ) + root_nodes += f" {convert_query_node_gql(self.ty_to_gql_ty_map, fixed_node, variables, files)}\n" args_row = "" for key, val in variables.items(): args_row += f"${key}: {self.ty_to_gql_ty_map[val.type_name]}, " @@ -423,30 +596,44 @@ def build_gql( doc = f"{ty} {name}{args_row} {{\n{root_nodes}}}" variables = {key: val.value for key, val in variables.items()} # print(doc, variables) - return (doc, variables) + return (doc, variables, files) def build_req( self, doc: str, variables: typing.Dict[str, typing.Any], opts: typing.Optional[GraphQLTransportOptions] = None, + files: typing.Dict[str, File] = {}, ): headers = {} headers.update(self.opts.headers) if opts: headers.update(opts.headers) - headers.update( - { - "accept": "application/json", - "content-type": "application/json", - } - ) - data = json.dumps({"query": doc, "variables": variables}).encode("utf-8") + headers.update({"accept": "application/json"}) + + body = json.dumps({"query": doc, "variables": variables}) + + if len(files) > 0: + form_data = MultiPartForm() + form_data.add_field("operations", body) + map = {} + + for idx, (path, file) in enumerate(files.items()): + map[idx] = ["variables" + path] + form_data.add_file(f"{idx}", file) + + form_data.add_field("map", json.dumps(map)) + headers.update({"Content-type": form_data.get_content_type()}) + body = bytes(form_data) + else: + headers.update({"Content-type": "application/json"}) + body = body.encode("utf-8") + return GraphQLRequest( addr=self.addr, method="POST", headers=headers, - body=data, + body=body, ) def handle_response(self, res: GraphQLResponse): @@ -466,8 +653,9 @@ def fetch( doc: str, variables: typing.Dict[str, typing.Any], opts: typing.Optional[GraphQLTransportOptions], + files: typing.Dict[str, File] = {}, ): - req = self.build_req(doc, variables, opts) + req = self.build_req(doc, variables, opts, files) try: with request.urlopen( request.Request( @@ -501,7 +689,7 @@ def query( opts: typing.Optional[GraphQLTransportOptions] = None, name: str = "", ) -> typing.Dict[str, Out]: - doc, variables = self.build_gql( + doc, variables, _ = self.build_gql( {key: val for key, val in inp.items()}, "query", name ) return self.fetch(doc, variables, opts) @@ -512,10 +700,10 @@ def mutation( opts: typing.Optional[GraphQLTransportOptions] = None, name: str = "", ) -> typing.Dict[str, Out]: - doc, variables = self.build_gql( + doc, variables, files = self.build_gql( {key: val for key, val in inp.items()}, "mutation", name ) - return self.fetch(doc, variables, opts) + return self.fetch(doc, variables, opts, files) def prepare_query( self, @@ -541,10 +729,11 @@ def __init__( name: str = "", ): dry_run_node = fun(PreparedArgs()) - doc, variables = transport.build_gql(dry_run_node, ty, name) + doc, variables, files = transport.build_gql(dry_run_node, ty, name) self.doc = doc self._mapping = variables self.transport = transport + self.files = files def resolve_vars( self, @@ -825,7 +1014,7 @@ def get_user(self, select: UserSelections) -> QueryNode[User]: {"getUser": NodeDescs.RootGetUserFn}, "$q" )[0] - return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def get_posts(self, select: PostSelections) -> QueryNode[Post]: node = selection_to_nodes( @@ -833,7 +1022,7 @@ def get_posts(self, select: PostSelections) -> QueryNode[Post]: {"getPosts": NodeDescs.RootGetPostsFn}, "$q" )[0] - return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def scalar_no_args(self) -> QueryNode[PostSlugString]: node = selection_to_nodes( @@ -841,7 +1030,7 @@ def scalar_no_args(self) -> QueryNode[PostSlugString]: {"scalarNoArgs": NodeDescs.RootScalarNoArgsFn}, "$q" )[0] - return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def scalar_args(self, args: typing.Union[Post, PlaceholderArgs]) -> MutationNode[PostSlugString]: node = selection_to_nodes( @@ -849,7 +1038,7 @@ def scalar_args(self, args: typing.Union[Post, PlaceholderArgs]) -> MutationNode {"scalarArgs": NodeDescs.RootScalarArgsFn}, "$q" )[0] - return MutationNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return MutationNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def composite_no_args(self, select: PostSelections) -> MutationNode[Post]: node = selection_to_nodes( @@ -857,7 +1046,7 @@ def composite_no_args(self, select: PostSelections) -> MutationNode[Post]: {"compositeNoArgs": NodeDescs.RootCompositeNoArgsFn}, "$q" )[0] - return MutationNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return MutationNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def composite_args(self, args: typing.Union[RootCompositeArgsFnInput, PlaceholderArgs], select: PostSelections) -> MutationNode[Post]: node = selection_to_nodes( @@ -865,7 +1054,7 @@ def composite_args(self, args: typing.Union[RootCompositeArgsFnInput, Placeholde {"compositeArgs": NodeDescs.RootCompositeArgsFn}, "$q" )[0] - return MutationNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return MutationNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def scalar_union(self, args: typing.Union[RootCompositeArgsFnInput, PlaceholderArgs]) -> QueryNode[RootScalarUnionFnOutput]: node = selection_to_nodes( @@ -873,7 +1062,7 @@ def scalar_union(self, args: typing.Union[RootCompositeArgsFnInput, PlaceholderA {"scalarUnion": NodeDescs.RootScalarUnionFn}, "$q" )[0] - return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def composite_union(self, args: typing.Union[RootCompositeArgsFnInput, PlaceholderArgs], select: RootCompositeUnionFnOutputSelections) -> QueryNode[RootCompositeUnionFnOutput]: node = selection_to_nodes( @@ -881,7 +1070,7 @@ def composite_union(self, args: typing.Union[RootCompositeArgsFnInput, Placehold {"compositeUnion": NodeDescs.RootCompositeUnionFn}, "$q" )[0] - return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) def mixed_union(self, args: typing.Union[RootCompositeArgsFnInput, PlaceholderArgs], select: RootMixedUnionFnOutputSelections) -> QueryNode[RootMixedUnionFnOutput]: node = selection_to_nodes( @@ -889,4 +1078,4 @@ def mixed_union(self, args: typing.Union[RootCompositeArgsFnInput, PlaceholderAr {"mixedUnion": NodeDescs.RootMixedUnionFn}, "$q" )[0] - return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes) + return QueryNode(node.node_name, node.instance_name, node.args, node.sub_nodes, node.files) diff --git a/tests/metagen/typegraphs/sample/py_upload/client.py b/tests/metagen/typegraphs/sample/py_upload/client.py new file mode 100644 index 0000000000..ec62c7a73c --- /dev/null +++ b/tests/metagen/typegraphs/sample/py_upload/client.py @@ -0,0 +1,894 @@ +# This file was @generated by metagen and is intended +# to be generated again on subsequent metagen runs. + +import io +import re +import uuid +import typing +import dataclasses as dc +import json +import urllib +import urllib.request as request +import urllib.error +import http.client as http_c +import mimetypes + + +def selection_to_nodes( + selection: "SelectionErased", + metas: typing.Dict[str, "NodeMetaFn"], + parent_path: str, +) -> typing.List["SelectNode[typing.Any]"]: + out = [] + flags = selection.get("_") + if flags is not None and not isinstance(flags, SelectionFlags): + raise Exception( + f"selection field '_' should be of type SelectionFlags but found {type(flags)}" + ) + select_all = True if flags is not None and flags.select_all else False + found_nodes = set(selection.keys()) + for node_name, meta_fn in metas.items(): + found_nodes.discard(node_name) + + node_selection = selection.get(node_name) + if node_selection is False or (node_selection is None and not select_all): + # this node was not selected + continue + + meta = meta_fn() + + # we splat out any aliasing of nodes here + node_instances = ( + [(key, val) for key, val in node_selection.items.items()] + if isinstance(node_selection, Alias) + else [(node_name, node_selection)] + ) + + for instance_name, instance_selection in node_instances: + # print(parent_path, instance_selection, meta.sub_nodes, instance_selection, flags) + if instance_selection is False or ( + instance_selection is None and not select_all + ): + # this instance was not selected + continue + if isinstance(instance_selection, Alias): + raise Exception( + f"nested Alias node discovered at {parent_path}.{instance_name}" + ) + + instance_args: typing.Optional[NodeArgs] = None + if meta.arg_types is not None: + arg = instance_selection + + if isinstance(arg, tuple): + arg = arg[0] + + # arg types are always TypedDicts + if not isinstance(arg, dict): + raise Exception( + f"node at {parent_path}.{instance_name} is a node that " + + "requires arguments " + + f"but detected argument is typeof {type(arg)}" + ) + + # convert arg dict to NodeArgs + expected_args = {key: val for key, val in meta.arg_types.items()} + instance_args = {} + for key, val in arg.items(): + ty_name = expected_args.pop(key) + if ty_name is None: + raise Exception( + f"unexpected argument ${key} at {parent_path}.{instance_name}" + ) + instance_args[key] = NodeArgValue(ty_name, val) + + sub_nodes: SubNodes = None + if meta.sub_nodes is not None or meta.variants is not None: + sub_selections = instance_selection + + # if node requires both selection and arg, it must be + # a CompositeSelectArgs which is a tuple selection + if meta.arg_types is not None: + if not isinstance(sub_selections, tuple): + raise Exception( + f"node at {parent_path}.{instance_name} is a composite " + + "that requires an argument object " + + f"but selection is typeof {type(sub_selections)}" + ) + sub_selections = sub_selections[1] + + # we got a tuple selection when this shouldn't be the case + elif isinstance(sub_selections, tuple): + raise Exception( + f"node at {parent_path}.{instance_name} " + + "is a composite that takes no arguments " + + f"but selection is typeof {type(instance_selection)}", + ) + + # flags are recursive for any subnode that's not specified + if sub_selections is None: + sub_selections = {"_": flags} + + # selection types are always TypedDicts as well + if not isinstance(sub_selections, dict): + raise Exception( + f"node at {parent_path}.{instance_name} " + + "is a no argument composite but first element of " + + f"selection is typeof {type(instance_selection)}", + ) + + if meta.sub_nodes is not None: + if meta.variants is not None: + raise Exception( + "unreachable: union/either NodeMetas can't have subnodes" + ) + sub_nodes = selection_to_nodes( + typing.cast("SelectionErased", sub_selections), + meta.sub_nodes, + f"{parent_path}.{instance_name}", + ) + else: + assert meta.variants is not None + union_selections: typing.Dict[str, typing.List[SelectNode]] = {} + for variant_ty, variant_meta in meta.variants.items(): + variant_meta = variant_meta() + + # this union member is a scalar + if variant_meta.sub_nodes is None: + continue + + variant_select = sub_selections.pop(variant_ty, None) + nodes = ( + selection_to_nodes( + typing.cast("SelectionErased", variant_select), + variant_meta.sub_nodes, + f"{parent_path}.{instance_name}.variant({variant_ty})", + ) + if variant_select is not None + else [] + ) + + # we select __typename for each variant + # even if the user is not interested in the variant + nodes.append( + SelectNode( + node_name="__typename", + instance_name="__typename", + args=None, + sub_nodes=None, + files=None, + ) + ) + + union_selections[variant_ty] = nodes + + if len(sub_selections) > 0: + raise Exception( + f"node at {parent_path}.{instance_name} " + + "has none of the variants called " + + str(sub_selections.keys()), + ) + sub_nodes = union_selections + + node = SelectNode( + node_name, instance_name, instance_args, sub_nodes, meta.input_files + ) + out.append(node) + + found_nodes.discard("_") + if len(found_nodes) > 0: + raise Exception( + f"unexpected nodes found in selection set at {parent_path}: {found_nodes}", + ) + return out + + +# +# --- --- Util types --- --- # +# + +Out = typing.TypeVar("Out", covariant=True) + +T = typing.TypeVar("T") + +ArgT = typing.TypeVar("ArgT", bound=typing.Mapping[str, typing.Any]) +SelectionT = typing.TypeVar("SelectionT") + + +@dc.dataclass +class File: + content: bytes + name: str + mimetype: typing.Optional[str] = None + + +# +# --- --- Graph node types --- --- # +# + + +SubNodes = typing.Union[ + None, + # atomic composite + typing.List["SelectNode"], + # union/either selection + typing.Dict[str, typing.List["SelectNode"]], +] + +TypePath = typing.List[typing.Union[typing.Literal["?"], typing.Literal["[]"], str]] +ValuePath = typing.List[typing.Union[typing.Literal[""], str]] + + +@dc.dataclass +class SelectNode(typing.Generic[Out]): + node_name: str + instance_name: str + args: typing.Optional["NodeArgs"] + sub_nodes: SubNodes + files: typing.Optional[typing.List[TypePath]] + + +@dc.dataclass +class QueryNode(SelectNode[Out]): + pass + + +@dc.dataclass +class MutationNode(SelectNode[Out]): + pass + + +NodeMetaFn = typing.Callable[[], "NodeMeta"] + + +@dc.dataclass +class NodeMeta: + sub_nodes: typing.Optional[typing.Dict[str, NodeMetaFn]] = None + variants: typing.Optional[typing.Dict[str, NodeMetaFn]] = None + arg_types: typing.Optional[typing.Dict[str, str]] = None + input_files: typing.Optional[typing.List[TypePath]] = None + + +class FileExtractor: + def __init__(self): + self.path: TypePath = [] + self.current_path: ValuePath = [] + self.result: typing.Dict[str, File] = {} + + def extract_from_value(self, value: typing.Any): + next_segment = self.path[len(self.current_path)] + + if next_segment == "?": + if value is None: + return + self.current_path.append("") + self.extract_from_value(value) + self.current_path.pop() + return + + if next_segment == "[]": + if not isinstance(value, list): + raise Exception(f"Expected array at {self.format_path()}") + + for idx in range(len(value)): + self.current_path.append(f"[{idx}]") + self.extract_from_array(value, idx) + self.current_path.pop() + return + + if next_segment.startswith("."): + if not isinstance(value, dict): + raise Exception(f"Expected dictionary at {self.format_path()}") + + self.current_path.append(next_segment) + self.extract_from_object(value, next_segment[1:]) + self.current_path.pop() + return + + def extract_from_array(self, parent: typing.List[typing.Any], idx: int): + value = parent[idx] + + if len(self.current_path) == len(self.path): + if isinstance(value, File): + self.result[self.format_path()] = value + parent[idx] = None + return + + raise Exception(f"Expected File at {self.format_path()}") + + self.extract_from_value(value) + + def extract_from_object(self, parent: typing.Dict[str, typing.Any], key: str): + value = parent.get(key) + + if len(self.current_path) == len(self.path): + if isinstance(value, File): + self.result[self.format_path()] = value + parent[key] = None + return + + raise Exception(f"Expected File at {self.format_path()}") + + self.extract_from_value(value) + + def format_path(self): + res = "" + + for path in self.current_path: + res += f".{path[1:-1]}" if path.startswith("[") else path + + return res + + +def extract_files( + key: str, obj: typing.Dict[str, typing.Any], paths: typing.List[TypePath] +): + extractor = FileExtractor() + + for path in paths: + if path[0] and path[0].startswith("." + key): + extractor.current_path = [] + extractor.path = path + extractor.extract_from_value(obj) + + return extractor.result + + +# +# --- --- Argument types --- --- # +# + + +@dc.dataclass +class NodeArgValue: + type_name: str + value: typing.Any + + +NodeArgs = typing.Dict[str, NodeArgValue] + + +class PlaceholderValue(typing.Generic[T]): + def __init__(self, key: str): + self.key = key + + +PlaceholderArgs = typing.Dict[str, PlaceholderValue] + + +class PreparedArgs: + def get(self, key: str) -> PlaceholderValue: + return PlaceholderValue(key) + + +# +# --- --- Selection types --- --- # +# + + +class Alias(typing.Generic[SelectionT]): + """ + Request multiple instances of a single node under different + aliases. + """ + + def __init__(self, **aliases: SelectionT): + self.items = aliases + + +ScalarSelectNoArgs = typing.Union[bool, Alias[typing.Literal[True]], None] +ScalarSelectArgs = typing.Union[ + ArgT, + PlaceholderArgs, + Alias[typing.Union[ArgT, PlaceholderArgs]], + typing.Literal[False], + None, +] +CompositeSelectNoArgs = typing.Union[ + SelectionT, Alias[SelectionT], typing.Literal[False], None +] +CompositeSelectArgs = typing.Union[ + typing.Tuple[typing.Union[ArgT, PlaceholderArgs], SelectionT], + Alias[typing.Tuple[typing.Union[ArgT, PlaceholderArgs], SelectionT]], + typing.Literal[False], + None, +] + + +# FIXME: ideally this would be a TypedDict +# to allow full dict based queries but +# we need to reliably identify SelectionFlags at runtime +# but TypedDicts don't allow instanceof +@dc.dataclass +class SelectionFlags: + select_all: typing.Union[bool, None] = None + + +class Selection(typing.TypedDict, total=False): + _: SelectionFlags + + +SelectionErased = typing.Mapping[ + str, + typing.Union[ + SelectionFlags, + ScalarSelectNoArgs, + ScalarSelectArgs[typing.Mapping[str, typing.Any]], + CompositeSelectNoArgs["SelectionErased"], + # FIXME: should be possible to make SelectionT here `SelectionErased` recursively + # but something breaks + CompositeSelectArgs[typing.Mapping[str, typing.Any], typing.Any], + ], +] + +# +# --- --- GraphQL types --- --- # +# + + +@dc.dataclass +class GraphQLTransportOptions: + headers: typing.Dict[str, str] + + +@dc.dataclass +class GraphQLRequest: + addr: str + method: str + headers: typing.Dict[str, str] + body: bytes + + +@dc.dataclass +class GraphQLResponse: + req: GraphQLRequest + status: int + headers: typing.Dict[str, str] + body: bytes + + +def convert_query_node_gql( + ty_to_gql_ty_map: typing.Dict[str, str], + node: SelectNode, + variables: typing.Dict[str, NodeArgValue], + files: typing.Dict[str, File], +): + out = ( + f"{node.instance_name}: {node.node_name}" + if node.instance_name != node.node_name + else node.node_name + ) + if node.args is not None: + arg_row = "" + for key, val in node.args.items(): + name = f"in{len(variables)}" + obj = {key: val.value} + + if node.files is not None and len(node.files) > 0: + extracted_files = extract_files(key, obj, node.files) + + for path, file in extracted_files.items(): + path_in_variables = re.sub(r"^\.[^.\[]+", f".{name}", path) + files[path_in_variables] = file + + val.value = obj[key] + variables[name] = val + arg_row += f"{key}: ${name}, " + if len(arg_row): + out += f"({arg_row[:-2]})" + + # if it's a dict, it'll be a union selection + if isinstance(node.sub_nodes, dict): + sub_node_list = "" + for variant_ty, sub_nodes in node.sub_nodes.items(): + # fetch the gql variant name so we can do + # type assertions + gql_ty = ty_to_gql_ty_map[variant_ty] + if gql_ty is None: + raise Exception( + f"unreachable: no graphql type found for variant {variant_ty}" + ) + gql_ty = gql_ty.strip("!") + + sub_node_list += f"... on {gql_ty} {{ " + for node in sub_nodes: + sub_node_list += f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables, files)} " + sub_node_list += "}" + out += f" {{ {sub_node_list}}}" + elif isinstance(node.sub_nodes, list): + sub_node_list = "" + for node in node.sub_nodes: + sub_node_list += ( + f"{convert_query_node_gql(ty_to_gql_ty_map, node, variables, files)} " + ) + out += f" {{ {sub_node_list}}}" + return out + + +class MultiPartForm: + def __init__(self): + self.form_fields: typing.List[typing.Tuple[str, str]] = [] + self.files: typing.List[typing.Tuple[str, File]] = [] + self.boundary = uuid.uuid4().hex.encode("utf-8") + + def add_field(self, name: str, value: str): + self.form_fields.append((name, value)) + + def add_file(self, key, file: File): + self.files.append((key, file)) + + def get_content_type(self): + return f"multipart/form-data; boundary={self.boundary.decode('utf-8')}" + + def _form_data(self, name): + return f'Content-Disposition: form-data; name="{name}"\r\n'.encode("utf-8") + + def _attached_file(self, name, filename): + return f'Content-Disposition: file; name="{name}"; filename="{filename}"\r\n'.encode( + "utf-8" + ) + + def _content_type(self, ct): + return f"Content-Type: {ct}\r\n".encode("utf-8") + + def __bytes__(self): + buffer = io.BytesIO() + boundary = b"--" + self.boundary + b"\r\n" + + for name, value in self.form_fields: + buffer.write(boundary) + buffer.write(self._form_data(name)) + buffer.write(b"\r\n") + buffer.write(value.encode("utf-8")) + buffer.write(b"\r\n") + + for key, file in self.files: + mimetype = ( + file.mimetype + or mimetypes.guess_type(file.name)[0] + or "application/octet-stream" + ) + + buffer.write(boundary) + buffer.write(self._attached_file(key, file.name)) + buffer.write(self._content_type(mimetype)) + buffer.write(b"\r\n") + buffer.write(file.content) + buffer.write(b"\r\n") + + buffer.write(b"--" + self.boundary + b"--\r\n") + + return buffer.getvalue() + + +class GraphQLTransportBase: + def __init__( + self, + addr: str, + opts: GraphQLTransportOptions, + ty_to_gql_ty_map: typing.Dict[str, str], + ): + self.addr = addr + self.opts = opts + self.ty_to_gql_ty_map = ty_to_gql_ty_map + + def build_gql( + self, + query: typing.Mapping[str, SelectNode], + ty: typing.Union[typing.Literal["query"], typing.Literal["mutation"]], + name: str = "", + ): + variables: typing.Dict[str, NodeArgValue] = {} + files: typing.Dict[str, File] = {} + root_nodes = "" + for key, node in query.items(): + fixed_node = SelectNode( + node.node_name, key, node.args, node.sub_nodes, node.files + ) + root_nodes += f" {convert_query_node_gql(self.ty_to_gql_ty_map, fixed_node, variables, files)}\n" + args_row = "" + for key, val in variables.items(): + args_row += f"${key}: {self.ty_to_gql_ty_map[val.type_name]}, " + + if len(args_row): + args_row = f"({args_row[:-2]})" + + doc = f"{ty} {name}{args_row} {{\n{root_nodes}}}" + variables = {key: val.value for key, val in variables.items()} + # print(doc, variables) + return (doc, variables, files) + + def build_req( + self, + doc: str, + variables: typing.Dict[str, typing.Any], + opts: typing.Optional[GraphQLTransportOptions] = None, + files: typing.Dict[str, File] = {}, + ): + headers = {} + headers.update(self.opts.headers) + if opts: + headers.update(opts.headers) + headers.update({"accept": "application/json"}) + + body = json.dumps({"query": doc, "variables": variables}) + + if len(files) > 0: + form_data = MultiPartForm() + form_data.add_field("operations", body) + map = {} + + for idx, (path, file) in enumerate(files.items()): + map[idx] = ["variables" + path] + form_data.add_file(f"{idx}", file) + + form_data.add_field("map", json.dumps(map)) + headers.update({"Content-type": form_data.get_content_type()}) + body = bytes(form_data) + else: + headers.update({"Content-type": "application/json"}) + body = body.encode("utf-8") + + return GraphQLRequest( + addr=self.addr, + method="POST", + headers=headers, + body=body, + ) + + def handle_response(self, res: GraphQLResponse): + if res.status != 200: + raise Exception(f"graphql request failed with status {res.status}", res) + if res.headers.get("content-type") != "application/json": + raise Exception("unexpected content-type in graphql response", res) + parsed = json.loads(res.body) + if parsed.get("errors"): + raise Exception("graphql errors in response", parsed) + return parsed["data"] + + +class GraphQLTransportUrlib(GraphQLTransportBase): + def fetch( + self, + doc: str, + variables: typing.Dict[str, typing.Any], + opts: typing.Optional[GraphQLTransportOptions], + files: typing.Dict[str, File] = {}, + ): + req = self.build_req(doc, variables, opts, files) + try: + with request.urlopen( + request.Request( + url=req.addr, method=req.method, headers=req.headers, data=req.body + ) + ) as res: + http_res: http_c.HTTPResponse = res + return self.handle_response( + GraphQLResponse( + req, + status=http_res.status, + body=http_res.read(), + headers={key: val for key, val in http_res.headers.items()}, + ) + ) + except request.HTTPError as res: + return self.handle_response( + GraphQLResponse( + req, + status=res.status or 599, + body=res.read(), + headers={key: val for key, val in res.headers.items()}, + ) + ) + except urllib.error.URLError as err: + raise Exception(f"URL error: {err.reason}") + + def query( + self, + inp: typing.Dict[str, QueryNode[Out]], + opts: typing.Optional[GraphQLTransportOptions] = None, + name: str = "", + ) -> typing.Dict[str, Out]: + doc, variables, _ = self.build_gql( + {key: val for key, val in inp.items()}, "query", name + ) + return self.fetch(doc, variables, opts) + + def mutation( + self, + inp: typing.Dict[str, MutationNode[Out]], + opts: typing.Optional[GraphQLTransportOptions] = None, + name: str = "", + ) -> typing.Dict[str, Out]: + doc, variables, files = self.build_gql( + {key: val for key, val in inp.items()}, "mutation", name + ) + return self.fetch(doc, variables, opts, files) + + def prepare_query( + self, + fun: typing.Callable[[PreparedArgs], typing.Dict[str, QueryNode[Out]]], + name: str = "", + ) -> "PreparedRequestUrlib[Out]": + return PreparedRequestUrlib(self, fun, "query", name) + + def prepare_mutation( + self, + fun: typing.Callable[[PreparedArgs], typing.Dict[str, MutationNode[Out]]], + name: str = "", + ) -> "PreparedRequestUrlib[Out]": + return PreparedRequestUrlib(self, fun, "mutation", name) + + +class PreparedRequestBase(typing.Generic[Out]): + def __init__( + self, + transport: GraphQLTransportBase, + fun: typing.Callable[[PreparedArgs], typing.Mapping[str, SelectNode[Out]]], + ty: typing.Union[typing.Literal["query"], typing.Literal["mutation"]], + name: str = "", + ): + dry_run_node = fun(PreparedArgs()) + doc, variables, files = transport.build_gql(dry_run_node, ty, name) + self.doc = doc + self._mapping = variables + self.transport = transport + self.files = files + + def resolve_vars( + self, + args: typing.Mapping[str, typing.Any], + mappings: typing.Dict[str, typing.Any], + ): + resolved: typing.Dict[str, typing.Any] = {} + for key, val in mappings.items(): + if isinstance(val, PlaceholderValue): + resolved[key] = args[val.key] + elif isinstance(val, dict): + self.resolve_vars(args, val) + else: + resolved[key] = val + return resolved + + +class PreparedRequestUrlib(PreparedRequestBase[Out]): + def __init__( + self, + transport: GraphQLTransportUrlib, + fun: typing.Callable[[PreparedArgs], typing.Mapping[str, SelectNode[Out]]], + ty: typing.Union[typing.Literal["query"], typing.Literal["mutation"]], + name: str = "", + ): + super().__init__(transport, fun, ty, name) + self.transport = transport + + def perform( + self, + args: typing.Mapping[str, typing.Any], + opts: typing.Optional[GraphQLTransportOptions] = None, + ) -> typing.Dict[str, Out]: + resolved_vars = self.resolve_vars(args, self._mapping) + return self.transport.fetch(self.doc, resolved_vars, opts) + + +# +# --- --- QueryGraph types --- --- # +# + + +class QueryGraphBase: + def __init__(self, ty_to_gql_ty_map: typing.Dict[str, str]): + self.ty_to_gql_ty_map = ty_to_gql_ty_map + + def graphql_sync( + self, addr: str, opts: typing.Optional[GraphQLTransportOptions] = None + ): + return GraphQLTransportUrlib( + addr, opts or GraphQLTransportOptions({}), self.ty_to_gql_ty_map + ) + + +# +# --- --- Typegraph types --- --- # +# + + +class NodeDescs: + @staticmethod + def scalar(): + return NodeMeta() + + @staticmethod + def RootUploadFn(): + return_node = NodeDescs.scalar() + return NodeMeta( + sub_nodes=return_node.sub_nodes, + variants=return_node.variants, + arg_types={ + "file": "RootUploadFnInputFileFile", + "path": "RootUploadFnInputPathRootUploadFnInputPathStringOptional", + }, + input_files=[[".file"]], + ) + + @staticmethod + def RootUploadManyFn(): + return_node = NodeDescs.scalar() + return NodeMeta( + sub_nodes=return_node.sub_nodes, + variants=return_node.variants, + arg_types={ + "prefix": "RootUploadManyFnInputPrefixRootUploadFnInputPathStringOptional", + "files": "RootUploadManyFnInputFilesRootUploadFnInputFileFileList", + }, + input_files=[[".files", "[]"]], + ) + + +RootUploadFnInputFileFile = File + +RootUploadFnInputPathString = str + +RootUploadFnInputPathRootUploadFnInputPathStringOptional = typing.Union[ + RootUploadFnInputPathString, None +] + +RootUploadFnInput = typing.TypedDict( + "RootUploadFnInput", + { + "file": RootUploadFnInputFileFile, + "path": RootUploadFnInputPathRootUploadFnInputPathStringOptional, + }, + total=False, +) + +RootUploadManyFnInputPrefixRootUploadFnInputPathStringOptional = typing.Union[ + RootUploadFnInputPathString, None +] + +RootUploadManyFnInputFilesRootUploadFnInputFileFileList = typing.List[ + RootUploadFnInputFileFile +] + +RootUploadManyFnInput = typing.TypedDict( + "RootUploadManyFnInput", + { + "prefix": RootUploadManyFnInputPrefixRootUploadFnInputPathStringOptional, + "files": RootUploadManyFnInputFilesRootUploadFnInputFileFileList, + }, + total=False, +) + +RootUploadFnOutput = bool + + +class QueryGraph(QueryGraphBase): + def __init__(self): + super().__init__( + { + "RootUploadFnInputFileFile": "root_upload_fn_input_file_file!", + "RootUploadFnInputPathRootUploadFnInputPathStringOptional": "String", + "RootUploadManyFnInputPrefixRootUploadFnInputPathStringOptional": "String", + "RootUploadManyFnInputFilesRootUploadFnInputFileFileList": "[root_upload_fn_input_file_file]!", + } + ) + + def upload( + self, args: typing.Union[RootUploadFnInput, PlaceholderArgs] + ) -> MutationNode[RootUploadFnOutput]: + node = selection_to_nodes( + {"upload": args}, {"upload": NodeDescs.RootUploadFn}, "$q" + )[0] + return MutationNode( + node.node_name, node.instance_name, node.args, node.sub_nodes, node.files + ) + + def upload_many( + self, args: typing.Union[RootUploadManyFnInput, PlaceholderArgs] + ) -> MutationNode[RootUploadFnOutput]: + node = selection_to_nodes( + {"uploadMany": args}, {"uploadMany": NodeDescs.RootUploadManyFn}, "$q" + )[0] + return MutationNode( + node.node_name, node.instance_name, node.args, node.sub_nodes, node.files + ) diff --git a/tests/metagen/typegraphs/sample/py_upload/main.py b/tests/metagen/typegraphs/sample/py_upload/main.py new file mode 100644 index 0000000000..6282d3321d --- /dev/null +++ b/tests/metagen/typegraphs/sample/py_upload/main.py @@ -0,0 +1,39 @@ +# Copyright Metatype OÜ, licensed under the Mozilla Public License Version 2.0. +# SPDX-License-Identifier: MPL-2.0 + +import os +import json + +from client import File, QueryGraph + + +port = os.environ.get("TG_PORT") + +api = QueryGraph() +gql = api.graphql_sync(f"http://localhost:{port}/sample") + +res1 = gql.mutation( + { + "upload": api.upload( + { + "file": File(b"Hello", "hello.txt", "text/plain"), + "path": "python/hello.txt", + } + ) + } +) + +res2 = gql.mutation( + { + "uploadMany": api.upload_many( + { + "files": list( + map(lambda i: File(b"Hello", f"{i}", "text/plain"), range(5)) + ), + "prefix": "python/", + } + ) + } +) + +print(json.dumps([res1, res2]))