From f87a78f246e16f7e204fa8d5ea1aad7db1eb5122 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 23:24:16 +0000 Subject: [PATCH] add tests Signed-off-by: Kyle Sayers --- examples/quantization_w4a16/llama3_example.py | 3 +- .../finetune/data/data_helpers.py | 8 ++-- .../finetune/data/test_dataset_helpers.py | 37 +++++++++++++++++-- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index f8ef395bd..9a606ce53 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -5,8 +5,7 @@ from llmcompressor.transformers import oneshot # Select model and load it. -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 97f217050..065ffc512 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,10 +1,9 @@ -import logging import os -import warnings from typing import Any, Callable, Dict, List, Optional import torch from datasets import Dataset, load_dataset +from loguru import logger from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data.data_collator import ( DataCollatorWithPadding, @@ -13,7 +12,6 @@ from llmcompressor.typing import Processor -LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 __all__ = [ @@ -56,7 +54,7 @@ def format_calibration_data( if num_calibration_samples is not None: safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) if safe_calibration_samples != num_calibration_samples: - LOGGER.warn( + logger.warning( f"Requested {num_calibration_samples} calibration samples but " f"the provided dataset only has {safe_calibration_samples}. " ) @@ -68,7 +66,7 @@ def format_calibration_data( if hasattr(tokenizer, "pad"): collate_fn = DataCollatorWithPadding(tokenizer) else: - warnings.warn( + logger.warning( "Could not find processor, attempting to collate with without padding " "(may fail for batch_size > 1)" ) diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 2c936a363..98e98c00f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -2,6 +2,7 @@ import pytest import torch from datasets import Dataset +from transformers import AutoTokenizer from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( @@ -9,6 +10,7 @@ get_raw_dataset, make_dataset_splits, ) +from llmcompressor.transformers.finetune.text_generation import configure_processor @pytest.mark.unit @@ -60,15 +62,42 @@ def test_separate_datasets(): @pytest.mark.unit -def test_format_calibration_data(): - tokenized_dataset = Dataset.from_dict( - {"input_ids": torch.randint(0, 512, (8, 2048))} +def test_format_calibration_data_padded_tokenized(): + vocab_size = 512 + seq_len = 2048 + ds_size = 16 + padded_tokenized_dataset = Dataset.from_dict( + {"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))} ) calibration_dataloader = format_calibration_data( - tokenized_dataset, num_calibration_samples=4, batch_size=2 + padded_tokenized_dataset, num_calibration_samples=8, batch_size=4 ) batch = next(iter(calibration_dataloader)) + assert batch["input_ids"].size(0) == 4 + +@pytest.mark.unit +def test_format_calibration_data_unpaddded_tokenized(): + vocab_size = 512 + ds_size = 16 + unpadded_tokenized_dataset = Dataset.from_dict( + { + "input_ids": [ + torch.randint(0, vocab_size, (seq_len,)) for seq_len in range(ds_size) + ] + } + ) + processor = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") + configure_processor(processor) + + calibration_dataloader = format_calibration_data( + unpadded_tokenized_dataset, + num_calibration_samples=8, + batch_size=4, + processor=processor, + ) + + batch = next(iter(calibration_dataloader)) assert batch["input_ids"].size(0) == 2