Skip to content

Commit

Permalink
feat: Stable Cascade img2img and remix (#398)
Browse files Browse the repository at this point in the history
* allows sending extra_source_images for both Image and Text gen (currently blocked on textgen)
* allows Stable Cascade img2img
* allows Stable Cascade remix
  • Loading branch information
db0 authored Mar 24, 2024
1 parent 960c041 commit 055778d
Show file tree
Hide file tree
Showing 20 changed files with 240 additions and 27 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

# 4.34.0

* Adds support for Stable Cascade img2img
* Adds support for Stable Cascade remix. Allows sending up to 5 extra images to mash together.

# 4.33.0

* When there is any potential issues with the request, the warnings key will be returned containing an array of potential issues. This should be returned to the user to inform them to potentially cancel the request in advance.
Expand Down
2 changes: 2 additions & 0 deletions horde/apis/models/kobold_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(self, api):
example="00000000-0000-0000-0000-000000000000",
),
),
"extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)),
"skipped": fields.Nested(self.response_model_generations_skipped, skip_none=True),
"softprompt": fields.String(description="The soft prompt requested for this generation."),
"model": fields.String(description="Which of the available models to use for this request."),
Expand Down Expand Up @@ -325,6 +326,7 @@ def __init__(self, api):
"from which this request is coming from."
),
),
"extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)),
"disable_batching": fields.Boolean(
default=False,
description=(
Expand Down
7 changes: 5 additions & 2 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(self, api):
"censorship",
"source_image",
"source_mask",
"extra_source_images",
"batch_index",
],
description="The relevance of the metadata field",
Expand Down Expand Up @@ -471,7 +472,7 @@ def __init__(self, api):
"source_processing": fields.String(
required=False,
default="img2img",
enum=["img2img", "inpainting", "outpainting"],
enum=["img2img", "inpainting", "outpainting", "remix"],
description="If source_image is provided, specifies how to process it.",
),
"source_mask": fields.String(
Expand All @@ -481,6 +482,7 @@ def __init__(self, api):
"If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel."
),
),
"extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)),
"r2_upload": fields.String(description="The r2 upload link to use to upload this image."),
"r2_uploads": fields.List(
fields.String(description="The r2 upload link to use to upload this image."),
Expand Down Expand Up @@ -583,7 +585,7 @@ def __init__(self, api):
"source_processing": fields.String(
required=False,
default="img2img",
enum=["img2img", "inpainting", "outpainting"],
enum=["img2img", "inpainting", "outpainting", "remix"],
description="If source_image is provided, specifies how to process it.",
),
"source_mask": fields.String(
Expand All @@ -593,6 +595,7 @@ def __init__(self, api):
"If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel."
),
),
"extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)),
"r2": fields.Boolean(
default=True,
description="If True, the image will be sent via cloudflare r2 download link.",
Expand Down
15 changes: 15 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def __init__(self):
help="Extra generate params to send to the worker.",
location="json",
)
self.generate_parser.add_argument(
"extra_source_images",
type=list,
required=False,
help="Extra images to send to the worker to processing",
location="json",
)
self.generate_parser.add_argument(
"trusted_workers",
type=bool,
Expand Down Expand Up @@ -316,6 +323,7 @@ def __init__(self, api):
"code": fields.String(description="A unique identifier for this warning.", enum=[i.name for i in WarningMessage]),
"message": fields.String(
description="Something that you should be aware about this request, in plain text.",
min_length=1,
),
},
)
Expand Down Expand Up @@ -1463,3 +1471,10 @@ def __init__(self, api):
"regex": fields.String(required=True, description="The full regex for this filter type."),
},
)
self.model_extra_source_images = api.model(
"ExtraSourceImage",
{
"image": fields.String(description="The Base64-encoded webp to use for further processing."),
"strength": fields.Float(description="Optional field, determining the strength to use for the processing", default=1.0),
},
)
10 changes: 9 additions & 1 deletion horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from horde.database import functions as database
from horde.detection import prompt_checker
from horde.flask import HORDE, cache, db
from horde.image import ensure_source_image_uploaded
from horde.limiter import limiter
from horde.logger import logger
from horde.patreon import patrons
Expand Down Expand Up @@ -325,7 +326,14 @@ def initiate_waiting_prompt(self):

# We split this into its own function, so that it may be overriden and extended
def activate_waiting_prompt(self):
self.wp.activate(self.downgrade_wp_priority)
if self.args.extra_source_images:
for iiter, eimg in enumerate(self.args.extra_source_images):
(
eimg["image"],
_,
_,
) = ensure_source_image_uploaded(eimg["image"], f"{self.wp.id}_exra_src_{iiter}", force_r2=True)
self.wp.activate(self.downgrade_wp_priority, extra_source_images=self.args.extra_source_images)


class SyncGenerate(GenerateTemplate):
Expand Down
9 changes: 6 additions & 3 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,23 @@ def get_size_too_big_message(self):
def validate(self):
super().validate()
if self.params.get("max_context_length", 1024) < self.params.get("max_length", 80):
raise e.BadRequest("You cannot request more tokens than your context length.")
raise e.BadRequest("You cannot request more tokens than your context length.", rc="TokenOverflow")
if "sampler_order" in self.params and len(set(self.params["sampler_order"])) < 7:
raise e.BadRequest(
"When sending a custom sampler order, you need to specify all possible samplers in the order",
rc="MissingFullSamplerOrder",
)
if self.args.extra_source_images is not None and len(self.args.extra_source_images) > 0:
raise e.BadRequest("This request type does not accept extra source images.", rc="InvalidExtraSourceImages.")
if "stop_sequence" in self.params:
stop_seqs = set(self.params["stop_sequence"])
if len(stop_seqs) > 128:
raise e.BadRequest("Too many stop sequences specified (max allowed is 128).")
raise e.BadRequest("Too many stop sequences specified (max allowed is 128).", rc="TooManyStopSequences")
total_stop_seq_len = 0
for seq in stop_seqs:
total_stop_seq_len += len(seq)
if total_stop_seq_len > 2000:
raise e.BadRequest("Your total stop sequence length exceeds the allowed limit (2000 chars).")
raise e.BadRequest("Your total stop sequence length exceeds the allowed limit (2000 chars).", rc="ExcessiveStopSequence")

def get_hashed_params_dict(self):
gen_payload = self.params.copy()
Expand Down
23 changes: 19 additions & 4 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def validate(self):
if "samplers" in model_req_dict and self.params.get("sampler_name", "k_euler_a") not in model_req_dict["samplers"]:
self.warnings.add(WarningMessage.SamplerMismatch)
# FIXME: Scheduler workaround until we support multiple schedulers
scheduler = 'karras'
scheduler = "karras"
if not self.params.get("karras", True):
scheduler = 'simple'
scheduler = "simple"
if "schedulers" in model_req_dict and scheduler not in model_req_dict["schedulers"]:
self.warnings.add(WarningMessage.SchedulerMismatch)
if "control_type" in self.params and any(model_name in ["pix2pix"] for model_name in self.args.models):
Expand All @@ -170,8 +170,6 @@ def validate(self):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with SDXL currently.", rc="ControlNetMismatch")
if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models):
if self.args.source_image:
raise e.BadRequest("Img2Img does not work with Stable Cascade currently.", rc="Img2ImgMismatch")
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("hires fix does not work with Stable Cascade currently.", rc="HiResFixMismatch")
if "control_type" in self.params:
Expand All @@ -184,6 +182,15 @@ def validate(self):
raise e.BadRequest("explicit LoRa version requests have to be a version ID (i.e integer).", rc="BadLoraVersion")
if "tis" in self.params and len(self.params["tis"]) > 20:
raise e.BadRequest("You cannot request more than 20 Textual Inversions per generation.", rc="TooManyTIs")
if self.args.source_processing == "remix" and any(
not model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models
):
raise e.BadRequest("Image Remix is only available for Stable Cascade models.", rc="InvalidRemix")
if self.args.extra_source_images is not None and len(self.args.extra_source_images) > 0:
if len(self.args.extra_source_images) > 5:
raise e.BadRequest("You can send a maximum of 5 extra source images.", rc="TooManyExtraSourceImages.")
if self.args.source_processing != "remix":
raise e.BadRequest("This request type does not accept extra source images.", rc="InvalidExtraSourceImages.")
if self.params.get("init_as_image") and self.params.get("return_control_map"):
raise e.UnsupportedModel(
"Invalid ControlNet parameters - cannot send inital map and return the same map",
Expand Down Expand Up @@ -364,10 +371,18 @@ def activate_waiting_prompt(self):
"Inpainting requests must either include a mask, or an alpha channel.",
rc="InpaintingMissingMask",
)
if self.args.extra_source_images:
for iiter, eimg in enumerate(self.args.extra_source_images):
(
eimg["image"],
_,
_,
) = ensure_source_image_uploaded(eimg["image"], f"{self.wp.id}_exra_src_{iiter}", force_r2=True)
self.wp.activate(
downgrade_wp_priority=self.downgrade_wp_priority,
source_image=self.source_image,
source_mask=self.source_mask,
extra_source_images=self.args.extra_source_images,
)


Expand Down
1 change: 1 addition & 0 deletions horde/bridge_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

BRIDGE_CAPABILITIES = {
"AI Horde Worker reGen": {
5: {"extra_source_images"},
3: {"lora_versions"},
2: {"textual_inversion", "lora"},
1: {
Expand Down
9 changes: 9 additions & 0 deletions horde/classes/base/news.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@

class News:
HORDE_NEWS = [
{
"date_published": "2024-03-24",
"newspiece": (
"The AI Horde now supports [Stable Cascade](https://stability.ai/news/introducing-stable-cascade) along with its"
"[image variations / remix](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/#image-variations) capabilities!"
),
"tags": ["Stable Cascade", "db0", "nlnet"],
"importance": "Information",
},
{
"date_published": "2024-02-13",
"newspiece": (
Expand Down
9 changes: 8 additions & 1 deletion horde/classes/base/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from horde.flask import SQLITE_MODE, db
from horde.logger import logger
from horde.utils import get_db_uuid, get_expiry_date
from horde.bridge_reference import check_bridge_capability

procgen_classes = {
"template": ProcessingGeneration,
Expand Down Expand Up @@ -81,6 +82,7 @@ class WaitingPrompt(db.Model):

params = db.Column(MutableDict.as_mutable(json_column_type), default={}, nullable=False)
gen_payload = db.Column(MutableDict.as_mutable(json_column_type), default={}, nullable=False)
extra_source_images = db.Column(MutableDict.as_mutable(json_column_type), nullable=True)
nsfw = db.Column(db.Boolean, default=False, nullable=False)
ipaddr = db.Column(db.String(39)) # ipv6
safe_ip = db.Column(db.Boolean, default=False, nullable=False)
Expand Down Expand Up @@ -156,11 +158,13 @@ def set_models(self, model_names=None):
model_entry = WPModels(model=model, wp_id=self.id)
db.session.add(model_entry)

def activate(self, downgrade_wp_priority=False):
def activate(self, downgrade_wp_priority=False, extra_source_images=None):
"""We separate the activation from __init__ as often we want to check if there's a valid worker for it
Before we add it to the queue
"""
self.active = True
if extra_source_images is not None and len(extra_source_images) > 0:
self.extra_source_images = {"esi": extra_source_images}
if self.user.flagged and self.user.kudos > 10:
self.extra_priority = round(self.user.kudos / 1000)
elif self.user.flagged:
Expand Down Expand Up @@ -272,6 +276,9 @@ def get_pop_payload(self, procgen_list, payload):
"model": procgen_list[0].model,
"ids": [g.id for g in procgen_list],
}
if self.extra_source_images and check_bridge_capability("extra_source_images", procgen_list[0].worker.bridge_agent):
prompt_payload["extra_source_images"] = self.extra_source_images['esi']

return prompt_payload

def is_completed(self):
Expand Down
4 changes: 2 additions & 2 deletions horde/classes/kobold/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def prepare_job_payload(self, initial_dict=None):
self.gen_payload["n"] = 1
db.session.commit()

def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None):
def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None):
# We separate the activation from __init__ as often we want to check if there's a valid worker for it
# Before we add it to the queue
super().activate(downgrade_wp_priority)
super().activate(downgrade_wp_priority, extra_source_images=extra_source_images)
proxied_account = ""
if self.proxied_account:
proxied_account = f":{self.proxied_account}"
Expand Down
6 changes: 5 additions & 1 deletion horde/classes/stable/kudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ def payload_to_tensor(cls, payload):
payload["sampler_name"] if payload["sampler_name"] in KudosModel.KNOWN_SAMPLERS else "k_euler",
)
data_control_types.append(payload.get("control_type", "None"))
data_source_processing_types.append(payload.get("source_processing", "txt2img"))
sp = payload.get("source_processing", "txt2img")
# Little hack until new model is out
if sp == "remix":
sp = "img2img"
data_source_processing_types.append(sp)
data_post_processors = payload.get("post_processing", [])[:]
# logger.debug([data,data_control_types,data_source_processing_types,data_post_processors])

Expand Down
6 changes: 4 additions & 2 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def get_pop_payload(self, procgen_list, payload):
src_msk = download_source_mask(self.id)
if src_msk:
prompt_payload["source_mask"] = convert_pil_to_b64(src_msk, 50)
if self.extra_source_images and check_bridge_capability("extra_source_images", procgen.worker.bridge_agent):
prompt_payload["extra_source_images"] = self.extra_source_images['esi']
# We always ask the workers to upload the generation to R2 instead of sending it back as b64
# If they send it back as b64 anyway, we upload it outselves
prompt_payload["r2_upload"] = generate_procgen_upload_url(str(procgen.id), self.shared)
Expand All @@ -219,10 +221,10 @@ def get_pop_payload(self, procgen_list, payload):
# logger.debug([payload,prompt_payload])
return prompt_payload

def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None):
def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None):
# We separate the activation from __init__ as often we want to check if there's a valid worker for it
# Before we add it to the queue
super().activate(downgrade_wp_priority)
super().activate(downgrade_wp_priority, extra_source_images=extra_source_images)
if source_image or source_mask:
self.source_image = source_image
self.source_mask = source_mask
Expand Down
9 changes: 6 additions & 3 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@ def can_generate(self, waiting_prompt):
if waiting_prompt.source_image and not check_bridge_capability("img2img", self.bridge_agent):
return [False, "img2img"]
# logger.warning(datetime.utcnow())
if waiting_prompt.source_processing != "img2img":
if waiting_prompt.source_processing in [
"inpainting",
"outpainting",
]:
if not check_bridge_capability("inpainting", self.bridge_agent):
return [False, "painting"]
if not model_reference.has_inpainting_models(self.get_model_names()):
return [False, "models"]
if not self.allow_painting:
return [False, "painting"]
# If the only model loaded is the inpainting ones, we skip the worker when this kind of work is not required
if waiting_prompt.source_processing not in [
"inpainting",
Expand Down Expand Up @@ -124,8 +129,6 @@ def can_generate(self, waiting_prompt):
self.bridge_agent,
):
return [False, "bridge_version"]
if waiting_prompt.source_processing != "img2img" and not self.allow_painting:
return [False, "painting"]
if not waiting_prompt.safe_ip and not self.allow_unsafe_ipaddr:
return [False, "unsafe_ip"]
# We do not give untrusted workers anon or VPN generations, to avoid anything slipping by and spooking them.
Expand Down
2 changes: 1 addition & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
HORDE_VERSION = "4.33.0"
HORDE_VERSION = "4.34.0"

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down
4 changes: 4 additions & 0 deletions horde/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,10 @@ def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, p
ImageWaitingPrompt.source_processing.not_in(["inpainting", "outpainting"]),
worker.allow_painting == True, # noqa E712
),
or_(
ImageWaitingPrompt.extra_source_images == None, # noqa E712
check_bridge_capability("extra_source_images", worker.bridge_agent),
),
or_(
ImageWaitingPrompt.safe_ip == True, # noqa E712
worker.allow_unsafe_ipaddr == True, # noqa E712
Expand Down
7 changes: 7 additions & 0 deletions horde/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@
"InvalidPriorityUsername",
"OnlyServiceAccountProxy",
"RequiresTrust",
"InvalidRemixModel",
"InvalidExtraSourceImages",
"TooManyExtraSourceImages",
"MissingFullSamplerOrder",
"TooManyStopSequences",
"ExcessiveStopSequence",
"TokenOverflow",
]


Expand Down
Loading

0 comments on commit 055778d

Please sign in to comment.