From 9ef9da5382482c4dfb7fbc7e76767b87d1b14955 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Wed, 29 Jan 2025 20:32:28 +0800 Subject: [PATCH] Update training script --- README.md | 46 +++++++++++- requirements.txt | 6 ++ src/arguments.py | 121 ++++++++++++++++++++++++++++++ src/criterion.py | 33 ++++++++ src/data_processor.py | 171 ++++++++++++++++++++++++++++++++++++++++++ src/model.py | 35 +++++++++ src/trainer.py | 135 +++++++++++++++++++++++++++++++++ src/utils.py | 138 ++++++++++++++++++++++++++++++++++ train.py | 71 ++++++++++++++++++ 9 files changed, 754 insertions(+), 2 deletions(-) create mode 100644 src/arguments.py create mode 100644 src/criterion.py create mode 100644 src/data_processor.py create mode 100644 src/model.py create mode 100644 src/trainer.py create mode 100644 train.py diff --git a/README.md b/README.md index f22ae45..0dc81de 100644 --- a/README.md +++ b/README.md @@ -91,8 +91,50 @@ python3 evaluate.py --model-name [MODEL_NAME] \ - `"thenlper/gte-small"` - Adjust the `--batch-size` parameter if necessary to accommodate hardware constraints. -## TODO -- [ ] The training script of CmdCaliper. +## Training Scripts of CmdCaliper +We provide the training scripts with the configs of CmdCaliper reported in our paper. + +### Training Command +``` +python3 train.py \ + --temperature 0.05 \ + --lr 0.00002 \ + --path-to-checkpoint-dir ./checkpoints \ + --path-to-train-data-dir ./data/train_data \ + --path_to_eval_data_dir ./data/eval_data \ + --path-to-model-weight thenlper/gte-small \ + --epochs 2 +``` + +### Data Preparation +You need to prepare a `data.json` file for both your training and evaluation datasets. Place these files in the directories specified by `--path-to-train-data-dir` and `--path-to-eval-data-dir`. In our paper, we extracted 1,000 command line pairs from the training data to serve as the evaluation dataset. + +Please make sure the data in `data.json` follow this format: +``` +[ + [cmd1, positive_cmd1], + [cmd2, positive_cmd2], + [cmd3, positive_cmd3], + [cmd4, positive_cmd4], + ... +] +``` + +#### Automatic Evaluation Split + +You can also automatically split your training data into training and evaluation datasets by using the `--train-percentage` argument. Note that this will result in a different evaluation dataset for each training session. + +## Checkpoints + +During training, the following will be saved in the directory specified by `--path-to-checkpoint-dir`: + +- Model weights +- Optimizer state +- Learning rate scheduler state + +These files allow you to resume training if needed. Additionally, a `huggingface_model` directory will be created, containing the model weights in Transformers style. + + ## Citation ``` diff --git a/requirements.txt b/requirements.txt index 1b26c08..658c187 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,9 @@ PyYAML==6.0.2 sentence_transformers==3.1.1 torch==2.5.1 google-generativeai==0.8.3 + +safetensors==0.5.2 +huggingface-hub==0.27.1 +transformers==4.48.1 +numpy==2.2.2 + diff --git a/src/arguments.py b/src/arguments.py new file mode 100644 index 0000000..41b8e28 --- /dev/null +++ b/src/arguments.py @@ -0,0 +1,121 @@ +import logging +import pathlib +from typing import List, Union, Optional, Literal +from dataclasses import dataclass, field, fields + +@dataclass +class CriterionArguments: + temperature: float = field( + default=0.05, + metadata={ + 'help': 'The temperature of InfoNCE loss.' + } + ) + +@dataclass +class DataArguments: + path_to_train_data_dir: str = field( + metadata={ + 'aliases': '--path-to-train-data-dir', + 'required': True, + 'help': 'Path to data folder, which should contain "train" as child folder.' + } + ) + + path_to_eval_data_dir: Optional[str] = field( + default=None, + metadata={ + 'aliases': '--path-to-eval-data-dir', + 'help': 'Path to data folder, which should contain "eval" as child folder.' + } + ) + + train_percentage: float = field( + default=1., + metadata={ + 'aliases': '--train-percentage', + 'help': 'Percentage of spliting data into train_dataset and eval_dataset' + } + ) + + tokenize_on_the_fly: bool = field( + default=False, + metadata={ + 'aliases': '--tokenize-on-the-fly', + 'help': 'Whether to tokenize the sentences in each iteration.' + } + ) + + def __post_init__(self): + assert 0 < self.train_percentage <= 1, 'training_percentage should be within the range (0, 1]' + +@dataclass +class ModelArguments: + model_max_length: int = field( + default=512, + metadata={ + 'aliases': ['--max-sequence-len', '--max_sequence_len', '--model-max-length'], + 'help': 'Maximum sequence length. Sequences will be right padded (and possibly truncated).' + }, + ) + path_to_model_weight: str = field( + default=None, + metadata={'aliases': '--path-to-model-weight'} + ) + load_from_pretrained: bool = field(default=True, metadata={'aliases': '--load-from-pretrained'}) + gradient_checkpointing: bool = field(default=True, metadata={'aliases': '--gradient-checkpointing'}) + +@dataclass +class TrainingArguments: + path_to_checkpoint_dir: pathlib.Path = field( + metadata={ + 'aliases': '--path-to-checkpoint-dir', + 'required': True + } + ) + device: str = field(default="cuda") + + lr: float = field(default=0.00002) + epochs: int = field(default=2) + + shuffle: bool = field(default=True) + per_device_train_batch_size: int = field( + default=64, + metadata={ + 'aliases': ['--batch-size', '--batch_size', '--per-device-train-batch-size'], + 'help': 'The batch size per GPU/TPU core/CPU for training.' + } + ) + per_device_eval_batch_size: int = field( + default=32, + metadata={ + 'aliases': '--per-device-eval-batch-size', + 'help': 'The batch size per GPU/TPU core/CPU for evaluation.' + } + ) + log_level: str = field( + default='INFO', + metadata={ + 'aliases': '--log-level', + 'help': f'Set logging level. Choices=[{"|".join(logging._nameToLevel.keys())}]' + } + ) + log_interval: int = field( + default=10, + metadata={'aliases': '--log-interval'}, + ) + eval_interval: int = field( + default=50, + metadata={ + 'aliases': '--eval-interval', + 'help': 'Do evaluation every eval_interval steps if eval_strategy is steps.' + }, + ) + + random_seed: int = field( + default=42, + metadata={'aliases': '--random-seed'} + ) + + def __post_init__(self): + self.log_level = logging._nameToLevel[self.log_level.upper()] diff --git a/src/criterion.py b/src/criterion.py new file mode 100644 index 0000000..21817ce --- /dev/null +++ b/src/criterion.py @@ -0,0 +1,33 @@ +import torch +import torch.nn.functional as F + +class InfoNCE: + def __init__(self, criterion_args, device="cuda"): + self.device = device + + self.temperature = criterion_args.temperature + + def __call__(self, x, auxiliary_data): + step_size = 3 if auxiliary_data["has_negative_sample"] else 2 + query_x = x[0::step_size] + positive_x = x[1::step_size] + if auxiliary_data["has_negative_sample"]: + negative_x = x[2::step_size] + + positive_similarity = F.cosine_similarity(query_x, positive_x).unsqueeze(-1) + positive_negative_similarity = F.cosine_similarity( + query_x.unsqueeze(0), positive_x.unsqueeze(1), -1 + ) + label_mask = ~torch.eye(positive_negative_similarity.shape[0], device=self.device, dtype=torch.bool) + positive_negative_similarity = positive_negative_similarity[label_mask].reshape(query_x.size(0), -1) + if auxiliary_data["has_negative_sample"]: + negative_similarity = F.cosine_similarity( + query_x.unsqueeze(0), negative_x.unsqueeze(1), -1 + ) + positive_negative_similarity = torch.cat([positive_negative_similarity, negative_similarity], -1) + all_similarity = torch.cat([positive_similarity, positive_negative_similarity], -1) + labels = torch.zeros(all_similarity.size(0), dtype=torch.long, device=self.device) + + loss = F.cross_entropy(all_similarity / self.temperature, labels) + loss = loss.mean() + return loss diff --git a/src/data_processor.py b/src/data_processor.py new file mode 100644 index 0000000..e7830bc --- /dev/null +++ b/src/data_processor.py @@ -0,0 +1,171 @@ +import collections +import os +from typing import Dict + +import torch +import torch.distributed as dist + +from .utils import load_json + +class ContrastDataset: + """ + Data format: + ``` + [ + [sentence_1, similar_sentence_1], + [sentence_2, similar_sentence_2], + [sentence_3, similar_sentence_3], + ] + ``` + or + ``` + [ + [sentence_1, similar_sentence_1, hard_negative_sentence_1], + [sentence_2, similar_sentence_2, hard_negative_sentence_2], + [sentence_3, similar_sentence_3, hard_negative_sentence_3], + ] + ``` + """ + def __init__(self, raw_data, tokenizer, device, tokenize_on_the_fly=False): + self.tokenizer = tokenizer + + self.raw_data_length = len(raw_data) + self.device = device + + self.has_negative_sample = len(raw_data[0]) == 3 if len(raw_data) > 0 else False + self.tokenize_on_the_fly = tokenize_on_the_fly + + self.processed_data, self.total_sentences_map = self.preprocess(raw_data) + + @classmethod + def initialize_dataset(cls, tokenizer, data_args, device="cuda"): + train_dataset = None + eval_dataset = None + + if data_args.train_percentage == 1: + train_dataset = cls( + load_json(os.path.join(data_args.path_to_train_data_dir, "data.json")), + tokenizer, device, data_args.tokenize_on_the_fly + ) + if data_args.path_to_eval_data_dir is not None: + eval_dataset = cls( + load_json(os.path.join(data_args.path_to_eval_data_dir, "data.json")), + tokenizer, device, data_args.tokenize_on_the_fly + ) + else: + data = load_json(os.path.join(data_args.path_to_train_data_dir, "data.json")) + + perm = torch.randperm(len(data)).tolist() + split = int(len(perm) * data_args.train_percentage) + train_indices = perm[:split] + eval_indices = perm[split:] + + train_data = [data[i] for i in train_indices] + eval_data = [data[i] for i in eval_indices] + + train_dataset = cls(train_data, tokenizer, device, data_args.tokenize_on_the_fly) + eval_dataset = cls(eval_data, tokenizer, device, data_args.tokenize_on_the_fly) + + return train_dataset, eval_dataset + + def preprocess(self, raw_data): + total_sentences_map = collections.defaultdict(list) + + for d in raw_data: + total_sentences_map["query_sentence_list"].append(d[0]) + total_sentences_map["positive_sentence_list"].append(d[1]) + if self.has_negative_sample: + total_sentences_map["negative_sentence_list"].append(d[2]) + + total_tokens_map = {} + if not self.tokenize_on_the_fly: + for k in total_sentences_map: + k_tokens = self.tokenizer( + total_sentences_map[k], padding="max_length", + truncation=True, return_tensors="pt" + ) + + sentence_num = len(total_sentences_map[k]) + total_tokens_map[k] = k_tokens + return total_tokens_map, total_sentences_map + + def __len__(self): + return self.raw_data_length + + def __getitem__(self, idx): + if self.tokenize_on_the_fly: + return {k: self.total_sentences_map[k][idx] for k in self.total_sentences_map} + return [{ + "input_ids": self.processed_data[k]["input_ids"][idx], + "attention_mask": self.processed_data[k]["attention_mask"][idx] + } for k in self.processed_data] + + def collate_fn(self, batch_pair_data): + """ + Returns: + { + "input_ids": torch.tensor([ + [], query_sample + [], positive_sample + [], negative_sample if exist + [], query_sample + [], positive_sample + [], negative_sample if exist + ]), + "attention_mask": torch.tensor([ + [], query_sample + [], positive_sample + [], negative_sample if exist + [], query_sample + [], positive_sample + [], negative_sample if exist + ]), + } + """ + if self.tokenize_on_the_fly: + flatten_sentence_list = [] + for data in batch_pair_data: + for k in data: + flatten_sentence_list.append(data[k]) + merged_batch_tokens = self.tokenizer( + flatten_sentence_list, padding=True, max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt" + ) + + merged_batch_tokens = { + "input_ids": merged_batch_tokens["input_ids"], + "attention_mask": merged_batch_tokens["attention_mask"] + } + else: + flatten_batch_pair_data = [] + for pd in batch_pair_data: + flatten_batch_pair_data.extend(pd) + merged_batch_tokens = dict( + input_ids=torch.stack([d["input_ids"] for d in flatten_batch_pair_data], 0), + attention_mask=torch.stack([d["attention_mask"] for d in flatten_batch_pair_data], 0), + ) + + merged_batch_tokens = self.truncate_redundant_tokens(merged_batch_tokens) + return merged_batch_tokens, {"has_negative_sample": self.has_negative_sample} + + def truncate_redundant_tokens(self, batch_tokens: Dict[str, torch.tensor]): + if dist.is_initialized() and dist.get_world_size() > 1: + # If we use tensor parallelism, we must ensure the sequence lengths are the same for each process. + # Therefore, we all reduce here to get the max value between all processes. + max_non_zero_index = torch.max(torch.sum(batch_tokens["attention_mask"], 1)).to(self.device) + dist.all_reduce( + max_non_zero_index, + op=torch.distributed.ReduceOp.MAX + ) + max_non_zero_index = max_non_zero_index.cpu() + else: + max_non_zero_index = torch.max(torch.sum(batch_tokens["attention_mask"], 1)) + + # To compatible with flash attention + max_non_zero_index += 4 - max_non_zero_index % 4 + + for k, v in batch_tokens.items(): + v = v[:, :max_non_zero_index] + batch_tokens[k] = v.to(self.device) + return batch_tokens + diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..e823fd4 --- /dev/null +++ b/src/model.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer + +class CSEBert(nn.Module): + def __init__(self, path_to_model_weight, gradient_checkpointing=True): + super().__init__() + self.path_to_model_weight = path_to_model_weight + + self.transformer = AutoModel.from_pretrained( + path_to_model_weight, use_cache=False + ) + + if gradient_checkpointing: + self.transformer.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + y = self.transformer( + input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True + ) + + token_embeddings = y[0] + + input_mask_expanded = attention_mask.unsqueeze(-1).expand( + token_embeddings.size() + ).to(token_embeddings.dtype) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def get_tokenizer(self, **kwargs): + tokenizer = AutoTokenizer.from_pretrained( + self.path_to_model_weight, **kwargs + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + return tokenizer diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..ac97bc4 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,135 @@ +import time + +import numpy as np +import torch +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import wrap + +from .utils import rank0_print, calculate_param_nums, AverageMeter, save_checkpoint + +class Trainer: + def __init__( + self, model, optimizer, criterion, training_args, + train_dataloader, eval_dataloader=None, lr_scheduler=None + ): + self.model = model + self.optimizer = optimizer + + self.lr_scheduler = lr_scheduler + self.criterion = criterion + self.training_args = training_args + + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.do_eval = False + if self.eval_dataloader is not None: + self.do_eval = True + + self.start_epoch = 0 + + def evaluation(self, ep: int, it: int): + self.model.eval() + losses = AverageMeter() + + with torch.no_grad(): + torch.cuda.synchronize() + for i, (batch_data, auxiliary_data) in enumerate(self.eval_dataloader): + + torch.cuda.synchronize() + + N = len(batch_data) + + output = self.model(**batch_data) + + torch.cuda.synchronize() + loss = self.criterion(output, auxiliary_data) + torch.cuda.synchronize() + + if np.isnan(loss.item()): + rank0_print('Hit nan loss. Skip record!') + else: + losses.update(loss.item(), N) + + rank0_print(f'Epoch: {ep + 1}/{self.training_args.epochs}.' + f' Iteration: {it + 1}/{len(self.train_dataloader)}.' + f' Eval loss: {losses.get_avg()}.') + + self.model.train() + return losses.get_avg() + + + def train(self): + rank0_print(f'The total param num of the model: {calculate_param_nums(self.model)}') + + self.model = wrap(self.model) + self.model.train() + + rank0_print('Start to train!!!!!') + + losses = AverageMeter() + for param_name, param in vars(self.training_args).items(): + rank0_print(f'Param Name -- {param_name}: {param}') + + + steps = 0 + best_eval_loss = float("inf") + self.optimizer.zero_grad() + + for e in range(self.start_epoch, self.training_args.epochs): + start_time = time.time() + losses.reset() + + for batch_idx, (batch_data, auxiliary_data) in enumerate(self.train_dataloader): + self.optimizer.zero_grad() + + + N = len(batch_data) + output = self.model(**batch_data) + + loss = self.criterion(output, auxiliary_data) + if np.isnan(loss.item()): + rank0_print('Hit nan loss. Skip record!') + else: + losses.update(loss.item(), N) + + loss.backward() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + torch.cuda.synchronize() + + steps += 1 + self.optimizer.step() + + if (steps % self.training_args.log_interval == 0) or (batch_idx + 1 == len(self.train_dataloader)): + + rank0_print(f'Epoch: {e + 1}/{self.training_args.epochs}.' + f' LR: {self.optimizer.param_groups[0]["lr"]:.9f}' + f' Iteration: {batch_idx + 1}/{len(self.train_dataloader)}.' + f' Train loss: {losses.get_avg()}.' + f' Time: {time.time() - start_time:.2f}.') + + torch.cuda.synchronize() + + if (steps % self.training_args.log_interval == 0) or (batch_idx + 1 == len(self.train_dataloader)): + start_time = time.time() + + if self.do_eval: + eval_loss = None + if steps % self.training_args.eval_interval == 0: + eval_loss = self.evaluation(e, batch_idx) + + if eval_loss is not None: + torch.cuda.synchronize() + + if eval_loss < best_eval_loss: + rank0_print(f"Achieve new lowest eval loss: {eval_loss} ! Save checkpoint") + best_eval_loss = eval_loss + torch.cuda.synchronize() + + save_checkpoint( + self.training_args.path_to_checkpoint_dir, + self.model, self.optimizer, self.lr_scheduler, e + ) + diff --git a/src/utils.py b/src/utils.py index cf7c0f9..54e5d8f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,6 +1,18 @@ +import sys +import os import json +import logging +from pathlib import Path import yaml +import torch +import numpy as np + +HUGGINGFACE_DIRNAME = "./huggingface_model" +MODEL_NAME = "model.pth" +OPTIMIZER_NAME = "optimizer.pth" +LR_SCHEDULER_NAME = "lr_scheduler.pth" + def load_yaml(path_to_data): with open(path_to_data, "r") as f: data = yaml.safe_load(f) @@ -35,3 +47,129 @@ def extract_cmds(response): continue new_generated_cmd_list.append(cmd.strip()) return new_generated_cmd_list + +class AverageMeter(object): + def __init__(self, name=''): + self._name = name + self.avg = 0.0 + self.sum = 0.0 + self.cnt = 0.0 + + def reset(self): + self.avg = 0.0 + self.sum = 0.0 + self.cnt = 0.0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + def __str__(self): + return "%s: %.5f" % (self._name, self.avg) + + def get_avg(self): + return self.avg + + def __repr__(self): + return self.__str__() + + +def set_random_seed(seed): + import random + logging.info("Set seed: {}".format(seed)) + random.seed(seed) + np.random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def rank0_print(*args, level: int = logging.INFO): + try: + if dist.get_rank() == 0: + logging.log(level, *args) + except: + logging.log(level, *args) + + +def initialize_logging(path_to_logging_dir: Path, level: int): + # Clear original logging setting (e.g., ColossalAI) + logger = logging.getLogger() + for handler in logger.handlers: + handler.close() + logger.handlers.clear() + + log_format = "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s" + logging.basicConfig( + stream=sys.stdout, + #filename=os.path.join(path_to_logging_dir, "logger.log"), + level=level, + format=log_format, + datefmt="%m/%d %I:%M:%S %p" + ) + + path_to_logging_file = path_to_logging_dir / "logger.log" + path_to_logging_file.touch(exist_ok=True) + + file_handler = logging.FileHandler(str(path_to_logging_file)) + file_handler.setFormatter(logging.Formatter(log_format)) + logging.getLogger().addHandler(file_handler) + +def calculate_param_nums(model): + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return total_params + +def resume_checkpoint( + path_to_checkpoint_dir, + model, + criterion=None, + optimizer=None, + lr_scheduler=None +): + path_to_model_checkpoint = os.path.join(path_to_checkpoint_dir, MODEL_NAME) + path_to_optimizer = os.path.join(path_to_checkpoint_dir, OPTIMIZER_NAME) + path_to_lr_scheduler = os.path.join(path_to_checkpoint_dir, LR_SCHEDULER_NAME) + + resume_epoch = None + if os.path.isfile(path_to_model_checkpoint): + model_checkpoint = torch.load(path_to_model_checkpoint) + model.load_state_dict(model_checkpoint) + + if os.path.isfile(path_to_optimizer): + optimizer_checkpoint = torch.load(path_to_optimizer) + optimizer.load_state_dict(optimizer_checkpoint["optimizer"]) + resume_epoch = optimizer_checkpoint.get("resume_epoch", None) + + if os.path.isfile(path_to_lr_scheduler): + lr_scheduler_checkpoint = torch.load(path_to_lr_scheduler) + lr_scheduler.load_state_dict(lr_scheduler_checkpoint) + + return resume_epoch + +def save_checkpoint( + path_to_checkpoint_dir, + model, + optimizer=None, + lr_scheduler=None, + resume_epoch=None +): + path_to_model_checkpoint = os.path.join(path_to_checkpoint_dir, MODEL_NAME) + path_to_huggingface_model_checkpoint = os.path.join(path_to_checkpoint_dir, HUGGINGFACE_DIRNAME) + os.makedirs(path_to_huggingface_model_checkpoint, exist_ok=True) + path_to_optimizer = os.path.join(path_to_checkpoint_dir, OPTIMIZER_NAME) + path_to_lr_scheduler = os.path.join(path_to_checkpoint_dir, LR_SCHEDULER_NAME) + + model_checkpoint = model.state_dict() + model.transformer.save_pretrained(path_to_huggingface_model_checkpoint) + + torch.save(model_checkpoint, path_to_model_checkpoint) + + if optimizer is not None: + torch.save({ + "optimizer": optimizer.state_dict(), + "resume_epoch": resume_epoch + }, path_to_optimizer) + + if lr_scheduler is not None: + torch.save(lr_scheduler.state_dict(), path_to_lr_scheduler) diff --git a/train.py b/train.py new file mode 100644 index 0000000..07c0fb2 --- /dev/null +++ b/train.py @@ -0,0 +1,71 @@ +import torch +import transformers + +from src.arguments import CriterionArguments, DataArguments, TrainingArguments, ModelArguments +from src.criterion import InfoNCE +from src.data_processor import ContrastDataset +from src.model import CSEBert +from src.utils import initialize_logging, set_random_seed +from src.trainer import Trainer + +def main( + criterion_args: CriterionArguments, + training_args: TrainingArguments, + data_args: DataArguments, + model_args: ModelArguments +): + + device = training_args.device + + set_random_seed(training_args.random_seed) + training_args.path_to_checkpoint_dir.mkdir(exist_ok=True) + + initialize_logging( + path_to_logging_dir=training_args.path_to_checkpoint_dir, + level=training_args.log_level + ) + + criterion = InfoNCE(criterion_args, device=device) + model = CSEBert( + model_args.path_to_model_weight, + model_args.gradient_checkpointing + ) + model.to(device) + + tokenizer = model.get_tokenizer( + padding_side="right", model_max_length=model_args.model_max_length + ) + train_dataset, eval_dataset = ContrastDataset.initialize_dataset(tokenizer, data_args, device=device) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, collate_fn=train_dataset.collate_fn, + batch_size=training_args.per_device_train_batch_size + ) + eval_dataloader = None + if eval_dataset is not None: + eval_dataloader = torch.utils.data.DataLoader( + eval_dataset, collate_fn=eval_dataset.collate_fn, + batch_size=training_args.per_device_eval_batch_size + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.lr) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader)*training_args.epochs) + + trainer = Trainer( + model=model, optimizer=optimizer, + criterion=criterion, training_args=training_args, + train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, + lr_scheduler=lr_scheduler + ) + trainer.train() + + +if __name__ == "__main__": + parser = transformers.HfArgumentParser(( + CriterionArguments, TrainingArguments, DataArguments, ModelArguments + )) + + (criterion_args, training_args, data_args, model_args, _) \ + = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + main(criterion_args, training_args, data_args, model_args)