Skip to content

Commit cee7692

Browse files
committed
Add huggingface from_pretrained / save_pretrained tests
Adds integration tests to ensure models containing TransformerLayer objects can be saved and loaded using the from_pretrained and save_pretrained methods. Signed-off-by: Peter St. John <[email protected]>
1 parent d2973cb commit cee7692

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
123123
)
124124
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
125125
# install_reqs.append("triton")
126-
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
126+
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML", "transformers"])
127127
if "jax" in frameworks:
128128
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
129129
install_reqs.extend(["jax", "flax>=0.7.1"])

tests/pytorch/test_hf_integration.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
import torch
3+
from transformers.configuration_utils import PretrainedConfig
4+
from transformers.modeling_utils import PreTrainedModel
5+
6+
from transformer_engine.pytorch.transformer import TransformerLayer
7+
from transformer_engine.pytorch.utils import is_bf16_compatible
8+
9+
param_types = [torch.float32, torch.float16]
10+
if is_bf16_compatible(): # bf16 requires sm_80 or higher
11+
param_types.append(torch.bfloat16)
12+
13+
14+
all_activations = ["gelu", "relu"]
15+
all_normalizations = ["LayerNorm", "RMSNorm"]
16+
17+
18+
@pytest.mark.parametrize("dtype", param_types)
19+
@pytest.mark.parametrize("activation", all_activations)
20+
@pytest.mark.parametrize("normalization", all_normalizations)
21+
def test_save_and_load_hf_model(tmp_path, dtype, activation, normalization):
22+
class SimpleTEModel(PreTrainedModel):
23+
config_class = PretrainedConfig
24+
25+
def __init__(self, config: PretrainedConfig):
26+
super().__init__(config)
27+
self.my_layer = TransformerLayer(
28+
hidden_size=320,
29+
num_attention_heads=16,
30+
ffn_hidden_size=1024,
31+
layer_number=None,
32+
params_dtype=dtype,
33+
activation=activation,
34+
normalization=normalization,
35+
)
36+
37+
def forward(self, hidden_states, attention_mask):
38+
return self.my_layer(hidden_states, attention_mask)
39+
40+
model = SimpleTEModel(PretrainedConfig())
41+
42+
model.save_pretrained(tmp_path / "simple_te_model")
43+
del model
44+
SimpleTEModel.from_pretrained(tmp_path / "simple_te_model")

0 commit comments

Comments
 (0)