From 01275b4cb3d70daa7ce52e261cf1c21e3adc38eb Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 2 Sep 2024 12:59:51 +0200 Subject: [PATCH] ENH: Faster adapter loading if there are a lot of target modules (#2045) This is an optimization to reduce the number of entries in the target_modules list. The reason is that in some circumstances, target_modules can contain hundreds of entries. Since each target module is checked against each module of the net (which can be thousands), this can become quite expensive when many adapters are being added. Often, the target_modules can be condensed in such a case, which speeds up the process. A context in which this can happen is when diffusers loads non-PEFT LoRAs. As there is no meta info on target_modules in that case, they are just inferred by listing all keys from the state_dict, which can be quite a lot. See: https://github.com/huggingface/diffusers/issues/9297 As shown there the speed improvements for loading many diffusers LoRAs can be substantial. When loading 30 adapters, the time would go up from 0.6 sec per adapter to 3 sec per adapter. With this fix, the time goes up from 0.6 sec per adapter to 1 sec per adapter. As there is a small chance for undiscovered bugs, we apply this optimization only if the list of target_modules is sufficiently big. --- src/peft/tuners/tuners_utils.py | 108 ++++++++++++++++++++++++- src/peft/utils/constants.py | 6 ++ tests/test_tuners_utils.py | 135 +++++++++++++++++++++++++++++++- 3 files changed, 247 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 285b5c9410..f185236287 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -31,7 +31,13 @@ from transformers.pytorch_utils import Conv1D from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND -from peft.utils.constants import DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES, SEQ_CLS_HEAD_NAMES +from peft.utils.constants import ( + DUMMY_MODEL_CONFIG, + DUMMY_TARGET_MODULES, + EMBEDDING_LAYER_NAMES, + MIN_TARGET_MODULES_FOR_OPTIMIZATION, + SEQ_CLS_HEAD_NAMES, +) from peft.utils.peft_types import PeftType, TaskType from ..config import PeftConfig @@ -433,6 +439,26 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d # update peft_config.target_modules if required peft_config = _maybe_include_all_linear_layers(peft_config, model) + # This is an optimization to reduce the number of entries in the target_modules list. The reason is that in some + # circumstances, target_modules can contain hundreds of entries. Since each target module is checked against + # each module of the net (which can be thousands), this can become quite expensive when many adapters are being + # added. Often, the target_modules can be condensed in such a case, which speeds up the process. + # A context in which this can happen is when diffusers loads non-PEFT LoRAs. As there is no meta info on + # target_modules in that case, they are just inferred by listing all keys from the state_dict, which can be + # quite a lot. See: https://github.com/huggingface/diffusers/issues/9297 + # As there is a small chance for undiscovered bugs, we apply this optimization only if the list of + # target_modules is sufficiently big. + if ( + isinstance(peft_config.target_modules, (list, set)) + and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION + ): + names_no_target = [ + name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules) + ] + new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target) + if len(new_target_modules) < len(peft_config.target_modules): + peft_config.target_modules = new_target_modules + for key in key_list: # Check for modules_to_save in case if _check_for_modules_to_save and any( @@ -781,6 +807,86 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device) +def _find_minimal_target_modules( + target_modules: list[str] | set[str], other_module_names: list[str] | set[str] +) -> set[str]: + """Find the minimal set of target modules that is sufficient to separate them from the other modules. + + Sometimes, a very large list of target_modules could be passed, which can slow down loading of adapters (e.g. when + loaded from diffusers). It may be possible to condense this list from hundreds of items to just a handful of + suffixes that are sufficient to distinguish the target modules from the other modules. + + Example: + ```py + >>> from peft.tuners.tuners_utils import _find_minimal_target_modules + + >>> target_modules = [f"model.decoder.layers.{i}.self_attn.q_proj" for i in range(100)] + >>> target_modules += [f"model.decoder.layers.{i}.self_attn.v_proj" for i in range(100)] + >>> other_module_names = [f"model.encoder.layers.{i}.self_attn.k_proj" for i in range(100)] + >>> _find_minimal_target_modules(target_modules, other_module_names) + {"q_proj", "v_proj"} + ``` + + Args: + target_modules (`list[str]` | `set[str]`): + The list of target modules. + other_module_names (`list[str]` | `set[str]`): + The list of other module names. They must not overlap with the target modules. + + Returns: + `set[str]`: + The minimal set of target modules that is sufficient to separate them from the other modules. + + Raises: + ValueError: + If `target_modules` is not a list or set of strings or if it contains an empty string. Also raises an error + if `target_modules` and `other_module_names` contain common elements. + """ + if isinstance(target_modules, str) or not target_modules: + raise ValueError("target_modules should be a list or set of strings.") + + target_modules = set(target_modules) + if "" in target_modules: + raise ValueError("target_modules should not contain an empty string.") + + other_module_names = set(other_module_names) + if not target_modules.isdisjoint(other_module_names): + msg = ( + "target_modules and other_module_names contain common elements, this should not happen, please " + "open a GitHub issue at https://github.com/huggingface/peft/issues with the code to reproduce this issue" + ) + raise ValueError(msg) + + # it is assumed that module name parts are separated by a "." + def generate_suffixes(s): + parts = s.split(".") + return [".".join(parts[i:]) for i in range(len(parts))][::-1] + + # Create a reverse lookup for other_module_names to quickly check suffix matches + other_module_suffixes = {suffix for item in other_module_names for suffix in generate_suffixes(item)} + + # Find all potential suffixes from target_modules + target_modules_suffix_map = {item: generate_suffixes(item) for item in target_modules} + + # Initialize a set for required suffixes + required_suffixes = set() + + for item, suffixes in target_modules_suffix_map.items(): + # Go through target_modules items, shortest suffixes first + for suffix in suffixes: + # If the suffix is already in required_suffixes or matches other_module_names, skip it + if suffix in required_suffixes or suffix in other_module_suffixes: + continue + # Check if adding this suffix covers the item + if not any(item.endswith(req_suffix) for req_suffix in required_suffixes): + required_suffixes.add(suffix) + break + + if not required_suffixes: + return set(target_modules) + return required_suffixes + + def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: """A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index b071503d8d..281ca5ea79 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -262,3 +262,9 @@ def starcoder_model_postprocess_past_key_value(past_key_values): TOKENIZER_CONFIG_NAME = "tokenizer_config.json" DUMMY_TARGET_MODULES = "dummy-target-modules" DUMMY_MODEL_CONFIG = {"model_type": "custom"} + +# If users specify more than this number of target modules, we apply an optimization to try to reduce the target modules +# to a minimal set of suffixes, which makes loading faster. We only apply this when exceeding a certain size since +# otherwise there is no point in optimizing and there is a small chance of bugs in the optimization algorithm, so no +# point in taking unnecessary risks. See #2045 for more context. +MIN_TARGET_MODULES_FOR_OPTIMIZATION = 20 diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index e2766ec17a..4072098493 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -42,6 +42,7 @@ get_model_status, get_peft_model, ) +from peft.tuners.lora.layer import LoraLayer from peft.tuners.tuners_utils import ( BaseTuner, BaseTunerLayer, @@ -49,8 +50,11 @@ check_target_module_exists, inspect_matched_modules, ) +from peft.tuners.tuners_utils import ( + _find_minimal_target_modules as find_minimal_target_modules, +) from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device -from peft.utils.constants import DUMMY_MODEL_CONFIG +from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu @@ -1149,3 +1153,132 @@ def test_no_warn_for_no_target_module_merge(self, recwarn): model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") model_no_target_module.merge_and_unload() assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge) + + +class TestFindMinimalTargetModules: + @pytest.mark.parametrize( + "target_modules, other_module_names, expected", + [ + (["bar"], [], {"bar"}), + (["foo"], ["bar"], {"foo"}), + (["1.foo", "2.foo"], ["3.foo", "4.foo"], {"1.foo", "2.foo"}), + # Could also return "bar.baz" but we want the shorter one + (["bar.baz"], ["foo.bar"], {"baz"}), + (["1.foo", "2.foo", "bar.baz"], ["3.foo", "bar.bla"], {"1.foo", "2.foo", "baz"}), + # Case with longer suffix chains and nested suffixes + (["a.b.c", "d.e.f", "g.h.i"], ["j.k.l", "m.n.o"], {"c", "f", "i"}), + (["a.b.c", "d.e.f", "g.h.i"], ["a.b.x", "d.x.f", "x.h.i"], {"c", "e.f", "g.h.i"}), + # Case with multiple items that can be covered by a single suffix + (["foo.bar.baz", "qux.bar.baz"], ["baz.bar.foo"], {"baz"}), + # Realistic examples + # Only match k_proj + ( + ["model.decoder.layers.{i}.self_attn.k_proj" for i in range(12)], + ( + ["model.decoder.layers.{i}.self_attn" for i in range(12)] + + ["model.decoder.layers.{i}.self_attn.v_proj" for i in range(12)] + + ["model.decoder.layers.{i}.self_attn.q_proj" for i in range(12)] + ), + {"k_proj"}, + ), + # Match all k_proj except the one in layer 5 => no common suffix + ( + ["model.decoder.layers.{i}.self_attn.k_proj" for i in range(12) if i != 5], + ( + ["model.decoder.layers.5.self_attn.k_proj"] + + ["model.decoder.layers.{i}.self_attn" for i in range(12)] + + ["model.decoder.layers.{i}.self_attn.v_proj" for i in range(12)] + + ["model.decoder.layers.{i}.self_attn.q_proj" for i in range(12)] + ), + {"{i}.self_attn.k_proj" for i in range(12) if i != 5}, + ), + ], + ) + def test_find_minimal_target_modules(self, target_modules, other_module_names, expected): + # check all possible combinations of list and set + result = find_minimal_target_modules(target_modules, other_module_names) + assert result == expected + + result = find_minimal_target_modules(set(target_modules), other_module_names) + assert result == expected + + result = find_minimal_target_modules(target_modules, set(other_module_names)) + assert result == expected + + result = find_minimal_target_modules(set(target_modules), set(other_module_names)) + assert result == expected + + def test_find_minimal_target_modules_empty_raises(self): + with pytest.raises(ValueError, match="target_modules should be a list or set of strings"): + find_minimal_target_modules([], ["foo"]) + + with pytest.raises(ValueError, match="target_modules should be a list or set of strings"): + find_minimal_target_modules(set(), ["foo"]) + + def test_find_minimal_target_modules_contains_empty_string_raises(self): + target_modules = ["", "foo", "bar.baz"] + other_module_names = ["bar"] + with pytest.raises(ValueError, match="target_modules should not contain an empty string"): + find_minimal_target_modules(target_modules, other_module_names) + + def test_find_minimal_target_modules_string_raises(self): + target_modules = "foo" + other_module_names = ["bar"] + with pytest.raises(ValueError, match="target_modules should be a list or set of strings"): + find_minimal_target_modules(target_modules, other_module_names) + + @pytest.mark.parametrize( + "target_modules, other_module_names", + [ + (["foo"], ["foo"]), + (["foo.bar"], ["foo.bar"]), + (["foo.bar", "spam", "eggs"], ["foo.bar"]), + (["foo.bar", "spam"], ["foo.bar", "eggs"]), + (["foo.bar"], ["foo.bar", "spam", "eggs"]), + ], + ) + def test_find_minimal_target_modules_not_disjoint_raises(self, target_modules, other_module_names): + msg = ( + "target_modules and other_module_names contain common elements, this should not happen, please " + "open a GitHub issue at https://github.com/huggingface/peft/issues with the code to reproduce this issue" + ) + with pytest.raises(ValueError, match=msg): + find_minimal_target_modules(target_modules, other_module_names) + + def test_get_peft_model_applies_find_target_modules(self): + # Check that when calling get_peft_model, the target_module optimization is indeed applied if the lenght of + # target_modules is big enough. The resulting model itself should be unaffected. + torch.manual_seed(0) + model_id = "facebook/opt-125m" # must be big enough for optimization to trigger + model = AutoModelForCausalLM.from_pretrained(model_id) + + # base case: specify target_modules in a minimal fashion + config = LoraConfig(init_lora_weights=False, target_modules=["q_proj", "v_proj"]) + model = get_peft_model(model, config) + + # this list contains all targeted modules listed separately + big_target_modules = [name for name, module in model.named_modules() if isinstance(module, LoraLayer)] + # sanity check + assert len(big_target_modules) > MIN_TARGET_MODULES_FOR_OPTIMIZATION + + # make a "checksum" of the model for comparison + model_check_sum_before = sum(p.sum() for p in model.parameters()) + + # strip prefix so that the names they can be used as new target_modules + prefix_to_strip = "base_model.model.model." + big_target_modules = [name[len(prefix_to_strip) :] for name in big_target_modules] + + del model + + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id) + # pass the big target_modules to config + config = LoraConfig(init_lora_weights=False, target_modules=big_target_modules) + model = get_peft_model(model, config) + + # check that target modules have been condensed + assert model.peft_config["default"].target_modules == {"q_proj", "v_proj"} + + # check that the resulting model is still the same + model_check_after = sum(p.sum() for p in model.parameters()) + assert model_check_sum_before == model_check_after