Skip to content

Commit

Permalink
Run Ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Sep 20, 2024
1 parent 0a1a0ca commit b1999e2
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions tests/backend/quantization/gguf/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
import torch
import pytest
import torch
import torch.nn as nn

from invokeai.backend.quantization.gguf.torch_patcher import TorchPatcher
from invokeai.backend.quantization.gguf.layers import GGUFLayer
from invokeai.backend.quantization.gguf.torch_patcher import TorchPatcher

quantized_sd = {
"linear.weight": torch.load("tests/assets/gguf_qweight.pt"),
"linear.bias": torch.load("tests/assets/gguf_qbias.pt"),
}


class TestGGUFPatcher(TorchPatcher):
class Linear(GGUFLayer, nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.linear(input, weight, bias)


class Test2GGUFPatcher(TorchPatcher):
class Linear(GGUFLayer, nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.linear(input, weight, bias)


# Define a dummy module for testing
class DummyModule(nn.Module):
def __init__(self, device: str='cpu', dtype: torch.dtype=torch.float32):
def __init__(self, device: str = "cpu", dtype: torch.dtype = torch.float32):
super().__init__()
self.linear = nn.Linear(3072, 18432, device=device, dtype=dtype)

def forward(self, x):
x = self.linear(x)
return x


# Test that TorchPatcher patches and unpatches nn.Linear correctly
def test_torch_patcher_patches_nn_linear():
original_linear = nn.Linear
Expand All @@ -41,6 +45,7 @@ def test_torch_patcher_patches_nn_linear():
assert nn.Linear is original_linear
assert nn.Linear is original_linear


# Test that GGUFPatcher patches and unpatches nn.Linear correctly
def test_gguf_patcher_patches_nn_linear():
original_linear = nn.Linear
Expand All @@ -54,14 +59,15 @@ def test_gguf_patcher_patches_nn_linear():
# nn.Linear should be restored
assert nn.Linear is original_linear


# Test that unpatching restores the original behavior
def test_gguf_patcher_unpatch_restores_behavior():
device = 'cpu'
device = "cpu"
dtype = torch.float32

input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
model = DummyModule(device=device, dtype=dtype)
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
model.load_state_dict(quantized_sd)

with TestGGUFPatcher.wrap():
Expand All @@ -74,14 +80,15 @@ def test_gguf_patcher_unpatch_restores_behavior():
assert nn.Linear is not TestGGUFPatcher.Linear
assert isinstance(nn.Linear(4, 8), nn.Linear)


# Test that the patched Linear layer behaves as expected
def test_gguf_patcher_linear_layer_behavior():
device = 'cpu'
device = "cpu"
dtype = torch.float32

input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
model = DummyModule(device=device, dtype=dtype)
with pytest.raises(Exception):
with pytest.raises(Exception): # noqa: B017
model.load_state_dict(quantized_sd)

with TestGGUFPatcher.wrap():
Expand Down Expand Up @@ -113,4 +120,3 @@ def test_torch_patcher_nested_contexts():

# After exiting outer context, nn.Linear should be restored to original
assert nn.Linear is original_linear

0 comments on commit b1999e2

Please sign in to comment.