Skip to content

Commit

Permalink
update mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
boy-hack committed Apr 16, 2024
1 parent 5f76b18 commit 62270d1
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions dataset/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, tokenizer, micro_batch_size, max_length, checkpoint_step=0, d
self.tokenizer = tokenizer
self.index = checkpoint_step
self.data = []
self.data2 = []
stop_token = "<|end_of_turn|>"
self.stop_token_ids = tokenizer.encode(stop_token, add_special_tokens=False)

Expand All @@ -62,8 +63,10 @@ def __init__(self, tokenizer, micro_batch_size, max_length, checkpoint_step=0, d
"input_ids": input_ids,
"labels": labels,
})
for item in self.get_data2():
self.data2.append(item)

def get_data(self):
def get_data2(self):
max_length = self.max_length * self.micro_batch_size
g_input_ids = []
g_labels_ids = []
Expand All @@ -90,9 +93,13 @@ def get_data(self):
a2 = torch.LongTensor(np.asarray(g_labels_ids).reshape(self.micro_batch_size, self.max_length))
yield dict(input_ids=a1, labels=a2)

def get_data(self):
for item in self.data2:
yield item

def __len__(self):
# 只训练前xx条数据
return len(self.data)
return len(self.data2)


if __name__ == '__main__':
Expand All @@ -115,3 +122,4 @@ def __len__(self):
engine = DataEngine(tokenizer, 1, 8192, 0, "D:\数据集\data-index\jibei-清洗\sft-ai\cc.json")
for item in engine.get_data():
print(item)
print(engine.__len__())

0 comments on commit 62270d1

Please sign in to comment.