Skip to content

Commit

Permalink
Update training script
Browse files Browse the repository at this point in the history
  • Loading branch information
eric8607242 committed Jan 29, 2025
1 parent e98f64c commit 9ef9da5
Show file tree
Hide file tree
Showing 9 changed files with 754 additions and 2 deletions.
46 changes: 44 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,50 @@ python3 evaluate.py --model-name [MODEL_NAME] \
- `"thenlper/gte-small"`
- Adjust the `--batch-size` parameter if necessary to accommodate hardware constraints.

## TODO
- [ ] The training script of CmdCaliper.
## Training Scripts of CmdCaliper
We provide the training scripts with the configs of CmdCaliper reported in our paper.

### Training Command
```
python3 train.py \
--temperature 0.05 \
--lr 0.00002 \
--path-to-checkpoint-dir ./checkpoints \
--path-to-train-data-dir ./data/train_data \
--path_to_eval_data_dir ./data/eval_data \
--path-to-model-weight thenlper/gte-small \
--epochs 2
```

### Data Preparation
You need to prepare a `data.json` file for both your training and evaluation datasets. Place these files in the directories specified by `--path-to-train-data-dir` and `--path-to-eval-data-dir`. In our paper, we extracted 1,000 command line pairs from the training data to serve as the evaluation dataset.

Please make sure the data in `data.json` follow this format:
```
[
[cmd1, positive_cmd1],
[cmd2, positive_cmd2],
[cmd3, positive_cmd3],
[cmd4, positive_cmd4],
...
]
```

#### Automatic Evaluation Split

You can also automatically split your training data into training and evaluation datasets by using the `--train-percentage` argument. Note that this will result in a different evaluation dataset for each training session.

## Checkpoints

During training, the following will be saved in the directory specified by `--path-to-checkpoint-dir`:

- Model weights
- Optimizer state
- Learning rate scheduler state

These files allow you to resume training if needed. Additionally, a `huggingface_model` directory will be created, containing the model weights in Transformers style.



## Citation
```
Expand Down
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ PyYAML==6.0.2
sentence_transformers==3.1.1
torch==2.5.1
google-generativeai==0.8.3

safetensors==0.5.2
huggingface-hub==0.27.1
transformers==4.48.1
numpy==2.2.2

121 changes: 121 additions & 0 deletions src/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging
import pathlib
from typing import List, Union, Optional, Literal
from dataclasses import dataclass, field, fields

@dataclass
class CriterionArguments:
temperature: float = field(
default=0.05,
metadata={
'help': 'The temperature of InfoNCE loss.'
}
)

@dataclass
class DataArguments:
path_to_train_data_dir: str = field(
metadata={
'aliases': '--path-to-train-data-dir',
'required': True,
'help': 'Path to data folder, which should contain "train" as child folder.'
}
)

path_to_eval_data_dir: Optional[str] = field(
default=None,
metadata={
'aliases': '--path-to-eval-data-dir',
'help': 'Path to data folder, which should contain "eval" as child folder.'
}
)

train_percentage: float = field(
default=1.,
metadata={
'aliases': '--train-percentage',
'help': 'Percentage of spliting data into train_dataset and eval_dataset'
}
)

tokenize_on_the_fly: bool = field(
default=False,
metadata={
'aliases': '--tokenize-on-the-fly',
'help': 'Whether to tokenize the sentences in each iteration.'
}
)

def __post_init__(self):
assert 0 < self.train_percentage <= 1, 'training_percentage should be within the range (0, 1]'

@dataclass
class ModelArguments:
model_max_length: int = field(
default=512,
metadata={
'aliases': ['--max-sequence-len', '--max_sequence_len', '--model-max-length'],
'help': 'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
},
)
path_to_model_weight: str = field(
default=None,
metadata={'aliases': '--path-to-model-weight'}
)
load_from_pretrained: bool = field(default=True, metadata={'aliases': '--load-from-pretrained'})
gradient_checkpointing: bool = field(default=True, metadata={'aliases': '--gradient-checkpointing'})

@dataclass
class TrainingArguments:
path_to_checkpoint_dir: pathlib.Path = field(
metadata={
'aliases': '--path-to-checkpoint-dir',
'required': True
}
)
device: str = field(default="cuda")

lr: float = field(default=0.00002)
epochs: int = field(default=2)

shuffle: bool = field(default=True)
per_device_train_batch_size: int = field(
default=64,
metadata={
'aliases': ['--batch-size', '--batch_size', '--per-device-train-batch-size'],
'help': 'The batch size per GPU/TPU core/CPU for training.'
}
)
per_device_eval_batch_size: int = field(
default=32,
metadata={
'aliases': '--per-device-eval-batch-size',
'help': 'The batch size per GPU/TPU core/CPU for evaluation.'
}
)
log_level: str = field(
default='INFO',
metadata={
'aliases': '--log-level',
'help': f'Set logging level. Choices=[{"|".join(logging._nameToLevel.keys())}]'
}
)
log_interval: int = field(
default=10,
metadata={'aliases': '--log-interval'},
)
eval_interval: int = field(
default=50,
metadata={
'aliases': '--eval-interval',
'help': 'Do evaluation every eval_interval steps if eval_strategy is steps.'
},
)

random_seed: int = field(
default=42,
metadata={'aliases': '--random-seed'}
)

def __post_init__(self):
self.log_level = logging._nameToLevel[self.log_level.upper()]
33 changes: 33 additions & 0 deletions src/criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn.functional as F

class InfoNCE:
def __init__(self, criterion_args, device="cuda"):
self.device = device

self.temperature = criterion_args.temperature

def __call__(self, x, auxiliary_data):
step_size = 3 if auxiliary_data["has_negative_sample"] else 2
query_x = x[0::step_size]
positive_x = x[1::step_size]
if auxiliary_data["has_negative_sample"]:
negative_x = x[2::step_size]

positive_similarity = F.cosine_similarity(query_x, positive_x).unsqueeze(-1)
positive_negative_similarity = F.cosine_similarity(
query_x.unsqueeze(0), positive_x.unsqueeze(1), -1
)
label_mask = ~torch.eye(positive_negative_similarity.shape[0], device=self.device, dtype=torch.bool)
positive_negative_similarity = positive_negative_similarity[label_mask].reshape(query_x.size(0), -1)
if auxiliary_data["has_negative_sample"]:
negative_similarity = F.cosine_similarity(
query_x.unsqueeze(0), negative_x.unsqueeze(1), -1
)
positive_negative_similarity = torch.cat([positive_negative_similarity, negative_similarity], -1)
all_similarity = torch.cat([positive_similarity, positive_negative_similarity], -1)
labels = torch.zeros(all_similarity.size(0), dtype=torch.long, device=self.device)

loss = F.cross_entropy(all_similarity / self.temperature, labels)
loss = loss.mean()
return loss
171 changes: 171 additions & 0 deletions src/data_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import collections
import os
from typing import Dict

import torch
import torch.distributed as dist

from .utils import load_json

class ContrastDataset:
"""
Data format:
```
[
[sentence_1, similar_sentence_1],
[sentence_2, similar_sentence_2],
[sentence_3, similar_sentence_3],
]
```
or
```
[
[sentence_1, similar_sentence_1, hard_negative_sentence_1],
[sentence_2, similar_sentence_2, hard_negative_sentence_2],
[sentence_3, similar_sentence_3, hard_negative_sentence_3],
]
```
"""
def __init__(self, raw_data, tokenizer, device, tokenize_on_the_fly=False):
self.tokenizer = tokenizer

self.raw_data_length = len(raw_data)
self.device = device

self.has_negative_sample = len(raw_data[0]) == 3 if len(raw_data) > 0 else False
self.tokenize_on_the_fly = tokenize_on_the_fly

self.processed_data, self.total_sentences_map = self.preprocess(raw_data)

@classmethod
def initialize_dataset(cls, tokenizer, data_args, device="cuda"):
train_dataset = None
eval_dataset = None

if data_args.train_percentage == 1:
train_dataset = cls(
load_json(os.path.join(data_args.path_to_train_data_dir, "data.json")),
tokenizer, device, data_args.tokenize_on_the_fly
)
if data_args.path_to_eval_data_dir is not None:
eval_dataset = cls(
load_json(os.path.join(data_args.path_to_eval_data_dir, "data.json")),
tokenizer, device, data_args.tokenize_on_the_fly
)
else:
data = load_json(os.path.join(data_args.path_to_train_data_dir, "data.json"))

perm = torch.randperm(len(data)).tolist()
split = int(len(perm) * data_args.train_percentage)
train_indices = perm[:split]
eval_indices = perm[split:]

train_data = [data[i] for i in train_indices]
eval_data = [data[i] for i in eval_indices]

train_dataset = cls(train_data, tokenizer, device, data_args.tokenize_on_the_fly)
eval_dataset = cls(eval_data, tokenizer, device, data_args.tokenize_on_the_fly)

return train_dataset, eval_dataset

def preprocess(self, raw_data):
total_sentences_map = collections.defaultdict(list)

for d in raw_data:
total_sentences_map["query_sentence_list"].append(d[0])
total_sentences_map["positive_sentence_list"].append(d[1])
if self.has_negative_sample:
total_sentences_map["negative_sentence_list"].append(d[2])

total_tokens_map = {}
if not self.tokenize_on_the_fly:
for k in total_sentences_map:
k_tokens = self.tokenizer(
total_sentences_map[k], padding="max_length",
truncation=True, return_tensors="pt"
)

sentence_num = len(total_sentences_map[k])
total_tokens_map[k] = k_tokens
return total_tokens_map, total_sentences_map

def __len__(self):
return self.raw_data_length

def __getitem__(self, idx):
if self.tokenize_on_the_fly:
return {k: self.total_sentences_map[k][idx] for k in self.total_sentences_map}
return [{
"input_ids": self.processed_data[k]["input_ids"][idx],
"attention_mask": self.processed_data[k]["attention_mask"][idx]
} for k in self.processed_data]

def collate_fn(self, batch_pair_data):
"""
Returns:
{
"input_ids": torch.tensor([
[], query_sample
[], positive_sample
[], negative_sample if exist
[], query_sample
[], positive_sample
[], negative_sample if exist
]),
"attention_mask": torch.tensor([
[], query_sample
[], positive_sample
[], negative_sample if exist
[], query_sample
[], positive_sample
[], negative_sample if exist
]),
}
"""
if self.tokenize_on_the_fly:
flatten_sentence_list = []
for data in batch_pair_data:
for k in data:
flatten_sentence_list.append(data[k])
merged_batch_tokens = self.tokenizer(
flatten_sentence_list, padding=True, max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt"
)

merged_batch_tokens = {
"input_ids": merged_batch_tokens["input_ids"],
"attention_mask": merged_batch_tokens["attention_mask"]
}
else:
flatten_batch_pair_data = []
for pd in batch_pair_data:
flatten_batch_pair_data.extend(pd)
merged_batch_tokens = dict(
input_ids=torch.stack([d["input_ids"] for d in flatten_batch_pair_data], 0),
attention_mask=torch.stack([d["attention_mask"] for d in flatten_batch_pair_data], 0),
)

merged_batch_tokens = self.truncate_redundant_tokens(merged_batch_tokens)
return merged_batch_tokens, {"has_negative_sample": self.has_negative_sample}

def truncate_redundant_tokens(self, batch_tokens: Dict[str, torch.tensor]):
if dist.is_initialized() and dist.get_world_size() > 1:
# If we use tensor parallelism, we must ensure the sequence lengths are the same for each process.
# Therefore, we all reduce here to get the max value between all processes.
max_non_zero_index = torch.max(torch.sum(batch_tokens["attention_mask"], 1)).to(self.device)
dist.all_reduce(
max_non_zero_index,
op=torch.distributed.ReduceOp.MAX
)
max_non_zero_index = max_non_zero_index.cpu()
else:
max_non_zero_index = torch.max(torch.sum(batch_tokens["attention_mask"], 1))

# To compatible with flash attention
max_non_zero_index += 4 - max_non_zero_index % 4

for k, v in batch_tokens.items():
v = v[:, :max_non_zero_index]
batch_tokens[k] = v.to(self.device)
return batch_tokens

Loading

0 comments on commit 9ef9da5

Please sign in to comment.