From 1d59f4a3830a322c5c479c5103af40f9a369c794 Mon Sep 17 00:00:00 2001 From: Carson Lam Date: Wed, 16 Aug 2023 19:02:02 -0700 Subject: [PATCH] mypy errors addressed --- src/together/files.py | 61 +++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/src/together/files.py b/src/together/files.py index 97b1b010..a3ecd17e 100644 --- a/src/together/files.py +++ b/src/together/files.py @@ -2,7 +2,7 @@ import os import posixpath import urllib.parse -from typing import Optional, Dict, List, Union +from typing import Dict, List, Mapping, Optional, Union, cast import requests from tqdm import tqdm @@ -46,7 +46,9 @@ def list(self) -> Dict[str, List[Dict[str, Union[str, int]]]]: return response_json @classmethod - def check(self, file: str, model: Optional[str] = None) -> Dict[str, Union[str, int]]: + def check( + self, file: str, model: Optional[str] = None + ) -> Dict[str, object]: return check_json(file, model) @classmethod @@ -55,11 +57,9 @@ def upload( file: str, check: bool = True, model: Optional[str] = None, - ) -> Dict[str, Union[str, int]]: + ) -> Mapping[str, Union[str, int, Mapping]]: data = {"purpose": "fine-tune", "file_name": os.path.basename(file)} - output_dict = {} - headers = { "Authorization": f"Bearer {together.api_key}", "User-Agent": together.user_agent, @@ -67,10 +67,11 @@ def upload( if check: report_dict = check_json(file, model) - output_dict["check"] = report_dict if not report_dict["is_check_passed"]: print(report_dict) raise together.FileTypeError("Invalid file supplied. Failed to upload.") + else: + report_dict = {} session = requests.Session() @@ -144,11 +145,16 @@ def upload( logger.critical(f"Response error raised: {e}") raise together.ResponseError(e) - output_dict["filename"] = os.path.basename(file) - output_dict["id"] = str(file_id) - output_dict["object"] = "file" + # output_dict["filename"] = os.path.basename(file) + # output_dict["id"] = str(file_id) + # output_dict["object"] = "file" - return output_dict + return { + "filename": os.path.basename(file), + "id": str(file_id), + "object": "file", + "report_dict": report_dict, + } @classmethod def delete(self, file_id: str) -> Dict[str, str]: @@ -250,7 +256,7 @@ def retrieve_content(self, file_id: str, output: Union[str, None] = None) -> str return output # this should be null @classmethod - def save_jsonl(self, data: dict, output_path: str, append: bool = False): + def save_jsonl(self, data: dict, output_path: str, append: bool = False) -> None: """ Write list of objects to a JSON lines file. """ @@ -277,25 +283,20 @@ def load_jsonl(self, input_path: str) -> List[Dict[str, str]]: def check_json( file: str, model: Optional[str] = None, -) -> Dict[str, Union[str, int, bool, list, dict]]: - report_dict = {"is_check_passed": True} +) -> Dict[str, object]: + + report_dict = {"is_check_passed": True, "model_special_tokens": "we are not yet checking end of sentence tokens for this model"} + num_samples_w_eos_token = 0 + + model_info_dict = cast(dict,together.model_info_dict) eos_token = None - if model is not None and model in together.model_info_dict: - if "eos_token" in together.model_info_dict[model]: - eos_token = together.model_info_dict[model]["eos_token"] - report_dict["model_info"] = { - "special_tokens": [ - f"the end of sentence token for this model is {eos_token}" - ] - } - report_dict["model_info"]["num_samples_w_eos_token"] = 0 - else: - report_dict["model_info"] = { - "special_tokens": [ - "we are not yet checking end of sentence tokens for this model" - ] - } + if model is not None and model in model_info_dict: + if "eos_token" in model_info_dict[model]: + eos_token = model_info_dict[model]["eos_token"] + report_dict[ + "model_special_tokens" + ] = f"the end of sentence token for this model is {eos_token}" if not os.path.isfile(file): report_dict["file_present"] = f"File not found at given file path {file}" @@ -350,7 +351,7 @@ def check_json( elif eos_token: if eos_token in json_line["text"]: - report_dict["model_info"]["num_samples_w_eos_token"] += 1 + num_samples_w_eos_token += 1 # make sure this is outside the for idx, line in enumerate(f): for loop if idx + 1 < together.min_samples: @@ -366,4 +367,6 @@ def check_json( report_dict["load_json"] = "Could not load JSONL file. Invalid format" report_dict["is_check_passed"] = False + report_dict["num_samples_w_eos_token"] = num_samples_w_eos_token + return report_dict