Skip to content

Commit 067eab1

Browse files
Luvatasayakpaul
andauthored
Faster set_adapters (#10777)
* Update peft_utils.py * Update peft_utils.py * Update peft_utils.py --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 57ac673 commit 067eab1

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

src/diffusers/utils/peft_utils.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -257,26 +257,18 @@ def get_module_weight(weight_for_adapter, module_name):
257257

258258
return block_weight
259259

260-
# iterate over each adapter, make it active and set the corresponding scaling weight
261-
for adapter_name, weight in zip(adapter_names, weights):
262-
for module_name, module in model.named_modules():
263-
if isinstance(module, BaseTunerLayer):
264-
# For backward compatbility with previous PEFT versions
265-
if hasattr(module, "set_adapter"):
266-
module.set_adapter(adapter_name)
267-
else:
268-
module.active_adapter = adapter_name
269-
module.set_scale(adapter_name, get_module_weight(weight, module_name))
270-
271-
# set multiple active adapters
272-
for module in model.modules():
260+
for module_name, module in model.named_modules():
273261
if isinstance(module, BaseTunerLayer):
274-
# For backward compatbility with previous PEFT versions
262+
# For backward compatibility with previous PEFT versions, set multiple active adapters
275263
if hasattr(module, "set_adapter"):
276264
module.set_adapter(adapter_names)
277265
else:
278266
module.active_adapter = adapter_names
279267

268+
# Set the scaling weight for each adapter for this module
269+
for adapter_name, weight in zip(adapter_names, weights):
270+
module.set_scale(adapter_name, get_module_weight(weight, module_name))
271+
280272

281273
def check_peft_version(min_version: str) -> None:
282274
r"""

0 commit comments

Comments
 (0)