diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 5d558c2b451f..306484331fb7 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -36,6 +36,7 @@ from nemo.collections.llm.gpt.data import ( AlpacaDataModule, ChatDataModule, + CustomRetrievalDataModule, DollyDataModule, FineTuningDataModule, HFDatasetDataModule, @@ -91,7 +92,9 @@ Llama31Config405B, Llama32Config1B, Llama32Config3B, + Llama32EmbeddingConfig1B, LlamaConfig, + LlamaEmbeddingModel, LlamaModel, MaskedTokenLossReduction, MistralConfig7B, @@ -150,6 +153,7 @@ __all__ = [ "MockDataModule", "T5MockDataModule", + "CustomRetrievalDataModule", "GPTModel", "GPTConfig", "gpt_data_step", @@ -185,6 +189,8 @@ "Nemotron4Config15B", "Nemotron4Config340B", "NemotronConfig", + "LlamaEmbeddingModel", + "Llama32EmbeddingConfig1B", "Phi3Config", "Phi3ConfigMini", "Phi3Model", diff --git a/nemo/collections/llm/bert/loss.py b/nemo/collections/llm/bert/loss.py index 6fd34a4d3fa3..3bbbdfbd8e49 100644 --- a/nemo/collections/llm/bert/loss.py +++ b/nemo/collections/llm/bert/loss.py @@ -99,6 +99,89 @@ def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: return torch.tensor(0.0, device=torch.cuda.current_device()) +class HardNegativeRankingLoss(MegatronLossReduction): + """ + This loss uses hard-negative samples. + The difference of this loss to the default MultipleNegativesRankingLoss + from Sentence Transformers is that the latter shares the hard negatives + as negatives for all examples, whereas this loss uses hard negatives + exclusively for the example they are associated. + """ + + def __init__( + self, + validation_step: bool = False, + val_drop_last: bool = True, + num_hard_negatives: int = 1, + scale: float = 50, + label_smoothing: float = 0.0, + ) -> None: + super().__init__() + self.validation_step = validation_step + self.val_drop_last = val_drop_last + self.num_hard_negatives = num_hard_negatives + self.scale = scale + self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing) + + def forward( + self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + from megatron.core import parallel_state + + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size != 1: + raise NotImplementedError(f'CP is not supported for {self.__class__} yet.') + + num_tensors_per_example = 2 + self.num_hard_negatives # 1 query, 1 pos, num_hard_negatives negs + current_train_n_passages = 1 + self.num_hard_negatives + batch_size = forward_out.shape[0] // num_tensors_per_example + # Get Query, Key (Positives, Negatives) + # forward_out was chunked [(q1, k1), (q2, k2), ...] + chunks = forward_out.chunk(batch_size) + query = torch.stack([item[0] for item in chunks]) + key = torch.cat([item[1:] for item in chunks]) + + assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0]) + assert key.shape[0] / query.shape[0] == current_train_n_passages, '{} / {} != {}'.format( + key.shape[0], query.shape[0], current_train_n_passages + ) + query_shape = query.shape + repeated_query = query.repeat(1, 1, current_train_n_passages).reshape( + query_shape[0] * current_train_n_passages, query_shape[1] + ) + scores = torch.sum(repeated_query * key, dim=-1).reshape(query_shape[0], current_train_n_passages) + labels = torch.zeros(query_shape[0], dtype=torch.long, device=query.device) + scores *= self.scale + ce_loss = self.cross_entropy_loss(scores, labels) + reduced_loss = average_losses_across_data_parallel_group([ce_loss]) + return ce_loss, {"avg": reduced_loss} + + def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: + """Taken from: https://github.com/NVIDIA/NeMo/blob/main + /nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" + if losses_reduced_per_micro_batch: + if "avg" in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + + return loss_tensor.mean() + + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list: List[torch.Tensor] = [ + loss_sum["loss_sum_and_ub_size"] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum["loss_sum_and_ub_size"][1] > 0 + ] + loss_sum = ( + torch.vstack(loss_sum_tensors_list).sum(dim=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0], device=torch.cuda.current_device()) + ) + return loss_sum + + return torch.tensor(0.0, device=torch.cuda.current_device()) + + class BERTInBatchExclusiveHardNegativesRankingLoss(MegatronLossReduction): """ This loss uses in-batch negative samples + hard-negative samples. diff --git a/nemo/collections/llm/gpt/data/__init__.py b/nemo/collections/llm/gpt/data/__init__.py index 89b5a3dc4b54..fd8935d9c11a 100644 --- a/nemo/collections/llm/gpt/data/__init__.py +++ b/nemo/collections/llm/gpt/data/__init__.py @@ -19,6 +19,7 @@ from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule +from nemo.collections.llm.gpt.data.retrieval import CustomRetrievalDataModule from nemo.collections.llm.gpt.data.squad import SquadDataModule __all__ = [ @@ -31,4 +32,5 @@ "PreTrainingDataModule", "build_pretraining_datamodule", "SquadDataModule", + "CustomRetrievalDataModule", ] diff --git a/nemo/collections/llm/gpt/data/retrieval.py b/nemo/collections/llm/gpt/data/retrieval.py new file mode 100644 index 000000000000..058068e811e0 --- /dev/null +++ b/nemo/collections/llm/gpt/data/retrieval.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os.path +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from datasets import Dataset + +from nemo.collections.llm.bert.data.fine_tuning import FineTuningDataModule +from nemo.collections.llm.gpt.data.core import get_dataset_root +from nemo.utils import logging + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers import TokenizerSpec + from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs + + +# Custom Retrieval Data Module loaded with json file +class CustomRetrievalDataModule(FineTuningDataModule): + """ """ + + def __init__( + self, + data_root: str, + dataset_identifier: str = "custom_retrieval_dataset", + seq_length: int = 2048, + tokenizer: Optional["TokenizerSpec"] = None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + force_redownload: bool = False, + delete_raw: bool = True, + seed: int = 1234, + memmap_workers: int = 1, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, + query_key: str = "question", + pos_doc_key: str = "pos_doc", + neg_doc_key: str = "neg_doc", + dataset_kwargs: Optional[Dict[str, Any]] = None, + ): + self.force_redownload = force_redownload + self.delete_raw = delete_raw + + assert packed_sequence_specs is None, "RetrievalDataModule does not support packed sequences." + assert os.path.exists(data_root), "Data root does not exist." + self.query_key = query_key + self.pos_doc_key = pos_doc_key + self.neg_doc_key = neg_doc_key + self.unprocessed_root = data_root + super().__init__( + dataset_root=get_dataset_root(dataset_identifier), + seq_length=seq_length, + tokenizer=tokenizer, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + seed=seed, + memmap_workers=memmap_workers, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + dataset_kwargs=dataset_kwargs, + ) + + def prepare_data(self) -> None: + """Prepare data if not split already.""" + if not self.train_path.exists() or self.force_redownload: + self._preprocess_and_split_data() + super().prepare_data() + + def _preprocess_and_split_data(self, train_ratio: float = 0.95, val_ratio: float = 0.04): + logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...") + + test_ratio = 1 - train_ratio - val_ratio + save_splits = {} + dataset = Dataset.from_list(json.load(open(self.unprocessed_root, 'r'))) + split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed) + split_dataset2 = split_dataset['test'].train_test_split( + test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed + ) + save_splits['training'] = split_dataset['train'] + save_splits['validation'] = split_dataset2['train'] + save_splits['test'] = split_dataset2['test'] + + for split_name, dataset in save_splits.items(): + output_file = self.dataset_root / f"{split_name}.jsonl" + with output_file.open("w", encoding="utf-8") as f: + for o in dataset: + # We only write one positive document for now + # All negative document are written + pos_doc = o[self.pos_doc_key][0] if isinstance(o[self.pos_doc_key], list) else o[self.pos_doc_key] + neg_doc = o[self.neg_doc_key] if isinstance(o[self.pos_doc_key], list) else [o[self.neg_doc_key]] + f.write(json.dumps({"query": o[self.query_key], "pos_doc": pos_doc, "neg_doc": neg_doc}) + "\n") + + logging.info(f"{split_name} split saved to {output_file}") diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 4e9448eaef2c..d9ab48e0ea51 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -64,6 +64,7 @@ LlamaConfig, LlamaModel, ) +from nemo.collections.llm.gpt.model.llama_embedding import Llama32EmbeddingConfig1B, LlamaEmbeddingModel from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B from nemo.collections.llm.gpt.model.mixtral import ( MixtralConfig, @@ -145,6 +146,8 @@ "Nemotron3Config22B", "Nemotron4Config340B", "NemotronModel", + "LlamaEmbeddingModel", + "Llama32EmbeddingConfig1B", "Phi3Config", "Phi3ConfigMini", "Phi3Model", diff --git a/nemo/collections/llm/gpt/model/hf_llama_embedding.py b/nemo/collections/llm/gpt/model/hf_llama_embedding.py new file mode 100644 index 000000000000..ba89626ff45f --- /dev/null +++ b/nemo/collections/llm/gpt/model/hf_llama_embedding.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForSequenceClassification, LlamaModel +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor: + """Pooling on last_hidden_states without pad tokens.""" + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + + if pool_type == "avg": + emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pool_type == "weighted_avg": + emb = last_hidden.sum(dim=1) + elif pool_type == "cls": + emb = last_hidden[:, 0] + elif pool_type == "last": + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + emb = last_hidden[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden.shape[0] + emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] + else: + raise ValueError(f"pool_type {pool_type} not supported") + + return emb + + +class LlamaBidirectionalConfig(LlamaConfig): + """LLamaBidirectionalConfig for LlamaBidirectionalModel.""" + + model_type = "llama_bidirec" + + def __init__( + self, + pooling="avg", + temperature=1.0, + **kwargs, + ): + self.pooling = pooling + self.temperature = temperature + super().__init__( + **kwargs, + ) + + +class LlamaBidirectionalModel(LlamaModel): + """LlamaBidirectionalModel. + Attention has been adjusted to bidirectional. + """ + + config_class = LlamaBidirectionalConfig + + def __init__(self, config: LlamaConfig): + super().__init__(config) + for layer in self.layers: + layer.self_attn.is_causal = False + self.config._attn_implementation = "eager" + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # Generates bi-directional attention. + causal_mask = _prepare_4d_attention_mask(attention_mask, input_tensor.dtype) + return causal_mask + + +class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification): + """The LLaMa Model transformer with a sequence classification head on top (linear layer).""" + + config_class = LlamaBidirectionalConfig + + def __init__(self, config): + super().__init__(config) + # Releasing the parameters of LlamaModel + # created by parent LlamaForSequenceClassification + del self.model + + self.model = LlamaBidirectionalModel(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + pooled_hidden_states = pool( + last_hidden_states=hidden_states, + attention_mask=attention_mask, + pool_type=self.config.pooling, + ) + + pooled_logits = self.score(pooled_hidden_states) + pooled_logits = pooled_logits / self.config.temperature + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/nemo/collections/llm/gpt/model/llama_embedding.py b/nemo/collections/llm/gpt/model/llama_embedding.py new file mode 100644 index 000000000000..3d8edcc5121a --- /dev/null +++ b/nemo/collections/llm/gpt/model/llama_embedding.py @@ -0,0 +1,401 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Dict, Literal, Optional, Union + +import einops +import lightning.pytorch as L +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.spec_utils import ModuleSpec +from torch import Tensor, nn + +import nemo.collections.llm.gpt.model.base as GPTBase +from nemo.collections.llm.bert.loss import BERTInBatchExclusiveHardNegativesRankingLoss, HardNegativeRankingLoss +from nemo.collections.llm.gpt.model import GPTConfig +from nemo.collections.llm.gpt.model.llama import HFLlamaImporter, Llama32Config1B, LlamaConfig, LlamaModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io +from nemo.lightning.pytorch.utils import dtype_from_hf +from nemo.utils.import_utils import safe_import + +if TYPE_CHECKING: + from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel + + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +_, HAVE_TE = safe_import("transformer_engine") + + +def _local_layer_spec(config: "GPTConfig") -> ModuleSpec: + gpt_layer_spec = GPTBase.local_layer_spec(config) + gpt_layer_spec.submodules.self_attention.params['attn_mask_type'] = AttnMaskType.padding + return gpt_layer_spec + + +def _transformer_engine_layer_spec(config: "GPTConfig") -> ModuleSpec: + gpt_layer_spec = GPTBase.transformer_engine_layer_spec(config) + gpt_layer_spec.submodules.self_attention.params['attn_mask_type'] = AttnMaskType.padding + return gpt_layer_spec + + +def get_nv_embedding_layer_spec(config): + """Customized Layer Spec for NV Embedding Llama Model. + Bidirectional attention is enabled instead of causal masking. + """ + if HAVE_TE: + return _transformer_engine_layer_spec(config) + else: + return _local_layer_spec(config) + + +def nv_embedding_data_step(dataloder_iter) -> Dict[str, torch.Tensor]: + """Setup NVEmbedding Llama Model dataloader batch.""" + batch = next(dataloder_iter) + + _batch: dict + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + required_keys = set() + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.add("input_ids") + required_keys.add("position_ids") + + _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()} + # slice batch along sequence dimension for context parallelism + output = GPTBase.get_batch_on_this_context_parallel_rank(_batch) + + return output + + +def nv_embedding_forward_step(model: L.LightningModule, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + This subsets the batch keys to the ones actually used by forward pass of the model, + and then calls the model's forward pass. if "cu_seqsens" are defined in the batch, + then the packed sequence parameters are also passed to the model for forward pass efficiency. + """ + forward_args = { + "input_ids": batch["input_ids"], + "attention_mask": batch["attention_mask"], + "position_ids": batch["position_ids"], + } + emb = model.encode(**forward_args) + return emb + + +@dataclass +class Llama32EmbeddingConfig1B(Llama32Config1B): + """Llama3.2 Embedding 1B Config""" + + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = get_nv_embedding_layer_spec + forward_step_fn: Callable = nv_embedding_forward_step + data_step_fn: Callable = nv_embedding_data_step + + # Training Configs + truncation_method: Literal["left", "right"] = 'right' + num_hard_negatives: int = 4 + ce_loss_scale: float = 50 + label_smoothing: float = 0.0 + in_batch_negatives: bool = False + negative_sample_strategy: Literal["random", "first"] = 'first' + add_bos: bool = True + add_eos: bool = False + + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreGPTModel": + """Configure the NV Embedding Llama3.2 1B Model""" + model = super().configure_model(tokenizer, pre_process, post_process) + # post_process need to be overwritten to False after model init because + # final_layernorm is still needed and it will only be initialized when post_process is True in Mcore. + # And for forward(), we do not want to run through output_layer thus setting post_process to False. + model.post_process = False + return model + + +def _average_pool(last_hidden_states: Tensor, attention_mask: Tensor): + """Average the hidden states on the non-masking tokens.""" + # [sq, b, h] -> [b, sq, h] + last_hidden_states = einops.rearrange(last_hidden_states, 's b h -> b s h') + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +class LlamaEmbeddingModel(LlamaModel): + """NV Embedding Llama Model""" + + def __init__( + self, + config: Annotated[Optional[LlamaConfig], Config[LlamaConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + + @property + def dataset_kwargs(self): + """Getter for dataset_kwargs from model config""" + return { + 'num_hard_negatives': self.config.num_hard_negatives, + 'negative_sample_strategy': self.config.negative_sample_strategy, + 'add_bos': self.config.add_bos, + 'add_eos': self.config.add_eos, + } + + def encode( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + decoder_input: Optional[torch.Tensor] = None, + ): + """Generate the embedding for the inputs. + It runs the forward and apply average pooling on the last hidden states of the model. + """ + if attention_mask.ndim == 2: + # extend attention mask to [b, 1, 1, sq] + # Also convert attention mask to binary + extended_mask = attention_mask.unsqueeze(1).unsqueeze(1) < 0.5 + elif attention_mask.ndim == 4: + assert attention_mask.shape[1] == 1 and attention_mask.shape[2] == 1, "Attention mask shape incorrect" + extended_mask = attention_mask + # Squeeze attention mask to [b, sq] for averaging pooling later + + attention_mask = extended_mask.squeeze() < 0.5 + else: + raise ValueError("Attention_mask shape incorrect") + + output = self.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=extended_mask, + decoder_input=decoder_input, + ) + embeddings = _average_pool(output, attention_mask) + embeddings = F.normalize(embeddings, p=2, dim=1) + return embeddings + + @property + def training_loss_reduction(self) -> BERTInBatchExclusiveHardNegativesRankingLoss: # pylint: disable=C0115,C0116 + if not self._training_loss_reduction: + if self.config.in_batch_negatives: + loss_func = BERTInBatchExclusiveHardNegativesRankingLoss + else: + loss_func = HardNegativeRankingLoss + self._training_loss_reduction = loss_func( + validation_step=False, + num_hard_negatives=self.config.num_hard_negatives, + scale=self.config.ce_loss_scale, + label_smoothing=self.config.label_smoothing, + ) + + return self._training_loss_reduction + + @property + def validation_loss_reduction(self) -> BERTInBatchExclusiveHardNegativesRankingLoss: # pylint: disable=C0115,C0116 + if not self._validation_loss_reduction: + if self.config.in_batch_negatives: + loss_func = BERTInBatchExclusiveHardNegativesRankingLoss + else: + loss_func = HardNegativeRankingLoss + self._validation_loss_reduction = loss_func( + validation_step=True, + num_hard_negatives=self.config.num_hard_negatives, + scale=self.config.ce_loss_scale, + label_smoothing=self.config.label_smoothing, + ) + + return self._validation_loss_reduction + + +@io.model_importer(LlamaEmbeddingModel, "hf") +class LlamaEmbeddingImporter(HFLlamaImporter): + """HF Importer for Llama Embedding Model""" + + def init(self) -> LlamaEmbeddingModel: + return LlamaEmbeddingModel(self.config, tokenizer=self.tokenizer) + + @property + def config(self) -> Llama32Config1B: + # pylint : disable=C0116 + from transformers import LlamaConfig as HFLlamaConfig + + source = HFLlamaConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = Llama32EmbeddingConfig1B( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=getattr(source, "tie_word_embeddings", False), + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + + return output + + +@io.model_exporter(LlamaEmbeddingModel, "hf") +class LlamaEmbeddingExporter(io.ModelConnector[LlamaEmbeddingModel, "LlamaBidirectionalModel"]): + """HF Exporter for NV Embedding Llama Model. + Note that NV Embedding LLama uses customized LlamaBidirectionalConfig config. + """ + + def init(self, dtype=torch.bfloat16) -> "LlamaForCausalLM": + from transformers.modeling_utils import no_init_weights + + from nemo.collections.llm.gpt.model.hf_llama_embedding import LlamaBidirectionalModel + + LlamaBidirectionalModel.register_for_auto_class("AutoModel") + with no_init_weights(True): + return LlamaBidirectionalModel._from_config(self.config, torch_dtype=dtype) + + def apply(self, output_path: Path) -> Path: + source, _ = self.nemo_load(str(self)) + source_dtype = source.module.embedding.word_embeddings.weight.dtype + target = self.init(source_dtype) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + try: + tokenizer = self.tokenizer.tokenizer + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = source.config.truncation_method + + tokenizer.save_pretrained(output_path) + except Exception: + logging.warning("Failed to save tokenizer") + + return output_path + + @property + def config(self): + """Get HF NV Embedding Llama Config.""" + source: LlamaConfig = io.load_context(str(self), subpath="model.config") + + from nemo.collections.llm.gpt.model.hf_llama_embedding import LlamaBidirectionalConfig + + LlamaBidirectionalConfig.register_for_auto_class("AutoConfig") + return LlamaBidirectionalConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=self.tokenizer.vocab_size, + tie_word_embeddings=source.share_embeddings_and_output_weights, + ) + + def convert_state(self, source, target): + """Convert NeMo State dict to HF.""" + mapping = { + "decoder.layers.*.self_attention.linear_proj.weight": "layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "norm.weight", + } + transforms = [_export_qkv, _export_linear_fc1, _export_embedding] + + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=transforms, + ) + + @property + def tokenizer(self) -> "TokenizerSpec": + """Get NeMo Tokenizer""" + return io.load_context(str(self), subpath="model").tokenizer + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "layers.*.self_attn.q_proj.weight", + "layers.*.self_attn.k_proj.weight", + "layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_size = megatron_config.kv_channels + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key="embedding.word_embeddings.weight", + target_key="embed_tokens.weight", +) +def _export_embedding(ctx: io.TransformCTX, embedding): + megatron_config = ctx.target.config + # prune padding. + return embedding[: megatron_config.vocab_size, :] + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("layers.*.mlp.gate_proj.weight", "layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 09291e4165be..0892bb10f16b 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -38,6 +38,7 @@ llama31_405b, llama32_1b, llama32_3b, + llama_embedding_1b, mamba2_1_3b, mamba2_2_7b, mamba2_8b, diff --git a/nemo/collections/llm/recipes/llama_embedding_1b.py b/nemo/collections/llm/recipes/llama_embedding_1b.py new file mode 100644 index 000000000000..4a26fcc563d3 --- /dev/null +++ b/nemo/collections/llm/recipes/llama_embedding_1b.py @@ -0,0 +1,286 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm import Llama32EmbeddingConfig1B, LlamaEmbeddingModel +from nemo.collections.llm.api import finetune +from nemo.collections.llm.peft import PEFT_STR2CLS +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "nvembed_llama_1b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a NVEmbed Llama3.2 1B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the NVEmbed Llama3.2 1B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=nvembed_llama_1b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(LlamaEmbeddingModel, config=run.Config(Llama32EmbeddingConfig1B)) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for NVEmbed Llama3.2 1B model. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=nvembed_llama_1b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + For more information on distributed training strategies, refer to the + NeMo documentation on multi-GPU and multi-node training. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + resume_path: str = "meta-llama/Llama-3.2-1B", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + micro_batch_size: int = 4, + global_batch_size: int = 64, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, +) -> run.Partial: + """ + Create a fine-tuning recipe for NVEmbed Llama3.2 1B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + resume_path (str): Path to the Huggingface model or pretrained distributed checkpoint for resume + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + micro_batch_size (int): Size of micro batch. + global_batch_size (int): Size of global batch. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. pack sequence is not supported for embedding model training. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory nvembed_llama_1b + + Python API usage: + >>> recipe = finetune_recipe(name="nvembed_llama_1b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SPECTER dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + if seq_length is None: + seq_length = 512 + + assert packed_sequence is None, 'pack_sequence is not supported for Embedding model finetuning.' + recipe = default_finetune_recipe(model(), resume_path, dir, name, num_nodes, num_gpus_per_node, packed_sequence) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + # Use Specter Dataset as the default for finetuning + recipe.data = run.Config( + llm.SpecterDataModule, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + dataset_kwargs={ + 'num_hard_negatives': recipe.model.config.num_hard_negatives, + 'negative_sample_strategy': recipe.model.config.negative_sample_strategy, + 'add_bos': recipe.model.config.add_bos, + 'add_eos': recipe.model.config.add_eos, + }, + ) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + else: + recipe.peft.target_modules = ['linear_qkv'] + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py b/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py index 8bca618dce3d..0da7af6ed96d 100644 --- a/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py +++ b/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. from random import choices, sample -from typing import Mapping, Optional +from typing import Literal, Mapping, Optional import datasets import numpy as np @@ -32,6 +32,10 @@ class BertEmbeddingDataset(Dataset): + """ + Embedding Dataset Class. + """ + def __init__( self, file_path: str, @@ -49,19 +53,28 @@ def __init__( special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} data_type: str = 'train', # train, query or doc num_hard_negatives: int = 4, + negative_sample_strategy: Literal["random", "first"] = 'first', ): """ file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. - tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). - max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. - min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec + (ex: YTTM, SentencePiece). + max_seq_length (int): maximum sequence length for each dataset examples. + Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. + Data examples will be dropped if they do not meet the min length requirements. add_bos (bool): Whether to add a beginning of sentence token to each data example add_eos (bool): Whether to add an end of sentence token to each data example seed: Random seed for data shuffling. - max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. - index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + max_num_samples: Maximum number of samples to load. This can be > dataset length + if you want to oversample data. If None, all samples will be loaded. + index_mapping_dir: Directory to save the index mapping to. + If None, will write to the same folder as the dataset. truncation_method: Truncation from which position. Options: ['left', 'right'] - special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. + Default: {'system_turn_start': '', 'turn_start': '', + 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + negative_sample_strategy: Strategy for negative samples. Options: ['random', 'first'] """ # TODO: lot of copy-paste from GPTSFDDataset, should refactor both to use a common base class (@adithyare) self.tokenizer = tokenizer @@ -75,6 +88,14 @@ def __init__( self.index_mapping_dir = index_mapping_dir self.virtual_tokens = virtual_tokens self.truncation_method = truncation_method + self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id else self.tokenizer.eos_id + self.negative_sample_strategy = negative_sample_strategy + assert ( + truncation_method == 'left' or truncation_method == 'right' + ), 'truncation_method must be either "left" or "right"' + assert ( + negative_sample_strategy == 'random' or negative_sample_strategy == 'first' + ), 'negative_sample_strategy must be either "random" or "first"' if special_tokens is None: self.special_tokens = { "system_turn_start": "", @@ -98,6 +119,13 @@ def __init__( # Will be None after this call if `max_num_samples` is None self.samples_mapping = None self._build_samples_mapping() + logging.info( + f"Creating EmbeddingDataset with seed={self.seed},\n" + f"add_bos={self.add_bos}, add_eos={self.add_eos},\n" + f"max_seq_length={self.max_seq_length}, min_seq_length={self.min_seq_length},\n" + f"pad_token_id={self.pad_token_id}, negative_sample_strategy={self.negative_sample_strategy},\n" + f"num_hard_negatives={self.num_hard_negatives}." + ) def _build_samples_mapping(self): if self.max_num_samples is not None: @@ -169,8 +197,13 @@ def _process_example(self, example): # sample rest with replacement nd = nd + choices(example['neg_doc'], k=self.num_hard_negatives - len(example['neg_doc'])) else: - # sample without replacement - nd = sample(example['neg_doc'], k=self.num_hard_negatives) + if self.negative_sample_strategy == 'random': + # sample without replacement + # Choose the first self.num_hard_negatives + nd = sample(example['neg_doc'], k=self.num_hard_negatives) + else: + # Choose the first self.num_hard_negatives samples + nd = example['neg_doc'][: self.num_hard_negatives] assert len(nd) == self.num_hard_negatives, "Error in sampling required number of hard negatives" nd = [self.tokenizer.text_to_ids("passage: " + ex.strip()) for ex in nd] @@ -228,27 +261,17 @@ def _maybe_cast_to_list(self, x): def _ceil_to_nearest(self, n, m): return (n + m - 1) // m * m - def _collate_item(self, item, max_length, pad_id): + def _collate_item(self, item, max_length): item = self._maybe_cast_to_list(item) - # max_length = max([len(x) for x in item]) if item else 0 - # here [0] should be tokenizer.pad_id - item = [x + [pad_id] * (max_length - len(x)) for x in item] + pad_id = self.pad_token_id + if self.truncation_method == 'left': + item = [[pad_id] * (max_length - len(x)) + x for x in item] + else: + item = [x + [pad_id] * (max_length - len(x)) for x in item] return item @torch.no_grad() - def _create_attention_mask(self, max_length): - """Create `attention_mask`. - Args: - input_ids: A 1D tensor that holds the indices of tokens. - """ - # seq_length = len(input_ids) - # `attention_mask` has the shape of [1, seq_length, seq_length] - attention_mask = torch.tril(torch.ones((max_length, max_length))).unsqueeze(0) - attention_mask = attention_mask < 0.5 - return attention_mask - - @torch.no_grad() - def _create_attention_mask2(self, max_length, item_lengh): + def _create_attention_mask2(self, max_length, item_length): """Create `attention_mask`. Args: input_ids: A 1D tensor that holds the indices of tokens. @@ -256,10 +279,20 @@ def _create_attention_mask2(self, max_length, item_lengh): # seq_length = len(input_ids) # `attention_mask` has the shape of [1, seq_length, seq_length] attention_mask = torch.zeros(max_length) - attention_mask[:item_lengh] = 1 + if self.truncation_method == 'left': + # input ids: [pad] [pad] token token | + # attention mask: 0 0 1 1 + attention_mask[max_length - item_length :] = 1 + else: + # input ids: token token [pad] [pad] | + # attention mask: 1 1 0 0 + attention_mask[:item_length] = 1 return attention_mask - def collate_fn(self, batch): + def _collate_fn(self, batch): + """ + Collate query passage together + """ input_ids = [] metadata = [] lengths = [] @@ -295,7 +328,7 @@ def collate_fn(self, batch): attention_mask = torch.stack(attention_mask) position_ids = [list(range(max_length)) for _ in batch] position_ids = torch.LongTensor(position_ids) - input_ids = torch.LongTensor(self._collate_item(input_ids, max_length=max_length, pad_id=0)) + input_ids = torch.LongTensor(self._collate_item(input_ids, max_length=max_length)) lengths = torch.LongTensor(lengths) - 1 # subtract 1 to account for the eos token processed_batch = { @@ -303,6 +336,7 @@ def collate_fn(self, batch): 'token_type_ids': torch.zeros_like(input_ids), 'attention_mask': attention_mask, 'metadata': metadata, + 'position_ids': position_ids, } return processed_batch