diff --git a/CHANGELOG.md b/CHANGELOG.md index b2d37d35..f066797a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +# 4.35.1 + +* Added allow_sdxl_controlnet worker key + # 4.35.0 * Added ability to generate QR-code images diff --git a/horde/apis/models/stable_v2.py b/horde/apis/models/stable_v2.py index c2bba38e..9951fc84 100644 --- a/horde/apis/models/stable_v2.py +++ b/horde/apis/models/stable_v2.py @@ -132,6 +132,14 @@ def __init__(self): help="If True, this worker will pick up requests requesting ControlNet.", location="json", ) + self.job_pop_parser.add_argument( + "allow_sdxl_controlnet", + type=bool, + required=False, + default=False, + help="If True, this worker will pick up requests requesting SDXL ControlNet.", + location="json", + ) self.job_pop_parser.add_argument( "allow_lora", type=bool, @@ -519,6 +527,10 @@ def __init__(self, api): default=True, description="If True, this worker will pick up requests requesting ControlNet.", ), + "allow_sdxl_controlnet": fields.Boolean( + default=True, + description="If True, this worker will pick up requests requesting SDXL ControlNet.", + ), "allow_lora": fields.Boolean( default=True, description="If True, this worker will pick up requests requesting LoRas.", diff --git a/horde/apis/models/v2.py b/horde/apis/models/v2.py index a8994cd6..36ea7877 100644 --- a/horde/apis/models/v2.py +++ b/horde/apis/models/v2.py @@ -620,6 +620,14 @@ def __init__(self, api): default=None, description="If True, this worker supports and allows lora requests.", ), + "controlnet": fields.Boolean( + default=None, + description="If True, this worker supports and allows controlnet requests.", + ), + "sdxl_controlnet": fields.Boolean( + default=None, + description="If True, this worker supports and allows SDXL controlnet requests.", + ), "max_length": fields.Integer( example=80, description="The maximum tokens this worker can generate.", diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 289d3603..e84ce286 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -191,7 +191,7 @@ def validate(self): if self.params.get("workflow") == "qr_code": # QR-code pipeline cannot do batching currently self.args["disable_batching"] = True - if not all(model_reference.get_model_baseline(model_name).startswith("stable diffusion 1") for model_name in self.args.models): + if not all(model_reference.get_model_baseline(model_name) in ['stable_diffusion 1', 'stable_diffusion_xl'] for model_name in self.args.models): raise e.BadRequest("QR Code controlnet only works with SD 1.5 models currently", rc="ControlNetMismatch.") if self.params.get("extra_texts") is None or len(self.params.get("extra_texts")) == 0: raise e.BadRequest("This request requires you pass the required extra texts for this workflow.", rc="MissingExtraTexts.") @@ -588,6 +588,7 @@ def check_in(self): allow_unsafe_ipaddr=self.args.allow_unsafe_ipaddr, allow_post_processing=self.args.allow_post_processing, allow_controlnet=self.args.allow_controlnet, + allow_sdxl_controlnet=self.args.allow_sdxl_controlnet, allow_lora=self.args.allow_lora, priority_usernames=self.priority_usernames, ) diff --git a/horde/classes/stable/worker.py b/horde/classes/stable/worker.py index 193225e1..316531c1 100644 --- a/horde/classes/stable/worker.py +++ b/horde/classes/stable/worker.py @@ -23,6 +23,7 @@ class ImageWorker(Worker): allow_painting = db.Column(db.Boolean, default=True, nullable=False) allow_post_processing = db.Column(db.Boolean, default=True, nullable=False) allow_controlnet = db.Column(db.Boolean, default=False, nullable=False) + allow_sdxl_controlnet = db.Column(db.Boolean, default=False, nullable=False) allow_lora = db.Column(db.Boolean, default=False, nullable=False) wtype = "image" @@ -36,6 +37,7 @@ def check_in(self, max_pixels, **kwargs): self.allow_painting = kwargs.get("allow_painting", True) self.allow_post_processing = kwargs.get("allow_post_processing", True) self.allow_controlnet = kwargs.get("allow_controlnet", False) + self.allow_sdxl_controlnet = kwargs.get("allow_sdxl_controlnet", False) self.allow_lora = kwargs.get("allow_lora", False) if len(self.get_model_names()) == 0: self.set_models(["stable_diffusion"]) @@ -116,7 +118,14 @@ def can_generate(self, waiting_prompt): if not check_bridge_capability("image_is_control", self.bridge_agent): return [False, "bridge_version"] if not self.allow_controlnet: + return [False, "controlnet"] + if waiting_prompt.params.get("workflow") == "qr_code": + if not check_bridge_capability("controlnet", self.bridge_agent): + return [False, "bridge_version"] + if not check_bridge_capability("qr_code", self.bridge_agent): return [False, "bridge_version"] + if not self.allow_sdxl_controlnet: + return [False, "controlnet"] if waiting_prompt.params.get("hires_fix") and not check_bridge_capability("hires_fix", self.bridge_agent): return [False, "bridge_version"] if ( @@ -169,6 +178,7 @@ def get_details(self, details_privilege=0): ret_dict["painting"] = self.allow_painting if check_bridge_capability("inpainting", self.bridge_agent) else False ret_dict["post-processing"] = self.allow_post_processing ret_dict["controlnet"] = self.allow_controlnet + ret_dict["sdxl_controlnet"] = self.allow_sdxl_controlnet ret_dict["lora"] = self.allow_lora return ret_dict diff --git a/sql_statements/4.35.1.txt b/sql_statements/4.35.1.txt new file mode 100644 index 00000000..1b8bb39d --- /dev/null +++ b/sql_statements/4.35.1.txt @@ -0,0 +1 @@ +ALTER TABLE workers ADD COLUMN allow_sdxl_controlnet BOOLEAN default false not null; diff --git a/tests/test_image.py b/tests/test_image.py index 486fd460..6d6018e4 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -42,6 +42,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: "allow_unsafe_ipaddr": True, "allow_post_processing": True, "allow_controlnet": True, + "allow_sdxl_controlnet": True, "allow_lora": True, } pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) diff --git a/tests/test_image_extra_sources.py b/tests/test_image_extra_sources.py index 54ec1980..b6f0431a 100644 --- a/tests/test_image_extra_sources.py +++ b/tests/test_image_extra_sources.py @@ -63,6 +63,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: "allow_unsafe_ipaddr": True, "allow_post_processing": True, "allow_controlnet": True, + "allow_sdxl_controlnet": True, "allow_lora": True, } pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)