diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 2f195b2cbe..03b8531bfd 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -885,14 +885,16 @@ def generate_suffixes(s): # Initialize a set for required suffixes required_suffixes = set() - for item, suffixes in target_modules_suffix_map.items(): + # We sort the target_modules_suffix_map simply to get deterministic behavior, since sets have no order. In theory + # the order should not matter but in case there is a bug, it's better for the bug to be deterministic. + for item, suffixes in sorted(target_modules_suffix_map.items(), key=lambda tup: tup[1]): # Go through target_modules items, shortest suffixes first for suffix in suffixes: # If the suffix is already in required_suffixes or matches other_module_names, skip it if suffix in required_suffixes or suffix in other_module_suffixes: continue # Check if adding this suffix covers the item - if not any(item.endswith(req_suffix) for req_suffix in required_suffixes): + if not any(item.endswith("." + req_suffix) for req_suffix in required_suffixes): required_suffixes.add(suffix) break diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 4072098493..90dbea8d70 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1282,3 +1282,48 @@ def test_get_peft_model_applies_find_target_modules(self): # check that the resulting model is still the same model_check_after = sum(p.sum() for p in model.parameters()) assert model_check_sum_before == model_check_after + + def test_suffix_is_substring_of_other_suffix(self): + # This test is based on a real world bug found in diffusers. The issue was that we needed the suffix + # 'time_emb_proj' in the minimal target modules. However, if there already was the suffix 'proj' in the + # required_suffixes, 'time_emb_proj' would not be added because the test was `endswith(suffix)` and + # 'time_emb_proj' ends with 'proj'. The correct logic is to test if `endswith("." + suffix")`. The module names + # chosen here are only a subset of the hundreds of actual module names but this subset is sufficient to + # replicate the bug. + target_modules = [ + "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj", + "mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj", + "up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj", + "mid_block.attentions.0.proj_out", + "up_blocks.0.attentions.0.proj_out", + "down_blocks.1.attentions.0.proj_out", + "up_blocks.0.resnets.0.time_emb_proj", + "down_blocks.0.resnets.0.time_emb_proj", + "mid_block.resnets.0.time_emb_proj", + ] + other_module_names = [ + "conv_in", + "time_proj", + "time_embedding", + "time_embedding.linear_1", + "add_time_proj", + "add_embedding", + "add_embedding.linear_1", + "add_embedding.linear_2", + "down_blocks", + "down_blocks.0", + "down_blocks.0.resnets", + "down_blocks.0.resnets.0", + "up_blocks", + "up_blocks.0", + "up_blocks.0.attentions", + "up_blocks.0.attentions.0", + "up_blocks.0.attentions.0.norm", + "up_blocks.0.attentions.0.transformer_blocks", + "up_blocks.0.attentions.0.transformer_blocks.0", + "up_blocks.0.attentions.0.transformer_blocks.0.norm1", + "up_blocks.0.attentions.0.transformer_blocks.0.attn1", + ] + expected = {"time_emb_proj", "proj", "proj_out"} + result = find_minimal_target_modules(target_modules, other_module_names) + assert result == expected