Skip to content

Commit ef2f8c2

Browse files
authored
Merge pull request #344 from Modalities/333-add-torch-compile-uni-test
333 add torch compile uni test
2 parents 0035ff0 + ccef95b commit ef2f8c2

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

src/modalities/models/model_factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,15 @@ def get_parent_module_and_child_name(child_module: nn.Module, model: nn.Module)
266266
return parent_candidate, child_name
267267
raise ModelStateError("No valid parent candidate")
268268

269-
block_types = tuple([get_module_class_from_name(model, b) for b in block_names])
269+
block_types = []
270+
for name in block_names:
271+
module_class = get_module_class_from_name(model, name)
272+
if module_class is not None:
273+
block_types.append(module_class)
274+
else:
275+
raise ValueError("None of the provided block_names match any modules in the model")
276+
277+
block_types = tuple(block_types)
270278

271279
for _, module in model.named_modules():
272280
if isinstance(module, block_types):

tests/test_torch_compile.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import pytest
2+
import torch.nn as nn
3+
4+
from modalities.models.gpt2.gpt2_model import (
5+
GPT2LLM,
6+
ActivationType,
7+
AttentionConfig,
8+
AttentionImplementation,
9+
LayerNorms,
10+
LayerNormWrapperConfig,
11+
PositionTypes,
12+
QueryKeyValueTransformType,
13+
)
14+
from modalities.models.model_factory import ModelFactory
15+
16+
17+
def create_gpt2_configs():
18+
attention_config = AttentionConfig(
19+
qkv_transforms=[
20+
AttentionConfig.QueryKeyValueTransformConfig(
21+
type_hint=QueryKeyValueTransformType.RotaryTransform.name,
22+
config=AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig(
23+
n_embd=512, n_head=8, seq_length_dim=-2, base_freq=10000
24+
),
25+
)
26+
]
27+
)
28+
norm_config = LayerNormWrapperConfig(norm_type=LayerNorms.layer_norm, config={"normalized_shape": 512})
29+
return attention_config, norm_config
30+
31+
32+
@pytest.fixture
33+
def gpt2_model():
34+
attention_config, norm_config = create_gpt2_configs()
35+
model = GPT2LLM(
36+
sample_key="input_ids",
37+
prediction_key="logits",
38+
poe_type=PositionTypes.NOPE,
39+
sequence_length=256,
40+
vocab_size=1024,
41+
n_layer=4,
42+
n_head_q=8,
43+
n_head_kv=4,
44+
n_embd=512,
45+
ffn_hidden=2048,
46+
dropout=0.1,
47+
bias=True,
48+
activation_type=ActivationType.SWIGLU,
49+
attention_implementation=AttentionImplementation.PYTORCH_FLASH,
50+
attention_config=attention_config,
51+
attention_norm_config=norm_config,
52+
ffn_norm_config=norm_config,
53+
lm_head_norm_config=norm_config,
54+
use_weight_tying=True,
55+
)
56+
return model
57+
58+
59+
def test_get_compiled_model_compiles_blocks(gpt2_model):
60+
original_blocks = list(gpt2_model.transformer.h)
61+
original_wte = gpt2_model.transformer.wte
62+
original_lm_head = gpt2_model.transformer.lm_head
63+
64+
block_names = ["GPT2Block"]
65+
result_model = ModelFactory.get_compiled_model(gpt2_model, block_names)
66+
67+
assert len(result_model.transformer.h) == 4, "Should still have four blocks"
68+
for i, (original_block, new_block) in enumerate(zip(original_blocks, result_model.transformer.h)):
69+
assert new_block is not original_block, f"Block {i} should be a compiled version"
70+
assert isinstance(new_block, nn.Module), f"Block {i} should be an nn.Module"
71+
assert result_model.transformer.wte is original_wte, "Embedding layer should remain unchanged"
72+
assert result_model.transformer.lm_head is original_lm_head, "LM head should remain unchanged"
73+
assert result_model is gpt2_model, "Should return the same model instance"
74+
75+
76+
def test_get_compiled_model_no_matching_blocks(gpt2_model):
77+
"""
78+
Test that get_compiled_model raises a ValueError if no blocks match the specified types.
79+
"""
80+
with pytest.raises(ValueError, match="None of the provided block_names match any modules in the model"):
81+
ModelFactory.get_compiled_model(gpt2_model, block_names=["Conv2d"])
82+
83+
84+
def test_get_compiled_model_empty_block_names(gpt2_model):
85+
original_model_dict = dict(gpt2_model.named_modules())
86+
result_model = ModelFactory.get_compiled_model(gpt2_model, block_names=[])
87+
88+
new_model_dict = dict(result_model.named_modules())
89+
assert new_model_dict == original_model_dict, "Model should remain unchanged with empty block_names"
90+
assert result_model is gpt2_model, "Should return the same model instance"

0 commit comments

Comments
 (0)