Skip to content

Commit

Permalink
fix: bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
luckasRanarison committed Nov 29, 2024
1 parent 141ba93 commit febe31f
Show file tree
Hide file tree
Showing 6 changed files with 1,344 additions and 211 deletions.
68 changes: 32 additions & 36 deletions src/metagen/src/client_py/static/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,10 @@ class NodeMeta:


class FileExtractor:
path: TypePath = []
current_path: ValuePath = []
result: typing.Dict[str, File] = {}
def __init__(self):
self.path: TypePath = []
self.current_path: ValuePath = []
self.result: typing.Dict[str, File] = {}

def extract_from_value(self, value: typing.Any):
next_segment = self.path[len(self.current_path)]
Expand Down Expand Up @@ -295,7 +296,7 @@ def extract_from_array(self, parent: typing.List[typing.Any], idx: int):
self.extract_from_value(value)

def extract_from_object(self, parent: typing.Dict[str, typing.Any], key: str):
value = parent[key]
value = parent.get(key)

if len(self.current_path) == len(self.path):
if isinstance(value, File):
Expand All @@ -316,13 +317,16 @@ def format_path(self):
return res


def extract_files(obj: typing.Dict[str, typing.Any], paths: typing.List[TypePath]):
def extract_files(
key: str, obj: typing.Dict[str, typing.Any], paths: typing.List[TypePath]
):
extractor = FileExtractor()

for path in paths:
extractor.current_path = []
extractor.path = path
extractor.extract_from_value(obj)
if path[0] and path[0].startswith("." + key):
extractor.current_path = []
extractor.path = path
extractor.extract_from_value(obj)

return extractor.result

Expand Down Expand Up @@ -458,7 +462,7 @@ def convert_query_node_gql(
obj = {key: val.value}

if node.files is not None and len(node.files) > 0:
extracted_files = extract_files(obj, node.files)
extracted_files = extract_files(key, obj, node.files)

for path, file in extracted_files.items():
path_in_variables = re.sub(r"^\.[^.\[]+", f".{name}", path)
Expand Down Expand Up @@ -499,9 +503,10 @@ def convert_query_node_gql(


class MultiPartForm:
form_fields: typing.List[typing.Tuple[str, str]] = []
files: typing.List[typing.Tuple[str, File]] = []
boundary = uuid.uuid4().hex.encode("utf-8")
def __init__(self):
self.form_fields: typing.List[typing.Tuple[str, str]] = []
self.files: typing.List[typing.Tuple[str, File]] = []
self.boundary = uuid.uuid4().hex.encode("utf-8")

def add_field(self, name: str, value: str):
self.form_fields.append((name, value))
Expand All @@ -512,25 +517,16 @@ def add_file(self, key, file: File):
def get_content_type(self):
return f"multipart/form-data; boundary={self.boundary.decode('utf-8')}"

@staticmethod
def _form_data(name):
return (
("Content-Disposition: form-data; " 'name="{}"\r\n')
.format(name)
.encode("utf-8")
)
def _form_data(self, name):
return f'Content-Disposition: form-data; name="{name}"\r\n'.encode("utf-8")

@staticmethod
def _attached_file(name, filename):
return (
("Content-Disposition: file; " 'name="{}"; filename="{}"\r\n')
.format(name, filename)
.encode("utf-8")
def _attached_file(self, name, filename):
return f'Content-Disposition: file; name="{name}"; filename="{filename}"\r\n'.encode(
"utf-8"
)

@staticmethod
def _content_type(ct):
return "Content-Type: {}\r\n".format(ct).encode("utf-8")
def _content_type(self, ct):
return f"Content-Type: {ct}\r\n".encode("utf-8")

def __bytes__(self):
buffer = io.BytesIO()
Expand All @@ -544,20 +540,21 @@ def __bytes__(self):
buffer.write(b"\r\n")

for key, file in self.files:
if file.mimetype is None:
file.mimetype = (
mimetypes.guess_type(file.name)[0] or "application/octet-stream"
)
mimetype = (
file.mimetype
or mimetypes.guess_type(file.name)[0]
or "application/octet-stream"
)

buffer.write(boundary)
buffer.write(self._attached_file(key, file.name))
buffer.write(self._content_type(file.mimetype))
buffer.write(self._content_type(mimetype))
buffer.write(b"\r\n")
buffer.write(file.content)
buffer.write(b"\r\n")

buffer.write(b"--" + self.boundary + b"--\r\n")
print("buffer: ", buffer.getvalue().decode())

return buffer.getvalue()


Expand Down Expand Up @@ -615,7 +612,7 @@ def build_req(

if len(files) > 0:
form_data = MultiPartForm()
form_data.add_field("operation", body)
form_data.add_field("operations", body)
map = {}

for idx, (path, file) in enumerate(files.items()):
Expand All @@ -624,7 +621,6 @@ def build_req(

form_data.add_field("map", json.dumps(map))
headers.update({"Content-type": form_data.get_content_type()})
# print(f"form_data: {form_data}")
body = bytes(form_data)
else:
headers.update({"Content-type": "application/json"})
Expand Down
Loading

0 comments on commit febe31f

Please sign in to comment.