Skip to content

Commit

Permalink
Merge pull request #1 from vTuanpham/feat/Dialogs_config
Browse files Browse the repository at this point in the history
Feat/dialogs config
  • Loading branch information
vTuanpham authored Dec 10, 2023
2 parents 55d3193 + 3849032 commit 1edb2de
Show file tree
Hide file tree
Showing 18 changed files with 515 additions and 94 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_translate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches:
- main
- dev
- feat/*
jobs:
test:
runs-on: ubuntu-latest
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
#.idea/

*.json
.idea/
.idea/
*.Identifier
3 changes: 2 additions & 1 deletion configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base_config import BaseConfig
from .qa_config import QAConfig
from .qa_config import QAConfig
from .dialogs_config import DialogsConfig
101 changes: 101 additions & 0 deletions configs/dialogs_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import sys
sys.path.insert(0,r'./')
import pprint
from typing import List, Dict
from dataclasses import dataclass, asdict, fields


@dataclass
class DialogsConfig:
"""
A single training/test example for conversation config.
"""
qas_id: str
system_prompt: str

user_prompts: list

agent_responses: list = None

answer_lengths: List[int] = None
prompt_lengths: List[int] = None

def __post_init__(self) -> None:
# Post validate
self.prompt_lengths = [len(prompt) for prompt in self.user_prompts]
self.answer_lengths = [len(answer) for answer in self.agent_responses]

def __str__(self) -> str:
return self.__repr__

@staticmethod
def intersect_lists(list1, list2):
intersected = []
min_length = min(len(list1), len(list2))

for i in range(min_length):
intersected.append(list1[i])
intersected.append(list2[i])

# Add remaining elements if any list is longer
if len(list1) > len(list2):
intersected.extend(list1[min_length:])
elif len(list2) > len(list1):
intersected.extend(list2[min_length:])

return intersected

@property
def __repr__(self) -> str:
s = ""
s += f"\n Question id: {self.qas_id}"
s += f"\n System prompt: {self.system_prompt}"
s += f"\n Dialogs: \n"

if self.user_prompts and self.agent_responses:
final_dialogs = self.intersect_lists(self.user_prompts, self.agent_responses)
final_dialogs_length = self.intersect_lists(self.prompt_lengths, self.answer_lengths)
for idx, (dialog, length) in enumerate(zip(final_dialogs, final_dialogs_length)):
s += f"Dialog {idx}: {dialog} \n"
s += f"Dialog {idx} length: {length}\n"

return s

@property
def get_dict(self) -> Dict:
return asdict(self)

@staticmethod
def get_keys() -> List[str]:
all_fields = fields(DialogsConfig)
return [v.name for v in all_fields]

@property
def get_dict_str(self, indent: int=4) -> None:
pp = pprint.PrettyPrinter(indent=indent)
pp.pprint(self.get_dict)


if __name__ == "__main__":
example_dialog = {"qas_id": 10,
"system_prompt": "You are an AI assistant, help as much as you can",
"user_prompts": ["Tell me a bit about AI", "How does AI learn"],
"agent_responses": ["Artificial Intelligence (AI) is a broad field focusing on creating systems or machines that can perform tasks that typically require human intelligence. It encompasses various subfields like machine learning, natural language processing, computer vision, robotics, and more. AI aims to simulate human cognitive functions, such as learning, problem-solving, perception, reasoning, and decision-making.",
'''AI learning primarily occurs through machine learning algorithms. There are a few key ways in which AI learns:
Supervised Learning: This method involves training AI models on labeled data. The algorithm learns patterns and associations between input data and corresponding output labels. For instance, in image recognition, showing the AI images labeled as "cat" or "dog" helps it learn to differentiate between the two.
Unsupervised Learning: Here, the AI learns from data without labeled outcomes. It looks for patterns, structures, or relationships within the data. Clustering algorithms, for example, can group similar data points together without prior labeling.
Reinforcement Learning: This method involves the AI learning through trial and error by interacting with an environment. It receives feedback in the form of rewards or penalties based on its actions. The AI's goal is to maximize cumulative reward, learning optimal strategies by exploring different actions.
Transfer Learning: This technique involves transferring knowledge learned from one task to another. Models pre-trained on vast amounts of data for one task can be fine-tuned or adapted to perform related tasks more effectively with smaller datasets.
AI learns by adjusting internal parameters or features in its algorithms to minimize errors or differences between predicted and actual outputs. This adjustment process, often referred to as "training," involves feeding the AI large amounts of data, iterating through algorithms, and refining the model's predictions or actions over time.
As AI continues to evolve, researchers are exploring new learning methodologies to enhance its capabilities, making it more adaptable, efficient, and capable of handling complex tasks across various domains.'''],
}
dialog_config_data = DialogsConfig(**example_dialog)
print(dialog_config_data)


15 changes: 6 additions & 9 deletions examples/ELI5/ELI5_10_docs_QAConfigParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@ def __init__(self, file_path: str, output_path: str, target_lang: str="vi",
max_example_per_thread=400, large_chunks_threshold=20000):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
target_config=QAConfig, # The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/qa_config.py
target_fields=['question_text', 'context_list', 'answers_list'], # The data fields to be translated (The fields belong to QAConfig)
do_translate=True,
max_list_length_per_thread=3,
target_lang=target_lang,
max_example_per_thread=max_example_per_thread,
large_chunks_threshold=large_chunks_threshold)
self.max_ctxs = 3
self.max_answers = 2

# The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/qa_config.py
self.target_config = QAConfig

# The data fields to be translated (The fields belong to QAConfig)
self.target_fields = ['question_text', 'context_list', 'answers_list']
self.max_ctxs = 5
self.max_answers = 3

# Read function must assign data that has been read to self.data_read
def read(self) -> None:
Expand Down
13 changes: 5 additions & 8 deletions examples/ELI5/ELI5_10docs_Parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@ def __init__(self, file_path: str, output_path: str, target_lang: str="vi",
max_example_per_thread=400, large_chunks_threshold=20000):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
target_config=BaseConfig, # The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
target_fields=['question_text', 'orig_answer_texts'], # The data fields to be translated (The fields belong to BaseConfig)
do_translate=True,
target_lang=target_lang,
max_example_per_thread=max_example_per_thread,
large_chunks_threshold=large_chunks_threshold)
self.max_ctxs = 5

# The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
self.target_config = BaseConfig

# The data fields to be translated (The fields belong to BaseConfig)
self.target_fields = ['question_text', 'orig_answer_texts']
self.max_ctxs = 5

# Read function must assign data that has been read to self.data_read
def read(self) -> None:
Expand Down Expand Up @@ -131,7 +128,7 @@ def convert(self) -> None:
r"examples/ELI5",
max_example_per_thread=100,
large_chunks_threshold=1000,
target_lang="ko")
target_lang="ru")
eli5_val_parser.read()
eli5_val_parser.convert()
eli5_val_parser.save
10 changes: 3 additions & 7 deletions examples/OpenOrca/OpenOrca_Parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@ class OpenOrcaParser(DataParser):
def __init__(self, file_path: str, output_path: str):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
target_config=BaseConfig, # The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
target_fields=['question_text', 'orig_answer_texts'], # The data fields to be translated (The fields belong to BaseConfig)
do_translate=True,
no_translated_code=True)

# The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
self.target_config = BaseConfig

# The data fields to be translated (The fields belong to BaseConfig)
self.target_fields = ['question_text', 'orig_answer_texts']

# Read function must assign data that has been read to self.data_read
def read(self) -> None:
# The read function must call the read function in DataParser class
Expand Down
77 changes: 77 additions & 0 deletions examples/ShareGPTV3/ShareGPTV3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
import random
import sys
sys.path.insert(0,r'./')
from tqdm.auto import tqdm

from configs import DialogsConfig
from translator import DataParser


PARSER_NAME = "ShareGPT_V3"


class ShareGPTV3(DataParser):
def __init__(self, file_path: str, output_path: str, target_lang: str="vi",
max_example_per_thread=300, large_chunks_threshold=20000,
max_list_length_per_thread=3):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
do_translate=True,
target_config=DialogsConfig,
target_fields=['user_prompts', 'agent_responses'],
target_lang=target_lang,
max_example_per_thread=max_example_per_thread,
large_chunks_threshold=large_chunks_threshold,
max_list_length_per_thread=max_list_length_per_thread)

# Read function must assign data that has been read to self.data_read
def read(self) -> None:
# The read function must call the read function in DataParser class
# I just want to be sure that the file path is correct
super(ShareGPTV3, self).read()

with open(self.file_path, encoding='utf-8') as jfile:
json_data = json.load(jfile)

self.data_read = json_data
return None

def convert(self) -> None:
# The convert function must call the convert function in DataParser class
# I just want to be sure the read function has actually assigned the self.data_read
super(ShareGPTV3, self).convert()

data_converted = []
for data in tqdm(self.data_read, desc="Converting data"):
data_dict = {}
data_dict['system_prompt'] = ""
data_dict['qas_id'] = data['id']

user_prompts = []
agent_responses = []
for conversation in data['conversations']:
if conversation["from"] == "human":
user_prompts.append(conversation['value'])
if conversation["from"] == "gpt":
agent_responses.append(conversation['value'])

data_dict['user_prompts'] = user_prompts
data_dict['agent_responses'] = agent_responses

data_dict['prompt_lengths'] = None
data_dict['answer_lengths'] = None
data_converted.append(data_dict)

# Be sure to assign the final data list to self.converted_data
self.converted_data = data_converted[:5000]

return None


if __name__ == '__main__':
share_gpt_v3_parser = ShareGPTV3(r"examples/ShareGPTV3/ShareGPT_V3_unfiltered_cleaned_split.json",
r"examples/ShareGPTV3")
share_gpt_v3_parser.read()
share_gpt_v3_parser.convert()
share_gpt_v3_parser.save
12 changes: 4 additions & 8 deletions examples/TIGER-Lab-MathInstruct/TigerLabMathInstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@ class MathInstruct(DataParser):
def __init__(self, file_path: str, output_path: str):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
target_config=BaseConfig, # The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
target_fields=['question_text', 'orig_answer_texts'], # The data fields to be translated (The fields belong to BaseConfig)
do_translate=True)

# The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
self.target_config = BaseConfig

# The data fields to be translated (The fields belong to BaseConfig)
self.target_fields = ['question_text', 'orig_answer_texts']

# Read function must assign data that has been read to self.data_read
def read(self):
# The read function must call the read function in DataParser class
Expand Down Expand Up @@ -98,7 +94,7 @@ def convert(self):
data_converted.append(data_dict)

# Be sure to assign the final data list to self.converted_data
self.converted_data = data_converted[20000:120000]
self.converted_data = data_converted

pass

Expand Down
10 changes: 3 additions & 7 deletions examples/YahmaAlpaca/AlpacaCleaned_Parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@ class AlpacaCleaned(DataParser):
def __init__(self, file_path: str, output_path: str):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
target_config=BaseConfig, # The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
target_fields=['question_text', 'orig_answer_texts'], # The data fields to be translated (The fields belong to BaseConfig)
do_translate=True,
no_translated_code=True,
target_lang="vi")

# The data config to be validated to check if self implement "convert" function is correct or not,
# you must map the data form to the correct fields of the @dataclass in the configs/base_config.py
self.target_config = BaseConfig

# The data fields to be translated (The fields belong to BaseConfig)
self.target_fields = ['question_text', 'orig_answer_texts']

# Read function must assign data that has been read to self.data_read
def read(self) -> None:
# The read function must call the read function in DataParser class
Expand Down
9 changes: 7 additions & 2 deletions tests/eli5_qaconfig_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import unittest
import warnings

import sys
sys.path.insert(0,r'./')
from datasets import load_dataset

from examples.ELI5.ELI5_10_docs_QAConfigParser import ELI5ValQAConfig
Expand All @@ -26,9 +27,11 @@ def step3(self):
def step4(self):
self.parser.save

self.output_path = os.path.join(self.output_dir, "ELI5_val_QAConfig_translated_ru.json")
self.output_path = os.path.join(self.output_dir, "ELI5_val_QAConfig.json")
self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_QAConfig_translated_ru.json")

self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist")
self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist")

def step5(self):
try:
Expand All @@ -44,6 +47,8 @@ def step6(self):
def step7(self):
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.output_path_translated):
os.remove(self.output_path_translated)

def _steps(self):
for name in dir(self): # dir() result is implicitly sorted
Expand Down
8 changes: 7 additions & 1 deletion tests/eli5_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import unittest
import warnings
import sys
sys.path.insert(0,r'./')

from datasets import load_dataset

Expand All @@ -26,9 +28,11 @@ def step3(self):
def step4(self):
self.parser.save

self.output_path = os.path.join(self.output_dir, "ELI5_val_translated_de.json")
self.output_path = os.path.join(self.output_dir, "ELI5_val.json")
self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_translated_de.json")

self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist")
self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist")

def step5(self):
try:
Expand All @@ -44,6 +48,8 @@ def step6(self):
def step7(self):
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.output_path_translated):
os.remove(self.output_path_translated)

def _steps(self):
for name in dir(self): # dir() result is implicitly sorted
Expand Down
Loading

0 comments on commit 1edb2de

Please sign in to comment.