Skip to content

Commit

Permalink
mypy errors addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
Carson Lam authored and Carson Lam committed Aug 17, 2023
1 parent f732656 commit 1d59f4a
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions src/together/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import posixpath
import urllib.parse
from typing import Optional, Dict, List, Union
from typing import Dict, List, Mapping, Optional, Union, cast

import requests
from tqdm import tqdm
Expand Down Expand Up @@ -46,7 +46,9 @@ def list(self) -> Dict[str, List[Dict[str, Union[str, int]]]]:
return response_json

@classmethod
def check(self, file: str, model: Optional[str] = None) -> Dict[str, Union[str, int]]:
def check(
self, file: str, model: Optional[str] = None
) -> Dict[str, object]:
return check_json(file, model)

@classmethod
Expand All @@ -55,22 +57,21 @@ def upload(
file: str,
check: bool = True,
model: Optional[str] = None,
) -> Dict[str, Union[str, int]]:
) -> Mapping[str, Union[str, int, Mapping]]:
data = {"purpose": "fine-tune", "file_name": os.path.basename(file)}

output_dict = {}

headers = {
"Authorization": f"Bearer {together.api_key}",
"User-Agent": together.user_agent,
}

if check:
report_dict = check_json(file, model)
output_dict["check"] = report_dict
if not report_dict["is_check_passed"]:
print(report_dict)
raise together.FileTypeError("Invalid file supplied. Failed to upload.")
else:
report_dict = {}

session = requests.Session()

Expand Down Expand Up @@ -144,11 +145,16 @@ def upload(
logger.critical(f"Response error raised: {e}")
raise together.ResponseError(e)

output_dict["filename"] = os.path.basename(file)
output_dict["id"] = str(file_id)
output_dict["object"] = "file"
# output_dict["filename"] = os.path.basename(file)
# output_dict["id"] = str(file_id)
# output_dict["object"] = "file"

return output_dict
return {
"filename": os.path.basename(file),
"id": str(file_id),
"object": "file",
"report_dict": report_dict,
}

@classmethod
def delete(self, file_id: str) -> Dict[str, str]:
Expand Down Expand Up @@ -250,7 +256,7 @@ def retrieve_content(self, file_id: str, output: Union[str, None] = None) -> str
return output # this should be null

@classmethod
def save_jsonl(self, data: dict, output_path: str, append: bool = False):
def save_jsonl(self, data: dict, output_path: str, append: bool = False) -> None:
"""
Write list of objects to a JSON lines file.
"""
Expand All @@ -277,25 +283,20 @@ def load_jsonl(self, input_path: str) -> List[Dict[str, str]]:
def check_json(
file: str,
model: Optional[str] = None,
) -> Dict[str, Union[str, int, bool, list, dict]]:
report_dict = {"is_check_passed": True}
) -> Dict[str, object]:

report_dict = {"is_check_passed": True, "model_special_tokens": "we are not yet checking end of sentence tokens for this model"}
num_samples_w_eos_token = 0

model_info_dict = cast(dict,together.model_info_dict)

eos_token = None
if model is not None and model in together.model_info_dict:
if "eos_token" in together.model_info_dict[model]:
eos_token = together.model_info_dict[model]["eos_token"]
report_dict["model_info"] = {
"special_tokens": [
f"the end of sentence token for this model is {eos_token}"
]
}
report_dict["model_info"]["num_samples_w_eos_token"] = 0
else:
report_dict["model_info"] = {
"special_tokens": [
"we are not yet checking end of sentence tokens for this model"
]
}
if model is not None and model in model_info_dict:
if "eos_token" in model_info_dict[model]:
eos_token = model_info_dict[model]["eos_token"]
report_dict[
"model_special_tokens"
] = f"the end of sentence token for this model is {eos_token}"

if not os.path.isfile(file):
report_dict["file_present"] = f"File not found at given file path {file}"
Expand Down Expand Up @@ -350,7 +351,7 @@ def check_json(

elif eos_token:
if eos_token in json_line["text"]:
report_dict["model_info"]["num_samples_w_eos_token"] += 1
num_samples_w_eos_token += 1

# make sure this is outside the for idx, line in enumerate(f): for loop
if idx + 1 < together.min_samples:
Expand All @@ -366,4 +367,6 @@ def check_json(
report_dict["load_json"] = "Could not load JSONL file. Invalid format"
report_dict["is_check_passed"] = False

report_dict["num_samples_w_eos_token"] = num_samples_w_eos_token

return report_dict

0 comments on commit 1d59f4a

Please sign in to comment.