This example provides a minimal example of pruning large language models with a magnitude-based importance score. We use the transformers
library to load the model and the datasets
library to evaluate the Perplexity with Wikitext2
. For more comprehensive examples of Gradient-based pruning or finetuning, please refer to LLM-Pruner.
This script has been tested with the following models:
- meta-llama/Meta-Llama-3.1-8B
- meta-llama/Meta-Llama-3-8B
- meta-llama/Llama-2-7b-hf
- microsoft/Phi-3-mini-4k-instruct
- Qwen/Qwen2-7B
- Qwen/Qwen2.5-7B
pip install transformers datasets accelerate --upgrade
# if torch-pruning not installed
pip install torch-pruning --upgrade # >=1.5.0
python prune_llm.py --model MODEL_CARD --pruning_ratio PRUNING_RATIO --max_seq_len MAX_SEQ_LEN --save_model SAVE_HF_MODEL
Arguments:
MODEL_CARD
: The model card of the model to be pruned, such asmeta-llama/Meta-Llama-3-8B
.PRUNING_RATIO
: The ratio of the model width to be pruned, such as0.5
.MAX_SEQ_LEN
: The maximum sequence length of the model, such as 4096. If not provided, the script will use the maximum sequence length of the model.SAVE_HF_MODEL
: The path to save the pruned model. If not provided, the pruned model will not be saved. You can load a saved model usingAutoModelForCausalLM.from_pretrained
.
python prune_llm.py --model meta-llama/Llama-3.1-8B --pruning_ratio 0.5 --max_seq_len 4096
Output:
----------------- Before Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
----------------- After Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 2048)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=1024, bias=False)
(v_proj): Linear(in_features=2048, out_features=1024, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=2048, out_features=7168, bias=False)
(up_proj): Linear(in_features=2048, out_features=7168, bias=False)
(down_proj): Linear(in_features=7168, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((2048,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
LlamaConfig {
"_attn_implementation_autoset": true,
"_name_or_path": "meta-llama/Llama-3.1-8B",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 7168,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.2",
"use_cache": true,
"vocab_size": 128256
}
num_params 2337409024
evaluating on wikitext2
Token indices sequence length is longer than the specified maximum sequence length for this model (2458791 > 131072). Running this sequence through the model will result in indexing errors
nsamples 70
sample 0
sample 50
wikitext perplexity 215807.609375
python prune_llm.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
Output:
----------------- Before Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
----------------- After Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 2048)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=1024, bias=False)
(v_proj): Linear(in_features=2048, out_features=1024, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=2048, out_features=7168, bias=False)
(up_proj): Linear(in_features=2048, out_features=7168, bias=False)
(down_proj): Linear(in_features=7168, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((2048,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
LlamaConfig {
"_attn_implementation_autoset": true,
"_name_or_path": "meta-llama/Meta-Llama-3-8B",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 7168,
"max_position_embeddings": 8192,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.2",
"use_cache": true,
"vocab_size": 128256
}
num_params 2337409024
evaluating on wikitext2
nsamples 35
sample 0
wikitext perplexity 552648.25
python prune_llm.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
Output:
----------------- Before Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
----------------- After Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 2048)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=2048, bias=False)
(v_proj): Linear(in_features=2048, out_features=2048, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=2048, out_features=5504, bias=False)
(up_proj): Linear(in_features=2048, out_features=5504, bias=False)
(down_proj): Linear(in_features=5504, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((2048,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
LlamaConfig {
"_name_or_path": "meta-llama/Llama-2-7b-hf",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5504,
"max_position_embeddings": 4096,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 16,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.44.2",
"use_cache": true,
"vocab_size": 32000
}
num_params 1750206464
evaluating on wikitext2
nsamples 83
sample 0
sample 50
wikitext perplexity 8479.0673828125
python prune_llm.py --model microsoft/Phi-3-mini-4k-instruct --pruning_ratio 0.5
Output:
----------------- Before Pruning -----------------
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 3072, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
(down_proj): Linear(in_features=8192, out_features=3072, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
----------------- After Pruning -----------------
Token indices sequence length is longer than the specified maximum sequence length for this model (2824490 > 4096). Running this sequence through the model will result in indexing errors
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 1536, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(qkv_proj): Linear(in_features=1536, out_features=4608, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=1536, out_features=8192, bias=False)
(down_proj): Linear(in_features=4096, out_features=1536, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=1536, out_features=32064, bias=False)
)
Phi3Config {
"_attn_implementation_autoset": true,
"_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
"architectures": [
"Phi3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "microsoft/Phi-3-mini-4k-instruct--configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "microsoft/Phi-3-mini-4k-instruct--modeling_phi3.Phi3ForCausalLM"
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 4096,
"model_type": "phi3",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 16,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"sliding_window": 2047,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.3",
"use_cache": true,
"vocab_size": 32064
}
num_params 1004570112
evaluating on wikitext2
nsamples 83
sample 0
sample 50
wikitext perplexity 92795.3984375
The Qwen2-7b model has 28 heads with num_key_value_heads=4
. This limits the pruning ratio to be [1/7, 2/7, 3/7, 4/7, 5/7, 6/7] if you want to save and load the pruned model using Huggingface Transformers since HF only supports in_features==out_features
in the q_proj
and o_proj
.
# 3/7 ~ 0.428571428, this script will craft a 2B model
python prune_llm.py --model Qwen/Qwen2-7B --pruning_ratio 0.428571428 --max_seq_len 4096
Output:
----------------- Before Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 3584)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
(up_proj): Linear(in_features=3584, out_features=18944, bias=False)
(down_proj): Linear(in_features=18944, out_features=3584, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((3584,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)
----------------- After Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 2048)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=512, bias=True)
(v_proj): Linear(in_features=2048, out_features=512, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=2048, out_features=10825, bias=False)
(up_proj): Linear(in_features=2048, out_features=10825, bias=False)
(down_proj): Linear(in_features=10825, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((2048,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=2048, out_features=152064, bias=False)
)
Qwen2Config {
"_attn_implementation_autoset": true,
"_name_or_path": "Qwen/Qwen2-7B",
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 10825,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.2",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064
}
num_params 2778904576
evaluating on wikitext2
nsamples 73
sample 0
sample 50
wikitext perplexity 44195.69140625
The Qwen2.5-7b model has 28 heads with num_key_value_heads=4
. This limits the pruning ratio to be [1/7, 2/7, 3/7, 4/7, 5/7, 6/7] if you want to save and load the pruned model using HuginggFace Transformers since HF only supports the same in_features and out_features for the q_proj
and o_proj
.
# 3/7 ~ 0.428571428, this script will craft a 2B model
python prune_llm.py --model Qwen/Qwen2.5-7B --pruning_ratio 0.428571428 --max_seq_len 4096
Output:
----------------- Before Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 3584)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
(up_proj): Linear(in_features=3584, out_features=18944, bias=False)
(down_proj): Linear(in_features=18944, out_features=3584, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((3584,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)
----------------- After Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 2048)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=512, bias=True)
(v_proj): Linear(in_features=2048, out_features=512, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=2048, out_features=10824, bias=False)
(up_proj): Linear(in_features=2048, out_features=10824, bias=False)
(down_proj): Linear(in_features=10824, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((2048,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=2048, out_features=152064, bias=False)
)
Qwen2Config {
"_attn_implementation_autoset": true,
"_name_or_path": "Qwen/Qwen2.5-7B",
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 10824,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.2",
"use_cache": true,
"use_mrope": false,
"use_sliding_window": false,
"vocab_size": 152064
}
num_params 2778732544
evaluating on wikitext2
Token indices sequence length is longer than the specified maximum sequence length for this model (2541000 > 131072). Running this sequence through the model will result in indexing errors
nsamples 73
sample 0
sample 50
wikitext perplexity 100281.03125