-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e98f64c
commit 9ef9da5
Showing
9 changed files
with
754 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import logging | ||
import pathlib | ||
from typing import List, Union, Optional, Literal | ||
from dataclasses import dataclass, field, fields | ||
|
||
@dataclass | ||
class CriterionArguments: | ||
temperature: float = field( | ||
default=0.05, | ||
metadata={ | ||
'help': 'The temperature of InfoNCE loss.' | ||
} | ||
) | ||
|
||
@dataclass | ||
class DataArguments: | ||
path_to_train_data_dir: str = field( | ||
metadata={ | ||
'aliases': '--path-to-train-data-dir', | ||
'required': True, | ||
'help': 'Path to data folder, which should contain "train" as child folder.' | ||
} | ||
) | ||
|
||
path_to_eval_data_dir: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
'aliases': '--path-to-eval-data-dir', | ||
'help': 'Path to data folder, which should contain "eval" as child folder.' | ||
} | ||
) | ||
|
||
train_percentage: float = field( | ||
default=1., | ||
metadata={ | ||
'aliases': '--train-percentage', | ||
'help': 'Percentage of spliting data into train_dataset and eval_dataset' | ||
} | ||
) | ||
|
||
tokenize_on_the_fly: bool = field( | ||
default=False, | ||
metadata={ | ||
'aliases': '--tokenize-on-the-fly', | ||
'help': 'Whether to tokenize the sentences in each iteration.' | ||
} | ||
) | ||
|
||
def __post_init__(self): | ||
assert 0 < self.train_percentage <= 1, 'training_percentage should be within the range (0, 1]' | ||
|
||
@dataclass | ||
class ModelArguments: | ||
model_max_length: int = field( | ||
default=512, | ||
metadata={ | ||
'aliases': ['--max-sequence-len', '--max_sequence_len', '--model-max-length'], | ||
'help': 'Maximum sequence length. Sequences will be right padded (and possibly truncated).' | ||
}, | ||
) | ||
path_to_model_weight: str = field( | ||
default=None, | ||
metadata={'aliases': '--path-to-model-weight'} | ||
) | ||
load_from_pretrained: bool = field(default=True, metadata={'aliases': '--load-from-pretrained'}) | ||
gradient_checkpointing: bool = field(default=True, metadata={'aliases': '--gradient-checkpointing'}) | ||
|
||
@dataclass | ||
class TrainingArguments: | ||
path_to_checkpoint_dir: pathlib.Path = field( | ||
metadata={ | ||
'aliases': '--path-to-checkpoint-dir', | ||
'required': True | ||
} | ||
) | ||
device: str = field(default="cuda") | ||
|
||
lr: float = field(default=0.00002) | ||
epochs: int = field(default=2) | ||
|
||
shuffle: bool = field(default=True) | ||
per_device_train_batch_size: int = field( | ||
default=64, | ||
metadata={ | ||
'aliases': ['--batch-size', '--batch_size', '--per-device-train-batch-size'], | ||
'help': 'The batch size per GPU/TPU core/CPU for training.' | ||
} | ||
) | ||
per_device_eval_batch_size: int = field( | ||
default=32, | ||
metadata={ | ||
'aliases': '--per-device-eval-batch-size', | ||
'help': 'The batch size per GPU/TPU core/CPU for evaluation.' | ||
} | ||
) | ||
log_level: str = field( | ||
default='INFO', | ||
metadata={ | ||
'aliases': '--log-level', | ||
'help': f'Set logging level. Choices=[{"|".join(logging._nameToLevel.keys())}]' | ||
} | ||
) | ||
log_interval: int = field( | ||
default=10, | ||
metadata={'aliases': '--log-interval'}, | ||
) | ||
eval_interval: int = field( | ||
default=50, | ||
metadata={ | ||
'aliases': '--eval-interval', | ||
'help': 'Do evaluation every eval_interval steps if eval_strategy is steps.' | ||
}, | ||
) | ||
|
||
random_seed: int = field( | ||
default=42, | ||
metadata={'aliases': '--random-seed'} | ||
) | ||
|
||
def __post_init__(self): | ||
self.log_level = logging._nameToLevel[self.log_level.upper()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
class InfoNCE: | ||
def __init__(self, criterion_args, device="cuda"): | ||
self.device = device | ||
|
||
self.temperature = criterion_args.temperature | ||
|
||
def __call__(self, x, auxiliary_data): | ||
step_size = 3 if auxiliary_data["has_negative_sample"] else 2 | ||
query_x = x[0::step_size] | ||
positive_x = x[1::step_size] | ||
if auxiliary_data["has_negative_sample"]: | ||
negative_x = x[2::step_size] | ||
|
||
positive_similarity = F.cosine_similarity(query_x, positive_x).unsqueeze(-1) | ||
positive_negative_similarity = F.cosine_similarity( | ||
query_x.unsqueeze(0), positive_x.unsqueeze(1), -1 | ||
) | ||
label_mask = ~torch.eye(positive_negative_similarity.shape[0], device=self.device, dtype=torch.bool) | ||
positive_negative_similarity = positive_negative_similarity[label_mask].reshape(query_x.size(0), -1) | ||
if auxiliary_data["has_negative_sample"]: | ||
negative_similarity = F.cosine_similarity( | ||
query_x.unsqueeze(0), negative_x.unsqueeze(1), -1 | ||
) | ||
positive_negative_similarity = torch.cat([positive_negative_similarity, negative_similarity], -1) | ||
all_similarity = torch.cat([positive_similarity, positive_negative_similarity], -1) | ||
labels = torch.zeros(all_similarity.size(0), dtype=torch.long, device=self.device) | ||
|
||
loss = F.cross_entropy(all_similarity / self.temperature, labels) | ||
loss = loss.mean() | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import collections | ||
import os | ||
from typing import Dict | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from .utils import load_json | ||
|
||
class ContrastDataset: | ||
""" | ||
Data format: | ||
``` | ||
[ | ||
[sentence_1, similar_sentence_1], | ||
[sentence_2, similar_sentence_2], | ||
[sentence_3, similar_sentence_3], | ||
] | ||
``` | ||
or | ||
``` | ||
[ | ||
[sentence_1, similar_sentence_1, hard_negative_sentence_1], | ||
[sentence_2, similar_sentence_2, hard_negative_sentence_2], | ||
[sentence_3, similar_sentence_3, hard_negative_sentence_3], | ||
] | ||
``` | ||
""" | ||
def __init__(self, raw_data, tokenizer, device, tokenize_on_the_fly=False): | ||
self.tokenizer = tokenizer | ||
|
||
self.raw_data_length = len(raw_data) | ||
self.device = device | ||
|
||
self.has_negative_sample = len(raw_data[0]) == 3 if len(raw_data) > 0 else False | ||
self.tokenize_on_the_fly = tokenize_on_the_fly | ||
|
||
self.processed_data, self.total_sentences_map = self.preprocess(raw_data) | ||
|
||
@classmethod | ||
def initialize_dataset(cls, tokenizer, data_args, device="cuda"): | ||
train_dataset = None | ||
eval_dataset = None | ||
|
||
if data_args.train_percentage == 1: | ||
train_dataset = cls( | ||
load_json(os.path.join(data_args.path_to_train_data_dir, "data.json")), | ||
tokenizer, device, data_args.tokenize_on_the_fly | ||
) | ||
if data_args.path_to_eval_data_dir is not None: | ||
eval_dataset = cls( | ||
load_json(os.path.join(data_args.path_to_eval_data_dir, "data.json")), | ||
tokenizer, device, data_args.tokenize_on_the_fly | ||
) | ||
else: | ||
data = load_json(os.path.join(data_args.path_to_train_data_dir, "data.json")) | ||
|
||
perm = torch.randperm(len(data)).tolist() | ||
split = int(len(perm) * data_args.train_percentage) | ||
train_indices = perm[:split] | ||
eval_indices = perm[split:] | ||
|
||
train_data = [data[i] for i in train_indices] | ||
eval_data = [data[i] for i in eval_indices] | ||
|
||
train_dataset = cls(train_data, tokenizer, device, data_args.tokenize_on_the_fly) | ||
eval_dataset = cls(eval_data, tokenizer, device, data_args.tokenize_on_the_fly) | ||
|
||
return train_dataset, eval_dataset | ||
|
||
def preprocess(self, raw_data): | ||
total_sentences_map = collections.defaultdict(list) | ||
|
||
for d in raw_data: | ||
total_sentences_map["query_sentence_list"].append(d[0]) | ||
total_sentences_map["positive_sentence_list"].append(d[1]) | ||
if self.has_negative_sample: | ||
total_sentences_map["negative_sentence_list"].append(d[2]) | ||
|
||
total_tokens_map = {} | ||
if not self.tokenize_on_the_fly: | ||
for k in total_sentences_map: | ||
k_tokens = self.tokenizer( | ||
total_sentences_map[k], padding="max_length", | ||
truncation=True, return_tensors="pt" | ||
) | ||
|
||
sentence_num = len(total_sentences_map[k]) | ||
total_tokens_map[k] = k_tokens | ||
return total_tokens_map, total_sentences_map | ||
|
||
def __len__(self): | ||
return self.raw_data_length | ||
|
||
def __getitem__(self, idx): | ||
if self.tokenize_on_the_fly: | ||
return {k: self.total_sentences_map[k][idx] for k in self.total_sentences_map} | ||
return [{ | ||
"input_ids": self.processed_data[k]["input_ids"][idx], | ||
"attention_mask": self.processed_data[k]["attention_mask"][idx] | ||
} for k in self.processed_data] | ||
|
||
def collate_fn(self, batch_pair_data): | ||
""" | ||
Returns: | ||
{ | ||
"input_ids": torch.tensor([ | ||
[], query_sample | ||
[], positive_sample | ||
[], negative_sample if exist | ||
[], query_sample | ||
[], positive_sample | ||
[], negative_sample if exist | ||
]), | ||
"attention_mask": torch.tensor([ | ||
[], query_sample | ||
[], positive_sample | ||
[], negative_sample if exist | ||
[], query_sample | ||
[], positive_sample | ||
[], negative_sample if exist | ||
]), | ||
} | ||
""" | ||
if self.tokenize_on_the_fly: | ||
flatten_sentence_list = [] | ||
for data in batch_pair_data: | ||
for k in data: | ||
flatten_sentence_list.append(data[k]) | ||
merged_batch_tokens = self.tokenizer( | ||
flatten_sentence_list, padding=True, max_length=self.tokenizer.model_max_length, | ||
truncation=True, return_tensors="pt" | ||
) | ||
|
||
merged_batch_tokens = { | ||
"input_ids": merged_batch_tokens["input_ids"], | ||
"attention_mask": merged_batch_tokens["attention_mask"] | ||
} | ||
else: | ||
flatten_batch_pair_data = [] | ||
for pd in batch_pair_data: | ||
flatten_batch_pair_data.extend(pd) | ||
merged_batch_tokens = dict( | ||
input_ids=torch.stack([d["input_ids"] for d in flatten_batch_pair_data], 0), | ||
attention_mask=torch.stack([d["attention_mask"] for d in flatten_batch_pair_data], 0), | ||
) | ||
|
||
merged_batch_tokens = self.truncate_redundant_tokens(merged_batch_tokens) | ||
return merged_batch_tokens, {"has_negative_sample": self.has_negative_sample} | ||
|
||
def truncate_redundant_tokens(self, batch_tokens: Dict[str, torch.tensor]): | ||
if dist.is_initialized() and dist.get_world_size() > 1: | ||
# If we use tensor parallelism, we must ensure the sequence lengths are the same for each process. | ||
# Therefore, we all reduce here to get the max value between all processes. | ||
max_non_zero_index = torch.max(torch.sum(batch_tokens["attention_mask"], 1)).to(self.device) | ||
dist.all_reduce( | ||
max_non_zero_index, | ||
op=torch.distributed.ReduceOp.MAX | ||
) | ||
max_non_zero_index = max_non_zero_index.cpu() | ||
else: | ||
max_non_zero_index = torch.max(torch.sum(batch_tokens["attention_mask"], 1)) | ||
|
||
# To compatible with flash attention | ||
max_non_zero_index += 4 - max_non_zero_index % 4 | ||
|
||
for k, v in batch_tokens.items(): | ||
v = v[:, :max_non_zero_index] | ||
batch_tokens[k] = v.to(self.device) | ||
return batch_tokens | ||
|
Oops, something went wrong.