Skip to content

Commit

Permalink
allow_sdxl_controlnet
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed May 21, 2024
1 parent 1467ad7 commit 2b6482c
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

# 4.35.1

* Added allow_sdxl_controlnet worker key

# 4.35.0

* Added ability to generate QR-code images
Expand Down
12 changes: 12 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ def __init__(self):
help="If True, this worker will pick up requests requesting ControlNet.",
location="json",
)
self.job_pop_parser.add_argument(
"allow_sdxl_controlnet",
type=bool,
required=False,
default=False,
help="If True, this worker will pick up requests requesting SDXL ControlNet.",
location="json",
)
self.job_pop_parser.add_argument(
"allow_lora",
type=bool,
Expand Down Expand Up @@ -519,6 +527,10 @@ def __init__(self, api):
default=True,
description="If True, this worker will pick up requests requesting ControlNet.",
),
"allow_sdxl_controlnet": fields.Boolean(
default=True,
description="If True, this worker will pick up requests requesting SDXL ControlNet.",
),
"allow_lora": fields.Boolean(
default=True,
description="If True, this worker will pick up requests requesting LoRas.",
Expand Down
8 changes: 8 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,14 @@ def __init__(self, api):
default=None,
description="If True, this worker supports and allows lora requests.",
),
"controlnet": fields.Boolean(
default=None,
description="If True, this worker supports and allows controlnet requests.",
),
"sdxl_controlnet": fields.Boolean(
default=None,
description="If True, this worker supports and allows SDXL controlnet requests.",
),
"max_length": fields.Integer(
example=80,
description="The maximum tokens this worker can generate.",
Expand Down
3 changes: 2 additions & 1 deletion horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def validate(self):
if self.params.get("workflow") == "qr_code":
# QR-code pipeline cannot do batching currently
self.args["disable_batching"] = True
if not all(model_reference.get_model_baseline(model_name).startswith("stable diffusion 1") for model_name in self.args.models):
if not all(model_reference.get_model_baseline(model_name) in ['stable_diffusion 1', 'stable_diffusion_xl'] for model_name in self.args.models):
raise e.BadRequest("QR Code controlnet only works with SD 1.5 models currently", rc="ControlNetMismatch.")
if self.params.get("extra_texts") is None or len(self.params.get("extra_texts")) == 0:
raise e.BadRequest("This request requires you pass the required extra texts for this workflow.", rc="MissingExtraTexts.")
Expand Down Expand Up @@ -588,6 +588,7 @@ def check_in(self):
allow_unsafe_ipaddr=self.args.allow_unsafe_ipaddr,
allow_post_processing=self.args.allow_post_processing,
allow_controlnet=self.args.allow_controlnet,
allow_sdxl_controlnet=self.args.allow_sdxl_controlnet,
allow_lora=self.args.allow_lora,
priority_usernames=self.priority_usernames,
)
Expand Down
10 changes: 10 additions & 0 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ImageWorker(Worker):
allow_painting = db.Column(db.Boolean, default=True, nullable=False)
allow_post_processing = db.Column(db.Boolean, default=True, nullable=False)
allow_controlnet = db.Column(db.Boolean, default=False, nullable=False)
allow_sdxl_controlnet = db.Column(db.Boolean, default=False, nullable=False)
allow_lora = db.Column(db.Boolean, default=False, nullable=False)
wtype = "image"

Expand All @@ -36,6 +37,7 @@ def check_in(self, max_pixels, **kwargs):
self.allow_painting = kwargs.get("allow_painting", True)
self.allow_post_processing = kwargs.get("allow_post_processing", True)
self.allow_controlnet = kwargs.get("allow_controlnet", False)
self.allow_sdxl_controlnet = kwargs.get("allow_sdxl_controlnet", False)
self.allow_lora = kwargs.get("allow_lora", False)
if len(self.get_model_names()) == 0:
self.set_models(["stable_diffusion"])
Expand Down Expand Up @@ -116,7 +118,14 @@ def can_generate(self, waiting_prompt):
if not check_bridge_capability("image_is_control", self.bridge_agent):
return [False, "bridge_version"]
if not self.allow_controlnet:
return [False, "controlnet"]
if waiting_prompt.params.get("workflow") == "qr_code":
if not check_bridge_capability("controlnet", self.bridge_agent):
return [False, "bridge_version"]
if not check_bridge_capability("qr_code", self.bridge_agent):
return [False, "bridge_version"]
if not self.allow_sdxl_controlnet:
return [False, "controlnet"]
if waiting_prompt.params.get("hires_fix") and not check_bridge_capability("hires_fix", self.bridge_agent):
return [False, "bridge_version"]
if (
Expand Down Expand Up @@ -169,6 +178,7 @@ def get_details(self, details_privilege=0):
ret_dict["painting"] = self.allow_painting if check_bridge_capability("inpainting", self.bridge_agent) else False
ret_dict["post-processing"] = self.allow_post_processing
ret_dict["controlnet"] = self.allow_controlnet
ret_dict["sdxl_controlnet"] = self.allow_sdxl_controlnet
ret_dict["lora"] = self.allow_lora
return ret_dict

Expand Down
1 change: 1 addition & 0 deletions sql_statements/4.35.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE workers ADD COLUMN allow_sdxl_controlnet BOOLEAN default false not null;
1 change: 1 addition & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"allow_unsafe_ipaddr": True,
"allow_post_processing": True,
"allow_controlnet": True,
"allow_sdxl_controlnet": True,
"allow_lora": True,
}
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
Expand Down
1 change: 1 addition & 0 deletions tests/test_image_extra_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"allow_unsafe_ipaddr": True,
"allow_post_processing": True,
"allow_controlnet": True,
"allow_sdxl_controlnet": True,
"allow_lora": True,
}
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
Expand Down

0 comments on commit 2b6482c

Please sign in to comment.