Skip to content

Commit

Permalink
attn.attn --> attn.qkv
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Nov 18, 2024
1 parent 4371c06 commit 465a9f7
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 107 deletions.
2 changes: 1 addition & 1 deletion litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def scaled_dot_product_attention(
ak, av = self.adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
aqkv = self.attn(prefix)
aqkv = self.qkv(prefix)
q_per_kv = self.config.n_head // self.config.n_query_groups
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
Expand Down
4 changes: 2 additions & 2 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
Expand Down
13 changes: 6 additions & 7 deletions litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,30 @@
import logging
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional, Union
import warnings

import lightning as L
from lightning_utilities.core.imports import RequirementCache
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.utilities import rank_zero_only
from lightning_utilities.core.imports import RequirementCache

import litgpt.generate.base as generate_base
from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_nvlink_connectivity,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision
get_default_supported_precision,
)


Expand Down Expand Up @@ -71,7 +70,7 @@ def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMA


def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None:
tensor_parallel_linear(fabric, attn.attn, "colwise")
tensor_parallel_linear(fabric, attn.qkv, "colwise")
tensor_parallel_linear(fabric, attn.proj, "rowwise")
attn.register_forward_hook(partial(all_reduce_output, fabric.world_size))

Expand Down
4 changes: 2 additions & 2 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
Expand Down
2 changes: 1 addition & 1 deletion litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def forward(

# Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value`
# instead of individually multiplying the input `x` with the respective weight matrices.
qkv = self.attn(x) # (B, T, 3xC*)
qkv = self.qkv(x) # (B, T, 3xC*)

# Define query, key and value sizes.
# If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`).
Expand Down
31 changes: 7 additions & 24 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def copy_weights_gpt_neox(
"gpt_neox.embed_in.weight": "transformer.wte.weight",
"gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias",
"gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight",
"gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.qkv.bias",
"gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight",
"gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias",
"gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight",
"gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None,
Expand Down Expand Up @@ -83,7 +83,7 @@ def copy_weights_falcon(
) -> None:
weight_map = {
"transformer.word_embeddings.weight": "transformer.wte.weight",
"transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight",
"transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight",
"transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight",
"transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
"transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
Expand Down Expand Up @@ -209,7 +209,7 @@ def copy_weights_hf_llama(
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
Expand Down Expand Up @@ -277,7 +277,7 @@ def copy_weights_gemma_2(
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
Expand Down Expand Up @@ -325,7 +325,7 @@ def copy_weights_phi(
if config.name.startswith("Phi-3"):
weight_map.update(
{
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.attn.weight",
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
Expand Down Expand Up @@ -370,30 +370,13 @@ def copy_weights_phi(
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


# def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
# """Reassemble from a normal to an interleaved placement in a QKV matrix.
# [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
# """
# q, k, v = param.split(
# (
# config.n_head * config.head_size,
# config.n_query_groups * config.head_size,
# config.n_query_groups * config.head_size,
# )
# )
# qs = q.split(config.n_head // config.n_query_groups * config.head_size)
# ks = k.split(config.head_size)
# vs = v.split(config.head_size)
# interleaved = [t for group in zip(qs, ks, vs) for t in group]
# return torch.cat(interleaved)

def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
18 changes: 9 additions & 9 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def copy_weights_falcon(
) -> None:
weight_map = {
"transformer.wte.weight": "transformer.word_embeddings.weight",
"transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight",
"transformer.h.{}.attn.qkv.weight": "transformer.h.{}.self_attention.query_key_value.weight",
"transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight",
"transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight",
"transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight",
Expand Down Expand Up @@ -55,7 +55,7 @@ def copy_weights_falcon(
name_template, layer_idx = layer_template(from_name)
to_name = weight_map[name_template].format(layer_idx)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")):
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
# Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...]
param = qkv_reassemble(param, config)
if saver is not None:
Expand All @@ -73,8 +73,8 @@ def copy_weights_gpt_neox(
"transformer.wte.weight": "gpt_neox.embed_in.weight",
"transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias",
"transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight",
"transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias",
"transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight",
"transformer.h.{}.attn.qkv.bias": "gpt_neox.layers.{}.attention.query_key_value.bias",
"transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight",
"transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias",
"transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight",
"transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias",
Expand All @@ -92,7 +92,7 @@ def copy_weights_gpt_neox(
name_template, layer_idx = layer_template(from_name)
to_name = weight_map[name_template].format(layer_idx)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")):
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
# Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...]
param = qkv_reassemble(param, config)
if saver is not None:
Expand Down Expand Up @@ -143,7 +143,7 @@ def copy_weights_llama(
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.attn.weight"):
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
Expand Down Expand Up @@ -192,7 +192,7 @@ def copy_weights_gemma_2(
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.attn.weight"):
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
Expand Down Expand Up @@ -239,7 +239,7 @@ def copy_weights_phi(
if config.name.startswith("Phi-3"):
weight_map.update(
{
"transformer.h.{}.attn.attn.weight": "model.layers.{}.self_attn.qkv_proj.weight",
"transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
Expand All @@ -251,7 +251,7 @@ def copy_weights_phi(
for from_name, param in lit_weights.items():
name_template, layer_idx = layer_template(from_name)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")):
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
if config.name.startswith("Phi-3"):
to_names = (weight_map[name_template].format(layer_idx),)
params = (param,)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca
"transformer.h.0.norm_1.weight",
"transformer.h.0.norm_1.bias",
"transformer.h.0.attn.gating_factor",
"transformer.h.0.attn.attn.bias",
"transformer.h.0.attn.qkv.bias",
"transformer.h.0.attn.proj.bias",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.0.norm_2.weight",
Expand All @@ -204,7 +204,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca
"transformer.h.1.norm_1.weight",
"transformer.h.1.norm_1.bias",
"transformer.h.1.attn.gating_factor",
"transformer.h.1.attn.attn.bias",
"transformer.h.1.attn.qkv.bias",
"transformer.h.1.attn.proj.bias",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.h.1.norm_2.weight",
Expand All @@ -216,11 +216,11 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca
},
"torch.uint8": {
"lm_head.weight",
"transformer.h.0.attn.attn.weight",
"transformer.h.0.attn.qkv.weight",
"transformer.h.0.attn.proj.weight",
"transformer.h.0.mlp.fc.weight",
"transformer.h.0.mlp.proj.weight",
"transformer.h.1.attn.attn.weight",
"transformer.h.1.attn.qkv.weight",
"transformer.h.1.attn.proj.weight",
"transformer.h.1.mlp.fc.weight",
"transformer.h.1.mlp.proj.weight",
Expand Down
36 changes: 18 additions & 18 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_config_identical():
base_model = BaseGPT.from_name(name)
adapter_model = AdapterV2GPT.from_name(name)

assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_bias")
assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_scale")
assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_bias")
assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_scale")
assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_bias")
assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_scale")
assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_bias")
assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_scale")


def test_adapter_v2_filter(tmp_path):
Expand All @@ -58,8 +58,8 @@ def test_adapter_v2_filter(tmp_path):
}
for layer in range(3):
for param in (
"attn.attn.adapter_bias",
"attn.attn.adapter_scale",
"attn.qkv.adapter_bias",
"attn.qkv.adapter_scale",
"attn.proj.adapter_bias",
"attn.proj.adapter_scale",
"mlp.fc.adapter_bias",
Expand Down Expand Up @@ -366,27 +366,27 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp
"torch.uint8": {
"transformer.h.0.mlp.fc.linear.weight",
"transformer.h.1.mlp.proj.linear.weight",
"transformer.h.1.attn.attn.linear.weight",
"transformer.h.1.attn.qkv.linear.weight",
"transformer.h.0.attn.proj.linear.weight",
"lm_head.linear.weight",
"transformer.h.1.attn.proj.linear.weight",
"transformer.h.0.mlp.proj.linear.weight",
"transformer.h.0.attn.attn.linear.weight",
"transformer.h.0.attn.qkv.linear.weight",
"transformer.h.1.mlp.fc.linear.weight",
},
"torch.float16": {
"transformer.h.1.attn.attn.adapter_bias",
"transformer.h.1.attn.qkv.adapter_bias",
"transformer.h.1.mlp.proj.adapter_bias",
"transformer.h.0.attn.attn.adapter_bias",
"transformer.h.0.attn.qkv.adapter_bias",
"transformer.h.0.norm_1.bias",
"transformer.h.0.attn.attn.linear.bias",
"transformer.h.0.attn.qkv.linear.bias",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.ln_f.weight",
"transformer.h.0.mlp.fc.linear.bias",
"transformer.h.0.mlp.proj.linear.bias",
"transformer.h.1.mlp.fc.linear.bias",
"transformer.h.0.attn.proj.adapter_scale",
"transformer.h.0.attn.attn.adapter_scale",
"transformer.h.0.attn.qkv.adapter_scale",
"transformer.h.1.norm_2.bias",
"transformer.h.1.attn.proj.adapter_scale",
"transformer.h.0.norm_2.bias",
Expand All @@ -408,9 +408,9 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp
"lm_head.adapter_bias",
"transformer.h.1.norm_2.weight",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.1.attn.attn.adapter_scale",
"transformer.h.1.attn.qkv.adapter_scale",
"transformer.h.1.mlp.fc.adapter_scale",
"transformer.h.1.attn.attn.linear.bias",
"transformer.h.1.attn.qkv.linear.bias",
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
Expand All @@ -437,20 +437,20 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp
"transformer.ln_f.bias",
"lm_head.adapter_scale",
"transformer.h.1.norm_2.weight",
"transformer.h.0.attn.attn.adapter_scale",
"transformer.h.0.attn.qkv.adapter_scale",
"transformer.h.0.mlp.proj.adapter_bias",
"transformer.h.0.attn.gating_factor",
"transformer.h.1.norm_1.bias",
"transformer.h.1.mlp.fc.adapter_bias",
"transformer.h.1.mlp.proj.adapter_scale",
"transformer.h.0.mlp.fc.adapter_scale",
"transformer.h.1.attn.attn.adapter_bias",
"transformer.h.1.attn.qkv.adapter_bias",
"transformer.h.0.norm_2.weight",
"transformer.h.1.norm_2.bias",
"transformer.h.0.norm_1.weight",
"transformer.h.0.attn.proj.adapter_scale",
"transformer.h.1.mlp.proj.adapter_bias",
"transformer.h.0.attn.attn.adapter_bias",
"transformer.h.0.attn.qkv.adapter_bias",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.ln_f.weight",
"transformer.h.1.attn.gating_factor",
Expand All @@ -460,7 +460,7 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp
"transformer.h.0.norm_1.bias",
"transformer.h.0.norm_2.bias",
"transformer.h.1.norm_1.weight",
"transformer.h.1.attn.attn.adapter_scale",
"transformer.h.1.attn.qkv.adapter_scale",
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_llama2_70b_conversion():
# the shapes are correct
holder = {k: tuple(t.shape) for k, t in holder.items()}
assert holder == {
"transformer.h.0.attn.attn.weight": (10240, 8192),
"transformer.h.0.attn.qkv.weight": (10240, 8192),
"transformer.h.0.attn.proj.weight": (8192, 8192),
"transformer.h.0.mlp.fc_1.weight": (28672, 8192),
"transformer.h.0.mlp.fc_2.weight": (28672, 8192),
Expand Down
Loading

0 comments on commit 465a9f7

Please sign in to comment.