Skip to content

Commit

Permalink
FIX Small fixes to hotswapping
Browse files Browse the repository at this point in the history
A couple of smaller issues that surfaced when working on the diffusers
integration are not fixed.

- Better detection if model is compiled in
  prepare_model_for_compiled_hotswap
- Fix handling of models that are compiled but where compilation is not
  detected (from "inside" the model)
- Handle device of swapped in adapter weights.
- Wrong adapter name in compiled diffusion model test
- Add hotswap test for different alphas and ranks but model not being
  compiled (linear and conv2d)
  • Loading branch information
BenjaminBossan committed Feb 7, 2025
1 parent eaab05e commit 023d242
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 5 deletions.
14 changes: 11 additions & 3 deletions src/peft/utils/hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def prepare_model_for_compiled_hotswap(
# do inference with adapter 1
```
"""
is_compiled = hasattr(model, "_orig_mod")
is_compiled = hasattr(model, "_orig_mod") or getattr(model, "_compiled_call_impl", False)
if is_compiled:
raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model")

Expand Down Expand Up @@ -416,9 +416,17 @@ def hotswap_adapter_from_state_dict(
# swap actual weights
# no need to account for potential _orig_mod in key here, as torch handles that
old_val = attrgetter(key)(model)
new_val = new_val.to(old_val.data.device)

# We try to detect if the model is compiled but it does not always work, e.g. if hotswapping is called from
# within the model itself. In this case, swap_tensors raises RuntimeError and should continue without
# swap_tensors.
if not is_compiled and not is_compiled_inplace:
torch.utils.swap_tensors(old_val, new_val)
continue
try:
torch.utils.swap_tensors(old_val, new_val)
continue
except RuntimeError:
is_compiled = True

# Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if
# this workaround could not cause trouble but the tests indicate that it works.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4337,10 +4337,10 @@ def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings):
unet(**dummy_input)["sample"]

if do_hotswap:
unet.load_lora_adapter(file_name1, adapter_name="default_0", hotswap=True)
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True)
else:
# offloading the old and loading the new adapter will result in recompilation
self.set_lora_device(unet, adapter_names=["default_0"], device="cpu")
self.set_lora_device(unet, adapter_names=["adapter0"], device="cpu")
unet.load_lora_adapter(file_name1, adapter_name="other_name", hotswap=False)

# we need to call forward to potentially trigger recompilation
Expand Down
118 changes: 118 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,124 @@ def test_hotswap_extra_key_raises(self, tmp_path):
with pytest.raises(RuntimeError, match=msg):
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")

@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
def test_hotswap_works_different_ranks_alphas(self, ranks, tmp_path):
# same as test_hotswap_works but different rank and alpha
# Load 2 different adapters and check that we can hotswap between them, with the model optionally being
# compiled.
atol, rtol = 1e-4, 1e-4
inputs = torch.rand(3, 10).to(self.torch_device)

# create adapter 0
config0 = LoraConfig(target_modules=["lin0", "lin1"], r=ranks[0], lora_alpha=ranks[0], init_lora_weights=False)
model = self.get_model()
torch.manual_seed(0)
model = get_peft_model(model, config0)
model.eval()
with torch.inference_mode():
output0 = model(inputs)
model.save_pretrained(tmp_path / "adapter0")

del model

# create adapter 1
config1 = LoraConfig(target_modules=["lin0"], r=ranks[1], lora_alpha=ranks[1], init_lora_weights=False)
model = self.get_model()
torch.manual_seed(1)
model = get_peft_model(model, config1)
model.eval()
with torch.inference_mode():
output1 = model(inputs)
model.save_pretrained(tmp_path / "adapter1")

# sanity check: they're not the same
assert not torch.allclose(output0, output1, atol=atol, rtol=rtol)

del model

# load adapter 0
model = self.get_model()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")
with torch.inference_mode():
output_loaded0 = model(inputs)

# sanity check: same output after loading for adapter 0
assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol)

# hotswap with adapter 1
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
with torch.inference_mode():
output_loaded1 = model(inputs)

# real check: model now behaves like adapter 1
assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol)

# hotswap back to adapter 0
hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default")
with torch.inference_mode():
output_loaded_back0 = model(inputs)

# real check: model now behaves again like adapter 0
assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol)

@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
def test_hotswap_works_different_ranks_alphas_conv2d(self, ranks, tmp_path):
# same as previous test, but for a Conv2d model
atol, rtol = 1e-4, 1e-4
inputs = torch.rand(3, 3, 10, 10).to(self.torch_device)

# create adapter 0
config0 = LoraConfig(target_modules=["conv"], r=ranks[0], init_lora_weights=False)
model = self.get_model_conv2d()
torch.manual_seed(0)
model = get_peft_model(model, config0)
model.eval()
with torch.inference_mode():
output0 = model(inputs)
model.save_pretrained(tmp_path / "adapter0")

del model

# create adapter 1
config1 = LoraConfig(target_modules=["conv"], r=ranks[1], init_lora_weights=False)
model = self.get_model_conv2d()
torch.manual_seed(1)
model = get_peft_model(model, config1)
model.eval()
with torch.inference_mode():
output1 = model(inputs)
model.save_pretrained(tmp_path / "adapter1")

# sanity check: they're not the same
assert not torch.allclose(output0, output1, atol=atol, rtol=rtol)

del model

# load adapter 0
model = self.get_model_conv2d()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")
with torch.inference_mode():
output_loaded0 = model(inputs)

# sanity check: same output after loading for adapter 0
assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol)

# hotswap with adapter 1
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
with torch.inference_mode():
output_loaded1 = model(inputs)

# real check: model now behaves like adapter 1
assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol)

# hotswap back to adapter 0
hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default")
with torch.inference_mode():
output_loaded_back0 = model(inputs)

# real check: model now behaves again like adapter 0
assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol)

def test_prepare_model_for_compiled_hotswap_scalings_are_tensors(self):
config = LoraConfig(target_modules=["lin0", "lin1"])
model = self.get_model()
Expand Down

0 comments on commit 023d242

Please sign in to comment.