Skip to content

Commit

Permalink
Remove dependency on flux config files
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 25, 2024
1 parent fa2003e commit 35069cf
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 95 deletions.
7 changes: 2 additions & 5 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
invocation,
invocation_output,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
Expand Down Expand Up @@ -188,17 +189,13 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)

return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=flux_conf["max_seq_len"],
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
Expand Down
20 changes: 20 additions & 0 deletions invokeai/backend/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from dataclasses import dataclass
from typing import Dict, Literal

from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
Expand All @@ -18,6 +19,25 @@ class ModelSpec:
repo_ae: str | None


max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}


ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)


configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
Expand Down
43 changes: 9 additions & 34 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""

from dataclasses import fields
from pathlib import Path
from typing import Any, Optional
from typing import Optional

import accelerate
import torch
import yaml
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs, ae_params
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
Expand Down Expand Up @@ -58,17 +57,9 @@ def _load_model(
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)

dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)

with SilenceWarnings():
model = AutoEncoder(params)
model = AutoEncoder(ae_params)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
Expand Down Expand Up @@ -182,14 +173,10 @@ def _load_model(
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)

match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
return self._load_from_singlefile(config)

raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
Expand All @@ -198,16 +185,12 @@ def _load_model(
def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)

with SilenceWarnings():
model = Flux(params)
model = Flux(configs[config.config_path].params)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model
Expand All @@ -224,14 +207,10 @@ def _load_model(
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)

match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
return self._load_from_singlefile(config)

raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
Expand All @@ -240,21 +219,17 @@ def _load_model(
def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)

with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params)
model = Flux(configs[config.config_path].params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def _get_checkpoint_config_path(
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
config_file = "flux/flux1-dev.yaml"
config_file = "flux-dev"
else:
config_file = "flux/flux1-schnell.yaml"
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
Expand Down
19 changes: 0 additions & 19 deletions invokeai/configs/flux/flux1-dev.yaml

This file was deleted.

19 changes: 0 additions & 19 deletions invokeai/configs/flux/flux1-schnell.yaml

This file was deleted.

16 changes: 0 additions & 16 deletions invokeai/configs/flux/flux1-vae.yaml

This file was deleted.

0 comments on commit 35069cf

Please sign in to comment.