-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Initial GGUF support for flux models #6890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
617485f
Initial GGUF support for flux models
brandonrising 196102b
Add unit tests for torch patcher
brandonrising b508899
Run Ruff
brandonrising e860726
recognize .gguf files when scanning a folder for import
3278f28
Initial experimentation with Tensor-like extension for GGUF.
RyanJDick d4e4b5b
Get alternative GGUF implementation working... barely.
RyanJDick 87f9668
Add gguf as a pyproject dependency
brandonrising 317b64e
Update invokeai/backend/model_manager/load/model_loaders/flux.py
brandonrising 6959fb4
Various updates to gguf performance
brandonrising de09ba2
Run ruff and fix typing in torch patcher
brandonrising 4cbb54c
Run ruff and update imports
brandonrising 527b05d
Remove no longer used code paths, general cleanup of new dequantizati…
brandonrising 9d5aa77
Fix type errors in GGMLTensor.
RyanJDick 1a8144c
Add unit tests for GGMLTensor.
RyanJDick a926fca
Add a compute_dtype field to GGMLTensor.
RyanJDick 4505114
Add workaround for FLUX GGUF models with incorrect img_in.weight shape.
RyanJDick 2a48587
Add comment describing why we're not using the meta device during pro…
brandonrising fffb620
Add __init__.py file to scripts dir for pytest
brandonrising f61183a
Remove no longer used dequantize_tensor function
brandonrising 0d96d53
Ignore paths that don't exist in probe for unit tests
brandonrising 7d23ba0
Update test_probe_handles_state_dict_with_integer_keys() to make sure…
RyanJDick 0c784dd
Update ui ModelFormatBadge to support GGUF.
RyanJDick e6500d5
Add unit test to confirm that GGMLTensor sizes (bytes) are being calc…
RyanJDick File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import overload | ||
|
||
import gguf | ||
import torch | ||
|
||
from invokeai.backend.quantization.gguf.utils import ( | ||
DEQUANTIZE_FUNCTIONS, | ||
TORCH_COMPATIBLE_QTYPES, | ||
dequantize, | ||
) | ||
|
||
|
||
def dequantize_and_run(func, args, kwargs): | ||
"""A helper function for running math ops on GGMLTensor inputs. | ||
|
||
Dequantizes the inputs, and runs the function. | ||
""" | ||
dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args] | ||
dequantized_kwargs = { | ||
k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items() | ||
} | ||
return func(*dequantized_args, **dequantized_kwargs) | ||
|
||
|
||
def apply_to_quantized_tensor(func, args, kwargs): | ||
"""A helper function to apply a function to a quantized GGML tensor, and re-wrap the result in a GGMLTensor. | ||
|
||
Assumes that the first argument is a GGMLTensor. | ||
""" | ||
# We expect the first argument to be a GGMLTensor, and all other arguments to be non-GGMLTensors. | ||
ggml_tensor = args[0] | ||
assert isinstance(ggml_tensor, GGMLTensor) | ||
assert all(not isinstance(a, GGMLTensor) for a in args[1:]) | ||
assert all(not isinstance(v, GGMLTensor) for v in kwargs.values()) | ||
|
||
new_data = func(ggml_tensor.quantized_data, *args[1:], **kwargs) | ||
|
||
if new_data.dtype != ggml_tensor.quantized_data.dtype: | ||
# This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors. | ||
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") | ||
|
||
return GGMLTensor( | ||
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor.tensor_shape, ggml_tensor.compute_dtype | ||
) | ||
|
||
|
||
GGML_TENSOR_OP_TABLE = { | ||
# Ops to run on the quantized tensor. | ||
torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore | ||
torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore | ||
# Ops to run on dequantized tensors. | ||
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore | ||
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore | ||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore | ||
} | ||
|
||
|
||
class GGMLTensor(torch.Tensor): | ||
"""A torch.Tensor sub-class holding a quantized GGML tensor. | ||
|
||
The underlying tensor is quantized, but the GGMLTensor class provides a dequantized view of the tensor on-the-fly | ||
when it is used in operations. | ||
""" | ||
|
||
@staticmethod | ||
def __new__( | ||
cls, | ||
data: torch.Tensor, | ||
ggml_quantization_type: gguf.GGMLQuantizationType, | ||
tensor_shape: torch.Size, | ||
compute_dtype: torch.dtype, | ||
): | ||
# Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors. | ||
return torch.Tensor._make_wrapper_subclass( # pyright: ignore | ||
cls, | ||
data.shape, | ||
dtype=data.dtype, | ||
layout=data.layout, | ||
device=data.device, | ||
strides=data.stride(), | ||
storage_offset=data.storage_offset(), | ||
) | ||
|
||
def __init__( | ||
self, | ||
data: torch.Tensor, | ||
ggml_quantization_type: gguf.GGMLQuantizationType, | ||
tensor_shape: torch.Size, | ||
compute_dtype: torch.dtype, | ||
): | ||
self.quantized_data = data | ||
self._ggml_quantization_type = ggml_quantization_type | ||
# The dequantized shape of the tensor. | ||
self.tensor_shape = tensor_shape | ||
self.compute_dtype = compute_dtype | ||
|
||
def __repr__(self, *, tensor_contents=None): | ||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self.tensor_shape})" | ||
|
||
@overload | ||
def size(self, dim: None = None) -> torch.Size: ... | ||
|
||
@overload | ||
def size(self, dim: int) -> int: ... | ||
|
||
def size(self, dim: int | None = None): | ||
"""Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" | ||
if dim is not None: | ||
return self.tensor_shape[dim] | ||
return self.tensor_shape | ||
|
||
@property | ||
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. | ||
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" | ||
return self.size() | ||
|
||
@property | ||
def quantized_shape(self) -> torch.Size: | ||
"""The shape of the quantized tensor.""" | ||
return self.quantized_data.shape | ||
|
||
def requires_grad_(self, mode: bool = True) -> torch.Tensor: | ||
"""The GGMLTensor class is currently only designed for inference (not training). Setting requires_grad to True | ||
is not supported. This method is a no-op. | ||
""" | ||
return self | ||
|
||
def get_dequantized_tensor(self): | ||
"""Return the dequantized tensor. | ||
|
||
Args: | ||
dtype: The dtype of the dequantized tensor. | ||
""" | ||
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: | ||
return self.quantized_data.to(self.compute_dtype) | ||
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS: | ||
# TODO(ryand): Look into how the dtype param is intended to be used. | ||
return dequantize( | ||
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None | ||
).to(self.compute_dtype) | ||
else: | ||
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation. | ||
new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type) | ||
return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||
# We will likely hit cases here in the future where a new op is encountered that is not yet supported. | ||
# The new op simply needs to be added to the GGML_TENSOR_OP_TABLE. | ||
if func in GGML_TENSOR_OP_TABLE: | ||
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs) | ||
return NotImplemented |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from pathlib import Path | ||
|
||
import gguf | ||
import torch | ||
|
||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor | ||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES | ||
|
||
|
||
def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]: | ||
reader = gguf.GGUFReader(path) | ||
|
||
sd: dict[str, GGMLTensor] = {} | ||
for tensor in reader.tensors: | ||
torch_tensor = torch.from_numpy(tensor.data) | ||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) | ||
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES: | ||
torch_tensor = torch_tensor.view(*shape) | ||
sd[tensor.name] = GGMLTensor( | ||
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype | ||
) | ||
return sd |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.