Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 3, 2025
1 parent 447b5ad commit f87a78f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
3 changes: 1 addition & 2 deletions examples/quantization_w4a16/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from llmcompressor.transformers import oneshot

# Select model and load it.
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
Expand Down
8 changes: 3 additions & 5 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from loguru import logger
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data.data_collator import (
DataCollatorWithPadding,
Expand All @@ -13,7 +12,6 @@

from llmcompressor.typing import Processor

LOGGER = logging.getLogger(__name__)
LABELS_MASK_VALUE = -100

__all__ = [
Expand Down Expand Up @@ -56,7 +54,7 @@ def format_calibration_data(
if num_calibration_samples is not None:
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
if safe_calibration_samples != num_calibration_samples:
LOGGER.warn(
logger.warning(
f"Requested {num_calibration_samples} calibration samples but "
f"the provided dataset only has {safe_calibration_samples}. "
)
Expand All @@ -68,7 +66,7 @@ def format_calibration_data(
if hasattr(tokenizer, "pad"):
collate_fn = DataCollatorWithPadding(tokenizer)
else:
warnings.warn(
logger.warning(
"Could not find processor, attempting to collate with without padding "
"(may fail for batch_size > 1)"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import pytest
import torch
from datasets import Dataset
from transformers import AutoTokenizer

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
get_raw_dataset,
make_dataset_splits,
)
from llmcompressor.transformers.finetune.text_generation import configure_processor


@pytest.mark.unit
Expand Down Expand Up @@ -60,15 +62,42 @@ def test_separate_datasets():


@pytest.mark.unit
def test_format_calibration_data():
tokenized_dataset = Dataset.from_dict(
{"input_ids": torch.randint(0, 512, (8, 2048))}
def test_format_calibration_data_padded_tokenized():
vocab_size = 512
seq_len = 2048
ds_size = 16
padded_tokenized_dataset = Dataset.from_dict(
{"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))}
)

calibration_dataloader = format_calibration_data(
tokenized_dataset, num_calibration_samples=4, batch_size=2
padded_tokenized_dataset, num_calibration_samples=8, batch_size=4
)

batch = next(iter(calibration_dataloader))
assert batch["input_ids"].size(0) == 4


@pytest.mark.unit
def test_format_calibration_data_unpaddded_tokenized():
vocab_size = 512
ds_size = 16
unpadded_tokenized_dataset = Dataset.from_dict(
{
"input_ids": [
torch.randint(0, vocab_size, (seq_len,)) for seq_len in range(ds_size)
]
}
)
processor = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
configure_processor(processor)

calibration_dataloader = format_calibration_data(
unpadded_tokenized_dataset,
num_calibration_samples=8,
batch_size=4,
processor=processor,
)

batch = next(iter(calibration_dataloader))
assert batch["input_ids"].size(0) == 2

0 comments on commit f87a78f

Please sign in to comment.