Skip to content

Commit

Permalink
feat: Horde will now return information about potential problems (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 authored Mar 16, 2024
1 parent 4a86871 commit 69b145e
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

# 4.33.0

* When there is any potential issues with the request, the warnings key will be returned containing an array of potential issues. This should be returned to the user to inform them to potentially cancel the request in advance.

# 4.32.5

* parses url-escaped values for model names
Expand Down
14 changes: 14 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from horde.exceptions import KNOWN_RC
from horde.vars import horde_noun, horde_title
from horde.enums import WarningMessage


class Parsers:
Expand Down Expand Up @@ -309,6 +310,18 @@ def __init__(self, api):
"generations": fields.List(fields.Nested(self.response_model_generation_result)),
},
)
self.response_model_warning = api.model(
"RequestSingleWarning",
{
"code": fields.String(
description="A unique identifier for this warning.",
enum=[i.name for i in WarningMessage]
),
"message": fields.String(
description="Something that you should be aware about this request, in plain text.",
),
},
)
self.response_model_async = api.model(
"RequestAsync",
{
Expand All @@ -320,6 +333,7 @@ def __init__(self, api):
default=None,
description="Any extra information from the horde about this request.",
),
"warnings": fields.List(fields.Nested(self.response_model_warning)),
},
)
self.response_model_generation_payload = api.model(
Expand Down
1 change: 1 addition & 0 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def post(self):
# I have to extract and store them this way, because if I use the defaults
# It causes them to be a shared object from the parsers class
self.params = {}
self.warnings = set()
if self.args.params:
self.params = self.args.params
self.models = []
Expand Down
19 changes: 18 additions & 1 deletion horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from horde.patreon import patrons
from horde.utils import hash_dictionary
from horde.vars import horde_title
from horde.enums import WarningMessage

models = ImageModels(api)
parsers = ImageParsers()
Expand Down Expand Up @@ -82,10 +83,13 @@ def post(self):
return ret_dict, 200
ret_dict = {
"id": self.wp.id,
"kudos": round(self.kudos),
"kudos": round(self.kudos),
}
if not database.wp_has_valid_workers(self.wp) and not settings.mode_raid():
ret_dict["message"] = self.get_size_too_big_message()
self.warnings.add(WarningMessage.NoAvailableWorker)
if len(self.warnings) > 0:
ret_dict["warnings"] = list(self.warnings)
return ret_dict, 202

def get_size_too_big_message(self):
Expand Down Expand Up @@ -131,6 +135,19 @@ def validate(self):
model_reference.get_model_baseline(model_name).startswith("stable diffusion 2") for model_name in self.args.models
):
raise e.UnsupportedModel("No current model available for this particular ControlNet for SD2.x", rc="ControlNetUnsupported")
for model_req_dict in [model_reference.get_model_requirements(m) for m in self.args.models]:
if "clip_skip" in model_req_dict and model_req_dict["clip_skip"] != self.params.get("clip_skip", 1):
self.warnings.add(WarningMessage.ClipSkipMismatch)
if "min_steps" in model_req_dict and model_req_dict["min_steps"] > self.params.get("steps", 30):
self.warnings.add(WarningMessage.StepsTooFew)
if "max_steps" in model_req_dict and model_req_dict["max_steps"] < self.params.get("steps", 30):
self.warnings.add(WarningMessage.StepsTooMany)
if "cfg_scale" in model_req_dict and model_req_dict["cfg_scale"] != self.params.get("cfg_scale", 7.5):
self.warnings.add(WarningMessage.CfgScaleMismatch)
if "min_cfg_scale" in model_req_dict and model_req_dict["min_cfg_scale"] > self.params.get("cfg_scale", 7.5):
self.warnings.add(WarningMessage.CfgScaleTooSmall)
if "max_cfg_scale" in model_req_dict and model_req_dict["max_cfg_scale"] < self.params.get("cfg_scale", 7.5):
self.warnings.add(WarningMessage.CfgScaleTooLarge)
if "control_type" in self.params and any(model_name in ["pix2pix"] for model_name in self.args.models):
raise e.UnsupportedModel("You cannot use ControlNet with these models.", rc="ControlNetUnsupported")
# if self.params.get("image_is_control"):
Expand Down
2 changes: 1 addition & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
HORDE_VERSION = "4.32.5"
HORDE_VERSION = "4.33.0"

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down
23 changes: 23 additions & 0 deletions horde/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,26 @@ class UserRoleTypes(enum.Enum):
SPECIAL = 6
SERVICE = 7
EDUCATION = 8

class ReturnedEnum(enum.Enum):

@property
def code(self):
return self.name

@property
def message(self):
return self.value

class WarningMessage(ReturnedEnum):
NoAvailableWorker = (
"Warning: No available workers can fulfill this request. "
"It will expire in 20 minutes unless a worker appears. "
"Please confider reducing its size of the request or choosing a different model."
)
ClipSkipMismatch = "The clip skip specified for this generation does not match the requirements of one of the requested models."
StepsTooFew = "The steps specified for this generation are too few for this model."
StepsTooMany = "The steps specified for this generation are too many for this model."
CfgScaleMismatch = "The cfg scale specified for this generation does not match the requirements of one of the requested models."
CfgScaleTooSmall = "The cfg_scale specified for this generation is too small for this model."
CfgScaleTooLarge = "The cfg_scale specified for this generation is too large for this model."
10 changes: 4 additions & 6 deletions horde/model_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ class ModelReference(PrimaryTimedFunction):
# However due to a racing or caching issue, this causes them to still pick jobs using those models
# Need to investigate more to remove this workaround
testing_models = {
"Juggernaut XL",
"Animagine XL",
"DreamShaper XL",
"Stable Cascade 1.0",
"Anime Illust Diffusion XL",
"Pony Diffusion XL",
}

def call_function(self):
Expand Down Expand Up @@ -94,6 +88,10 @@ def get_model_baseline(self, model_name):
model_details = self.reference.get(model_name, {})
return model_details.get("baseline", "stable diffusion 1")

def get_model_requirements(self, model_name):
model_details = self.reference.get(model_name, {})
return model_details.get("requirements", {})

def get_model_csam_whitelist(self, model_name):
model_details = self.reference.get(model_name, {})
return set(model_details.get("csam_whitelist", []))
Expand Down

0 comments on commit 69b145e

Please sign in to comment.