Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial GGUF support for flux models #6890

Merged
merged 23 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
617485f
Initial GGUF support for flux models
brandonrising Sep 19, 2024
196102b
Add unit tests for torch patcher
brandonrising Sep 20, 2024
b508899
Run Ruff
brandonrising Sep 20, 2024
e860726
recognize .gguf files when scanning a folder for import
Sep 22, 2024
3278f28
Initial experimentation with Tensor-like extension for GGUF.
RyanJDick Sep 27, 2024
d4e4b5b
Get alternative GGUF implementation working... barely.
RyanJDick Sep 30, 2024
87f9668
Add gguf as a pyproject dependency
brandonrising Sep 26, 2024
317b64e
Update invokeai/backend/model_manager/load/model_loaders/flux.py
brandonrising Sep 27, 2024
6959fb4
Various updates to gguf performance
brandonrising Sep 30, 2024
de09ba2
Run ruff and fix typing in torch patcher
brandonrising Sep 30, 2024
4cbb54c
Run ruff and update imports
brandonrising Oct 1, 2024
527b05d
Remove no longer used code paths, general cleanup of new dequantizati…
brandonrising Oct 1, 2024
9d5aa77
Fix type errors in GGMLTensor.
RyanJDick Oct 1, 2024
1a8144c
Add unit tests for GGMLTensor.
RyanJDick Oct 1, 2024
a926fca
Add a compute_dtype field to GGMLTensor.
RyanJDick Oct 1, 2024
4505114
Add workaround for FLUX GGUF models with incorrect img_in.weight shape.
RyanJDick Oct 1, 2024
2a48587
Add comment describing why we're not using the meta device during pro…
brandonrising Oct 2, 2024
fffb620
Add __init__.py file to scripts dir for pytest
brandonrising Oct 2, 2024
f61183a
Remove no longer used dequantize_tensor function
brandonrising Oct 2, 2024
0d96d53
Ignore paths that don't exist in probe for unit tests
brandonrising Oct 2, 2024
7d23ba0
Update test_probe_handles_state_dict_with_integer_keys() to make sure…
RyanJDick Oct 2, 2024
0c784dd
Update ui ModelFormatBadge to support GGUF.
RyanJDick Oct 2, 2024
e6500d5
Add unit test to confirm that GGMLTensor sizes (bytes) are being calc…
RyanJDick Oct 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def _run_diffusion(
cached_weights=cached_weights,
)
)
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
elif config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
Expand Down
19 changes: 18 additions & 1 deletion invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ModelFormat(str, Enum):
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"


class SchedulerPredictionType(str, Enum):
Expand Down Expand Up @@ -197,7 +198,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> N
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""

format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
Expand Down Expand Up @@ -363,6 +364,21 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")


class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""

prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.GGUFQuantized

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")


class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""

Expand Down Expand Up @@ -466,6 +482,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Expand Down
49 changes: 49 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
Expand All @@ -35,6 +36,8 @@
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
from invokeai.backend.util.silence_warnings import SilenceWarnings

try:
Expand Down Expand Up @@ -204,6 +207,52 @@ def _load_from_singlefile(
return model


@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
class FluxGGUFCheckpointModel(ModelLoader):
"""Class to load GGUF main models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved

match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)

raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)

def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)

with SilenceWarnings():
model = Flux(params[config.config_path])

# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)

# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
# We override the shape here to fix the issue.
# Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants
img_in_weight = sd.get("img_in.weight", None)
if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
expected_img_in_weight_shape = model.img_in.weight.shape
img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape)
img_in_weight.tensor_shape = expected_img_in_weight_shape

model.load_state_dict(sd, assign=True)
return model


@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""
Expand Down
19 changes: 12 additions & 7 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings

Expand Down Expand Up @@ -187,6 +189,7 @@ def probe(
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
Expand Down Expand Up @@ -220,7 +223,7 @@ def get_model_name(cls, model_path: Path) -> str:

@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")

if model_path.name == "learned_embeds.bin":
Expand Down Expand Up @@ -278,12 +281,10 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.SpandrelImageToImage
except spandrel.UnsupportedModelError:
pass
except RuntimeError as e:
if "No such file or directory" in str(e):
# This error is expected if the model_path does not exist (which is the case in some unit tests).
pass
else:
raise e
except Exception as e:
logger.warning(
f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}"
)

raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")

Expand Down Expand Up @@ -408,6 +409,8 @@ def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
model = torch.load(model_path, map_location="cpu")
assert isinstance(model, dict)
return model
elif model_path.suffix.endswith(".gguf"):
return gguf_sd_loader(model_path, compute_dtype=torch.float32)
else:
return safetensors.torch.load_file(model_path)

Expand Down Expand Up @@ -477,6 +480,8 @@ def get_format(self) -> ModelFormat:
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
):
return ModelFormat.BnbQuantizednf4b
elif any(isinstance(v, GGMLTensor) for v in state_dict.values()):
return ModelFormat.GGUFQuantized
return ModelFormat("checkpoint")

def get_variant_type(self) -> ModelVariantType:
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _walk_directory(self, path: Path, max_depth: int = 20) -> None:
return

for n in file_names:
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt", ".gguf")):
try:
self.model_found(absolute_path / n)
except KeyboardInterrupt:
Expand Down
8 changes: 7 additions & 1 deletion invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from picklescan.scanner import scan_file_path

from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader


def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
checkpoint = {}
Expand Down Expand Up @@ -54,7 +56,11 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
if str(path).endswith(".gguf"):
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint


Expand Down
152 changes: 152 additions & 0 deletions invokeai/backend/quantization/gguf/ggml_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import overload

import gguf
import torch

from invokeai.backend.quantization.gguf.utils import (
DEQUANTIZE_FUNCTIONS,
TORCH_COMPATIBLE_QTYPES,
dequantize,
)


def dequantize_and_run(func, args, kwargs):
"""A helper function for running math ops on GGMLTensor inputs.

Dequantizes the inputs, and runs the function.
"""
dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
dequantized_kwargs = {
k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
}
return func(*dequantized_args, **dequantized_kwargs)


def apply_to_quantized_tensor(func, args, kwargs):
"""A helper function to apply a function to a quantized GGML tensor, and re-wrap the result in a GGMLTensor.

Assumes that the first argument is a GGMLTensor.
"""
# We expect the first argument to be a GGMLTensor, and all other arguments to be non-GGMLTensors.
ggml_tensor = args[0]
assert isinstance(ggml_tensor, GGMLTensor)
assert all(not isinstance(a, GGMLTensor) for a in args[1:])
assert all(not isinstance(v, GGMLTensor) for v in kwargs.values())

new_data = func(ggml_tensor.quantized_data, *args[1:], **kwargs)

if new_data.dtype != ggml_tensor.quantized_data.dtype:
# This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors.
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")

return GGMLTensor(
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor.tensor_shape, ggml_tensor.compute_dtype
)


GGML_TENSOR_OP_TABLE = {
# Ops to run on the quantized tensor.
torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore
torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore
# Ops to run on dequantized tensors.
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
}


class GGMLTensor(torch.Tensor):
"""A torch.Tensor sub-class holding a quantized GGML tensor.

The underlying tensor is quantized, but the GGMLTensor class provides a dequantized view of the tensor on-the-fly
when it is used in operations.
"""

@staticmethod
def __new__(
cls,
data: torch.Tensor,
ggml_quantization_type: gguf.GGMLQuantizationType,
tensor_shape: torch.Size,
compute_dtype: torch.dtype,
):
# Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors.
return torch.Tensor._make_wrapper_subclass( # pyright: ignore
cls,
data.shape,
dtype=data.dtype,
layout=data.layout,
device=data.device,
strides=data.stride(),
storage_offset=data.storage_offset(),
)

def __init__(
self,
data: torch.Tensor,
ggml_quantization_type: gguf.GGMLQuantizationType,
tensor_shape: torch.Size,
compute_dtype: torch.dtype,
):
self.quantized_data = data
self._ggml_quantization_type = ggml_quantization_type
# The dequantized shape of the tensor.
self.tensor_shape = tensor_shape
self.compute_dtype = compute_dtype

def __repr__(self, *, tensor_contents=None):
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self.tensor_shape})"

@overload
def size(self, dim: None = None) -> torch.Size: ...

@overload
def size(self, dim: int) -> int: ...

def size(self, dim: int | None = None):
"""Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
if dim is not None:
return self.tensor_shape[dim]
return self.tensor_shape

@property
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
return self.size()

@property
def quantized_shape(self) -> torch.Size:
"""The shape of the quantized tensor."""
return self.quantized_data.shape

def requires_grad_(self, mode: bool = True) -> torch.Tensor:
"""The GGMLTensor class is currently only designed for inference (not training). Setting requires_grad to True
is not supported. This method is a no-op.
"""
return self

def get_dequantized_tensor(self):
"""Return the dequantized tensor.

Args:
dtype: The dtype of the dequantized tensor.
"""
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
return self.quantized_data.to(self.compute_dtype)
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
# TODO(ryand): Look into how the dtype param is intended to be used.
return dequantize(
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None
).to(self.compute_dtype)
else:
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type)
return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# We will likely hit cases here in the future where a new op is encountered that is not yet supported.
# The new op simply needs to be added to the GGML_TENSOR_OP_TABLE.
if func in GGML_TENSOR_OP_TABLE:
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
return NotImplemented
22 changes: 22 additions & 0 deletions invokeai/backend/quantization/gguf/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path

import gguf
import torch

from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES


def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
reader = gguf.GGUFReader(path)

sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
)
return sd
Loading
Loading