Skip to content

Commit a850e4c

Browse files
FThompsonAWSawsjoshir
authored andcommitted
Merge pull request #3 from aws-neuron/release_2.21.0
Neuron 2.21 release
1 parent c346979 commit a850e4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+5999
-133
lines changed

examples/dog.jpg

39.3 KB
Loading

examples/generation_mllama.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch
2+
import os
3+
4+
from transformers import AutoTokenizer, GenerationConfig
5+
6+
from neuronx_distributed_inference.models.config import MultimodalVisionNeuronConfig, OnDeviceSamplingConfig
7+
from neuronx_distributed_inference.models.mllama.modeling_mllama import MllamaInferenceConfig, NeuronMllamaForCausalLM
8+
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter
9+
from neuronx_distributed_inference.models.mllama.model_wrapper_mllama import NUM_IMAGE_PER_PROMPT
10+
from neuronx_distributed_inference.models.mllama.utils import create_vision_mask, get_image, get_image_tensors, add_instruct
11+
from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params
12+
from neuronx_distributed_inference.utils.benchmark import benchmark_sampling
13+
14+
# TODO : Either read from os_environment var or from arg_parser.
15+
checkpoint = "meta"
16+
model_variant = "11B"
17+
model_path = f"/home/ubuntu/models/Llama-3.2-{model_variant}-Vision-Instruct-{checkpoint}/"
18+
traced_model_path = f"/home/ubuntu/workplace/traced_models/Llama-3.2-{model_variant}-Vision-Instruct-{checkpoint}/"
19+
20+
torch.manual_seed(0)
21+
22+
23+
def run_llama_generate():
24+
# Initialize configs and tokenizer.
25+
batch_size = 1
26+
num_img_per_prompt = 1
27+
max_context_length = 1024
28+
seq_len = 2048
29+
30+
generation_config = GenerationConfig.from_pretrained(model_path)
31+
generation_config_kwargs = {
32+
"top_k": 1,
33+
}
34+
generation_config.update(**generation_config_kwargs)
35+
36+
on_device_sampling_config=OnDeviceSamplingConfig(
37+
dynamic=True,
38+
)
39+
40+
neuron_config = MultimodalVisionNeuronConfig(
41+
tp_degree=32,
42+
batch_size=batch_size,
43+
max_context_length=max_context_length,
44+
seq_len=seq_len,
45+
on_device_sampling_config=on_device_sampling_config,
46+
enable_bucketing=True,
47+
sequence_parallel_enabled=False,
48+
fused_qkv=False,
49+
async_mode=False,
50+
)
51+
config = MllamaInferenceConfig(
52+
neuron_config,
53+
load_config=load_pretrained_config(model_path),
54+
)
55+
56+
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
57+
tokenizer.pad_token = tokenizer.eos_token
58+
59+
60+
# Generate outputs.
61+
image = get_image("dog.jpg")
62+
batch_image = [[image] * num_img_per_prompt] * batch_size
63+
pixel_values, aspect_ratios, num_chunks, has_image = get_image_tensors(config, batch_image)
64+
65+
prompt = add_instruct("What is in this image? Tell me a story", has_image)
66+
batch_prompt = [prompt] * batch_size
67+
68+
if not os.path.exists(traced_model_path):
69+
# Compile and save model.
70+
print("\nCompiling and saving model...")
71+
model = NeuronMllamaForCausalLM(model_path, config)
72+
model.compile(traced_model_path)
73+
tokenizer.save_pretrained(traced_model_path)
74+
75+
# Load from compiled checkpoint.
76+
print("\nLoading model from compiled checkpoint...")
77+
model = NeuronMllamaForCausalLM(traced_model_path)
78+
model.load(traced_model_path)
79+
tokenizer = AutoTokenizer.from_pretrained(traced_model_path)
80+
81+
print("\nGenerating outputs...")
82+
print(f"Prompts: {batch_prompt}")
83+
84+
inputs = tokenizer(batch_prompt, padding=True, return_tensors="pt", add_special_tokens=False)
85+
86+
vision_token_id = tokenizer("<|image|>", add_special_tokens=False).input_ids[0]
87+
vision_mask = create_vision_mask(inputs.input_ids, vision_token_id)
88+
89+
generation_model = HuggingFaceGenerationAdapter(model)
90+
91+
# Test Sampling Parameters
92+
sampling_params = prepare_sampling_params(batch_size=batch_size, top_k=[1], top_p=[1.0], temperature=[1.0])
93+
outputs = generation_model.generate(
94+
inputs.input_ids,
95+
generation_config=generation_config,
96+
attention_mask=inputs.attention_mask,
97+
max_length=model.config.neuron_config.max_length,
98+
sampling_params=sampling_params,
99+
pixel_values=pixel_values,
100+
aspect_ratios=aspect_ratios,
101+
vision_mask =vision_mask,
102+
num_chunks=num_chunks,
103+
has_image=has_image,
104+
max_new_tokens=512,
105+
)
106+
output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
107+
108+
print("Generated outputs:")
109+
for i, output_token in enumerate(output_tokens):
110+
print(f"Output {i}: {output_token}")
111+
112+
113+
# Test with text-only input
114+
pixel_values, aspect_ratios, num_chunks, has_image = get_image_tensors(config, [[]] * batch_size)
115+
116+
prompt = add_instruct("what is the recipe of mayonnaise in two sentences?", has_image)
117+
batch_prompt = [prompt] * batch_size
118+
inputs = tokenizer(batch_prompt, padding=True, return_tensors="pt")
119+
120+
sampling_params = prepare_sampling_params(batch_size=batch_size, top_k=[1], top_p=[1.0], temperature=[1.0])
121+
outputs = generation_model.generate(
122+
inputs.input_ids,
123+
generation_config=generation_config,
124+
attention_mask=inputs.attention_mask,
125+
max_length=model.config.neuron_config.max_length,
126+
sampling_params=sampling_params,
127+
pixel_values=pixel_values,
128+
aspect_ratios=aspect_ratios,
129+
vision_mask=vision_mask,
130+
num_chunks=num_chunks,
131+
has_image=has_image,
132+
max_new_tokens=512,
133+
)
134+
output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
135+
136+
print("Generated outputs:")
137+
for i, output_token in enumerate(output_tokens):
138+
print(f"Output {i}: {output_token}")
139+
140+
print("\nPerformance Benchmarking!")
141+
benchmark_sampling(model=model, draft_model=None, generation_config=generation_config, target="all", image=True)
142+
143+
if __name__ == "__main__":
144+
run_llama_generate()
145+

examples/requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
transformers==4.45.*
22
sentencepiece
33
pillow
4-
pytest-forked
4+
pytest-forked
5+
tiktoken
6+
blobfile
7+
torchvision

neuron_test/unit_test/models/__init__.py

Whitespace-only changes.

neuron_test/unit_test/models/mllama/__init__.py

Whitespace-only changes.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from neuronx_distributed_inference.models.config import InferenceConfig
5+
from neuronx_distributed_inference.models.mllama.modeling_mllama_vision import VisionEncoder
6+
from neuronx_distributed_inference.models.mllama.utils import META_CHECKPOINT, to_2tuple
7+
8+
from .test_utils import load_checkpoint, logger, save_checkpoint, setup_debug_env, trace_nxd_model
9+
10+
VISION_SEQ_LEN = 1601
11+
VISION_HIDDEN_DIM = 1280
12+
MAX_NUM_CHUNKS = 4
13+
TORCH_DTYPE = torch.float32
14+
15+
16+
class VisionEncoderPosEmbedOnly(VisionEncoder):
17+
def __init__(self, max_num_tiles, image_size, patch_size, width):
18+
nn.Module.__init__(self)
19+
self.config = InferenceConfig(neuron_config=None)
20+
self.config.checkpoint = META_CHECKPOINT
21+
self.max_num_tiles = max_num_tiles
22+
self.image_size = to_2tuple(image_size)
23+
self.patch_size = to_2tuple(patch_size)
24+
self.grid_size = (
25+
self.image_size[0] // self.patch_size[0],
26+
self.image_size[1] // self.patch_size[1],
27+
)
28+
scale = width**-0.5
29+
self.positional_embedding = nn.Parameter(
30+
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width, dtype=TORCH_DTYPE)
31+
)
32+
self.gated_positional_embedding = nn.Parameter(
33+
scale
34+
* torch.randn(
35+
max_num_tiles,
36+
max_num_tiles,
37+
self.grid_size[0] * self.grid_size[1] + 1,
38+
width,
39+
dtype=TORCH_DTYPE,
40+
)
41+
)
42+
# Don't initialize to zero, otherwise the gated_positional_embedding has no effect on output
43+
self.gated_positional_embedding_gate = nn.Parameter(torch.randn(1, dtype=TORCH_DTYPE))
44+
45+
def forward(self, x, ar):
46+
return self.apply_positional_embedding(x, ar, ar_ids=None)
47+
48+
49+
class VisionEncoderMeta(VisionEncoderPosEmbedOnly):
50+
def apply_positional_embedding(self, x, ar, ar_ids=None):
51+
# apply regular position embedding
52+
bsz, num_chunks, num_tokens, dim = x.shape
53+
x = x.view(bsz * num_chunks, num_tokens, dim)
54+
x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh())
55+
x = x.view(bsz, num_chunks, num_tokens, dim)
56+
for idx, arx in enumerate(ar):
57+
_pos_embed = self.gated_positional_embedding[: arx[0], : arx[1]]
58+
_pos_embed = _pos_embed.reshape(arx[0] * arx[1], *_pos_embed.shape[2:])
59+
x[idx, : arx[0] * arx[1]] += _pos_embed * self.gated_positional_embedding_gate.tanh()
60+
return x
61+
62+
63+
def get_example_inputs():
64+
x = torch.randn(1, MAX_NUM_CHUNKS, VISION_SEQ_LEN, VISION_HIDDEN_DIM, dtype=TORCH_DTYPE)
65+
ar = torch.tensor([1, 1], dtype=torch.int32).view(1, 2)
66+
return x, ar
67+
68+
69+
def test_apply_pos_embed():
70+
setup_debug_env()
71+
72+
init_args = dict(
73+
max_num_tiles=MAX_NUM_CHUNKS,
74+
image_size=560,
75+
patch_size=14,
76+
width=VISION_HIDDEN_DIM,
77+
)
78+
79+
cpu_model_meta = VisionEncoderMeta(**init_args)
80+
save_checkpoint(cpu_model_meta)
81+
cpu_model = VisionEncoderPosEmbedOnly(**init_args)
82+
cpu_model.load_state_dict(load_checkpoint())
83+
84+
# Trace to get neuron model
85+
example_inputs = get_example_inputs()
86+
x, ar = example_inputs
87+
neuron_model = trace_nxd_model(
88+
VisionEncoderPosEmbedOnly, example_inputs, tp_degree=1, **init_args
89+
)
90+
91+
# Test all possible aspect ratios (with max_num_chunks=4)
92+
aspect_ratios = [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1], [2, 2], [3, 1], [4, 1]]
93+
for aspect_ratio in aspect_ratios:
94+
print("Testing aspect ratio:", tuple(aspect_ratio))
95+
ar = torch.tensor(aspect_ratio, dtype=torch.int32).view(1, 2)
96+
97+
# Compare Meta vs our implementation on CPU
98+
x_out_meta = cpu_model_meta(x, ar)
99+
x_out_cpu = cpu_model(x, ar)
100+
assert torch.allclose(x_out_meta, x_out_cpu)
101+
logger.info("Correctness test passing on CPU.")
102+
103+
x_out_xla = neuron_model(x, ar)
104+
assert torch.allclose(x_out_meta, x_out_xla)
105+
logger.info(
106+
f"{x_out_meta.shape}, {x.sum()}, {x_out_meta.sum()}, {x_out_cpu.sum()}, {x_out_xla.sum()}"
107+
)
108+
logger.info("Correctness test passing on device.\n")
109+
110+
logger.info("ALL TESTS PASSING")
111+
112+
113+
if __name__ == "__main__":
114+
test_apply_pos_embed()

0 commit comments

Comments
 (0)