Skip to content

Commit

Permalink
docstrings added
Browse files Browse the repository at this point in the history
Signed-off-by: Onur Yilmaz <[email protected]>
  • Loading branch information
oyilmaz-nvidia committed Feb 21, 2025
1 parent fa68758 commit 78453ea
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 21 deletions.
11 changes: 11 additions & 0 deletions nemo/collections/llm/gpt/model/hf_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Literal, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import AutoModel, AutoTokenizer
Expand Down Expand Up @@ -208,6 +209,8 @@ def __init__(

@property
def device(self) -> torch.device:
"""Returns the device"""

return self.model.device

def forward(
Expand All @@ -217,6 +220,8 @@ def forward(
token_type_ids: Optional[torch.Tensor] = None,
dimensions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Inference for the adapted Llama model"""

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
Expand Down Expand Up @@ -251,11 +256,15 @@ def forward(


class Pooling(torch.nn.Module):
"""Pooling layer for the adapter."""

def __init__(self, pooling_mode: str):
super().__init__()
self.pooling_mode = pooling_mode

def forward(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Forward function of the Pooling layer."""

last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)

pool_type = self.pooling_mode
Expand Down Expand Up @@ -288,6 +297,8 @@ def get_llama_bidirectional_hf_model(
torch_dtype: Optional[Union[torch.dtype, str]] = None,
trust_remote_code: bool = False,
):
"""Returns the adapter for the Llama bidirectional HF model."""

# check that the tokenizer matches the requirements of the pooling mode
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
pooling_mode = pooling_mode or "avg"
Expand Down
85 changes: 64 additions & 21 deletions nemo/export/onnx_llm_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,43 @@ def wrapper(*args, **kwargs):
# pylint: disable=line-too-long
class OnnxLLMExporter(ITritonDeployable):
"""
Exports nemo checkpoints to TensorRT-LLM and run fast inference.
Exports models to ONNX and run fast inference.
Example:
from nemo.export.onnx_llm_exporter import OnnxLLMExporter
trt_llm_exporter = OnnxLLMExporter(model_dir="/path/for/model/files")
trt_llm_exporter.export(
nemo_checkpoint_path="/path/for/nemo/checkpoint",
onnx_llm_exporter = OnnxLLMExporter(
onnx_model_dir="/path/for/onnx_model/files",
model_name_or_path="/path/for/model/files",
)
onnx_llm_exporter.export(
input_names=["input_ids", "attention_mask", "dimensions"],
output_names=["embeddings"],
)
output = trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])
print("output: ", output)
"""

def __init__(
self,
onnx_model_dir: str,
model=None,
model: Optional[torch.nn.Module] = None,
tokenizer=None,
model_name_or_path: str = None,
load_runtime: bool = True,
):
"""
Initializes the ONNX Exporter.
Args:
model_dir (str): path for storing the ONNX model files.
onnx_model_dir (str): path for storing the ONNX model files.
model (Optional[torch.nn.Module]): torch model.
tokenizer (HF or NeMo tokenizer): tokenizer class.
model_name_or_path (str): a path for ckpt or HF model ID
load_runtime (bool): load ONNX runtime if there is any exported model available in
the onnx_model_dir folder.
"""
self.onnx_model_dir = onnx_model_dir
self.model_name_or_path = model_name_or_path
Expand Down Expand Up @@ -141,7 +152,19 @@ def export(
export_dtype: str = "fp32",
verbose: bool = False,
):
"""Performs ONNX conversion from a PyTorch model."""
"""
Performs ONNX conversion from a PyTorch model.
Args:
input_names (list): input parameter names of the model that ONNX will export will use.
output_names (list): output parameter names of the model that ONNX will export will use.
example_inputs (dict): example input for the model to build the engine.
opset (int): ONNX opset version. Default is 20.
dynamic_axes_input (dict): Variable length axes for the input.
dynamic_axes_output (dict): Variable length axes for the output.
export_dtype (str): Export dtype, fp16 or fp32.
verbose (bool): Enable verbose or not.
"""

self._export_to_onnx(
input_names=input_names,
Expand Down Expand Up @@ -197,17 +220,22 @@ def _export_to_onnx(

def export_onnx_to_trt(
self,
trt_model_path,
trt_model_path: str,
profiles=None,
override_layernorm_precision_to_fp32=False,
override_layers_to_fp32=None,
trt_dtype="fp16",
profiling_verbosity="layer_names_only",
override_layernorm_precision_to_fp32: bool = False,
override_layers_to_fp32: List = None,
trt_dtype: str = "fp16",
profiling_verbosity: str = "layer_names_only",
) -> None:
"""Performs TensorRT conversion from an ONNX model.
Raises:
SerializationError: TensorRT engine must serialize properly.
Args:
trt_model_path: path to store the TensorRT model.
profiles: TensorRT profiles.
override_layernorm_precision_to_fp32 (bool): whether to convert layers to fp32 or not.
override_layers_to_fp32 (List): Layer names to be converted to fp32.
trt_dtype (str): "fp16" or "fp32".
profiling_verbosity (str): Profiling verbosity. Default is "layer_names_only".
"""
logging.info(f"Building TRT engine from ONNX model ({self.onnx_model_path})")
trt_logger = trt.Logger(trt.Logger.WARNING)
Expand Down Expand Up @@ -277,11 +305,6 @@ def export_onnx_to_trt(
logging.info(f"Successfully exported ONNX model ({self.onnx_model_path}) " f"to TRT engine ({trt_model_path})")

def _override_layer_precision_to_fp32(self, layer: trt.ILayer) -> None:
"""Set TensorRT layer precision and output type to FP32.
Args:
layer: tensorrt.ILayer.
"""
layer.precision = trt.float32
layer.set_output_type(0, trt.float32)

Expand Down Expand Up @@ -376,7 +399,14 @@ def ptq(
quantization_type="fp8",
max_batch_size=32,
) -> None:
"""Runs a calibration loop on the model using a calibration dataset."""
"""
Runs a calibration loop on the model using a calibration dataset.
Args:
calibration_data (str): path for the calibration data for PTQ.
quantization_type (str): Choices are int8, int8_sq, fp8, int4_awq, w4a8_awq.
max_batch_size (int): max batch size.
"""

if Path(self.model_name_or_path).is_dir():
if is_nemo2_checkpoint(self.model_name_or_path):
Expand Down Expand Up @@ -432,6 +462,8 @@ def chunk(iterable: List, n: int) -> Generator:
pbar.refresh()

def forward(self, prompt: Union[List, Dict]):
"""Run inference for a given prompt."""

if self.onnx_runtime_session is None:
warnings.warn("ONNX Runtime is not available. Please install " "the onnxruntime-gpu and try again.")
return None
Expand All @@ -445,25 +477,36 @@ def forward(self, prompt: Union[List, Dict]):

@property
def get_model(self):
"""Returns the model"""

return self.model

@property
def get_tokenizer(self):
"""Returns the tokenizer"""

return self.tokenizer

@property
def get_model_input_names(self):
"""Returns the model input names"""

return self.model_input_names

@property
def get_triton_input(self):
"""Get triton input"""

raise NotImplementedError("This function will be implemented later.")

@property
def get_triton_output(self):
"""Get triton output"""

raise NotImplementedError("This function will be implemented later.")

@batch
def triton_infer_fn(self, **inputs: np.ndarray):
"""PyTriton inference function"""

raise NotImplementedError("This function will be implemented later.")
1 change: 1 addition & 0 deletions nemo/export/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import Counter
from pathlib import Path
from typing import Dict, Union

Expand Down

0 comments on commit 78453ea

Please sign in to comment.