Skip to content

Commit

Permalink
Refactoring, trying to load more datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
jopokemine committed May 25, 2021
1 parent c928684 commit 48be053
Showing 1 changed file with 60 additions and 44 deletions.
104 changes: 60 additions & 44 deletions chatbot/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,15 @@
############################################

def load_files(*filepaths, open_func=open, line_eval_func=None):
# open_func allows for different open funcitons, in case the built-in open() funciton is not enough
# line_eval_func is optional, and allows some evaluation before return. Mostly used for JSON files, where json.loads() is needed
for file in filepaths:
print(f" Loading {file.split('/')[-1]}...")
with open_func(file) as f:
for line in f:
yield line if line_eval_func is None else line_eval_func(line)


def load_tsv_files(*filepaths, delimiter=','):
for file in filepaths:
with open(file) as f:
read_csv = csv.reader(f, delimiter=delimiter)
for line in read_csv:
yield line


def load_csv_files(*filepaths, delimiter=','):
for file in filepaths:
print(f" Loading {file.split('/')[-1]}...")
Expand All @@ -44,6 +38,14 @@ def load_csv_files(*filepaths, delimiter=','):
yield row


def load_tsv_files(*filepaths, delimiter=','):
for file in filepaths:
with open(file) as f:
read_csv = csv.reader(f, delimiter=delimiter)
for line in read_csv:
yield line


def write_pairs(datafile):
def decorator(function):
def wrapper(*args, **kwargs):
Expand All @@ -58,9 +60,10 @@ def wrapper(*args, **kwargs):
while True:
try:
pair = next(pair_iter)
pair = [s.strip().replace('\n', '').replace('\t', '') for s in pair]
writer.writerow(pair)
except UnicodeEncodeError:
continue
continue # Ignore Unicode characters
except StopIteration:
break # Have reached end of iterator, stop.
return wrapper
Expand Down Expand Up @@ -121,61 +124,68 @@ def format_multiple_answer_amazon_data(line_it):
# Convai Dataset #
############################################

@write_pairs(os.path.join(DATA_DIR, "formatted_lines_convai.txt"))
def load_convai_dataset():
# TODO: Finish
print("Loading Convai dataset...")
_, _, filenames = next(os.walk(data['convai']))
filepaths = [os.path.join(data['convai'], f) for f in filenames]
datafiles = filter(lambda f: 'data' in f, filepaths)
lines = load_files(*datafiles, line_eval_func=json.load)
line = next(lines)
for i in range(500):
for j in range(len(line[i]['dialog'])):
print(f"CONVO {i}: {line[i]['dialog'][j]['sender']}: {line[i]['dialog'][j]['text']}")
# print(line[1])
# print(line[0]['dialog'][1])
# print(line[0]['dialog'][2])
# print(line[0]['dialog'][3])
# print(line[0]['dialog'][4])
# print(line[0]['dialog'][5])
# print(line[0]['dialog'][6])
# print(next(lines)[0]['dialog'][2])
# print(line[:20])
# print(eval(next(lines)[1:-2]))
datafiles = filter(lambda f: 'data_' in f, filepaths)
lines = load_files(*datafiles, line_eval_func=json.loads)
while True:
try:
line = next(lines)
except StopIteration:
break
previous = []
for i in range(len(line)):
previous = [] # Empty previous answers, to stop conflicts
for j in range(len(line[i]['dialog'])):
if previous != []:
if previous[0] == line[i]['dialog'][j]['sender']:
previous.append(f"{previous[1]} {line[i]['dialog'][j]['text']}")
else:
for msg in previous[1:-1]:
yield [msg, line[i]['dialog'][j]['text']]
previous = []
else:
previous = [line[i]['dialog'][j]['sender'], line[i]['dialog'][j]['text']]


############################################
# Squad Train Dataset #
############################################

@write_pairs(os.path.join(DATA_DIR, "formatted_lines_squad.txt"))
def load_squad_train_dataset():
# FIXME
print("Loading Squad Train dataset")
_, _, filenames = next(os.walk(data['squad']))
objs = load_files(*[os.path.join(data['squad'], f) for f in filenames])
print(next(objs))
objs = load_files(*[os.path.join(data['squad'], f) for f in filenames], line_eval_func=json.loads)
obj = next(objs) # only one line, so only need to call this once.
for dataobj in obj['data']:
for paragraph in dataobj['paragraphs']:
for qa in paragraph['qas']:
for ans in qa['answers']:
yield [qa['question'], ans['text']]


############################################
# Opensubtitles Dataset #
############################################

# @write_pairs(os.path.join(DATA_DIR, "formatted_lines_opensubtitles.txt"))
@write_pairs(os.path.join(DATA_DIR, "formatted_lines_opensubtitles.txt"))
def load_opensubtitles_dataset():
# TODO
print("Loading Opensubtitles dataset...")
_, _, filenames = next(os.walk(data['opensubtitles']))
filepaths = [os.path.join(data['opensubtitles'], f) for f in filenames]
datafiles = filter(lambda f: '.gz' not in f, filepaths)
lines = load_files(*datafiles)
i = 0
while i < 50:
while True:
try:
line = next(lines)
line1 = next(lines)
line2 = next(lines)
except StopIteration:
break
print(line)
i += 1
yield [line1, line2]


############################################
Expand Down Expand Up @@ -305,23 +315,29 @@ def load_twitter_dataset():
# Reddit Dataset #
############################################

def load_reddit_dataset(path_to_datafiles):
def load_reddit_dataset():
# TODO
print("Loading Reddit dataset...")
_, _, filenames = next(os.walk(path_to_datafiles))
files = [os.path.join(path_to_datafiles, f) for f in filenames]
_, _, filenames = next(os.walk(data['reddit']))
files = [os.path.join(data['reddit'], f) for f in filenames]
datafiles = filter(lambda f: '.gz' in f, files)
lines = load_files(*datafiles, open_func=gzip.open, line_eval_func=json.loads)
for _ in range(2):
for _ in range(10):
try:
line = next(lines)
except StopIteration:
None
print(line.keys(), end="\n\n\n")
break
print(line['title'])
print('=========================')
print(line['selftext'])
print('-------------------------')
# print(line.keys(), end="\n\n\n")


def load_reddit_txt(path_to_datafiles):
def load_reddit_txt():
# TODO
print("Loading Reddit dataset...")
datafile = os.path.join(path_to_datafiles, "RS_2011-01")
datafile = os.path.join(data['reddit'], "RS_2011-01")
lines = load_files(datafile, line_eval_func=json.loads)
for _ in range(2):
try:
Expand Down

0 comments on commit 48be053

Please sign in to comment.