diff --git a/CHANGELOG.md b/CHANGELOG.md index 136b7445..b0803755 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +# 4.34.0 + +* Adds support for Stable Cascade img2img +* Adds support for Stable Cascade remix. Allows sending up to 5 extra images to mash together. + # 4.33.0 * When there is any potential issues with the request, the warnings key will be returned containing an array of potential issues. This should be returned to the user to inform them to potentially cancel the request in advance. diff --git a/horde/apis/models/kobold_v2.py b/horde/apis/models/kobold_v2.py index 8c2c8a9d..75060426 100644 --- a/horde/apis/models/kobold_v2.py +++ b/horde/apis/models/kobold_v2.py @@ -259,6 +259,7 @@ def __init__(self, api): example="00000000-0000-0000-0000-000000000000", ), ), + "extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)), "skipped": fields.Nested(self.response_model_generations_skipped, skip_none=True), "softprompt": fields.String(description="The soft prompt requested for this generation."), "model": fields.String(description="Which of the available models to use for this request."), @@ -325,6 +326,7 @@ def __init__(self, api): "from which this request is coming from." ), ), + "extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)), "disable_batching": fields.Boolean( default=False, description=( diff --git a/horde/apis/models/stable_v2.py b/horde/apis/models/stable_v2.py index 1d3e18ef..cde937f6 100644 --- a/horde/apis/models/stable_v2.py +++ b/horde/apis/models/stable_v2.py @@ -182,6 +182,7 @@ def __init__(self, api): "censorship", "source_image", "source_mask", + "extra_source_images", "batch_index", ], description="The relevance of the metadata field", @@ -471,7 +472,7 @@ def __init__(self, api): "source_processing": fields.String( required=False, default="img2img", - enum=["img2img", "inpainting", "outpainting"], + enum=["img2img", "inpainting", "outpainting", "remix"], description="If source_image is provided, specifies how to process it.", ), "source_mask": fields.String( @@ -481,6 +482,7 @@ def __init__(self, api): "If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel." ), ), + "extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)), "r2_upload": fields.String(description="The r2 upload link to use to upload this image."), "r2_uploads": fields.List( fields.String(description="The r2 upload link to use to upload this image."), @@ -583,7 +585,7 @@ def __init__(self, api): "source_processing": fields.String( required=False, default="img2img", - enum=["img2img", "inpainting", "outpainting"], + enum=["img2img", "inpainting", "outpainting", "remix"], description="If source_image is provided, specifies how to process it.", ), "source_mask": fields.String( @@ -593,6 +595,7 @@ def __init__(self, api): "If this arg is not passed, the inpainting/outpainting mask has to be embedded as alpha channel." ), ), + "extra_source_images": fields.List(fields.Nested(self.model_extra_source_images)), "r2": fields.Boolean( default=True, description="If True, the image will be sent via cloudflare r2 download link.", diff --git a/horde/apis/models/v2.py b/horde/apis/models/v2.py index b6481b9b..61363cd8 100644 --- a/horde/apis/models/v2.py +++ b/horde/apis/models/v2.py @@ -37,6 +37,13 @@ def __init__(self): help="Extra generate params to send to the worker.", location="json", ) + self.generate_parser.add_argument( + "extra_source_images", + type=list, + required=False, + help="Extra images to send to the worker to processing", + location="json", + ) self.generate_parser.add_argument( "trusted_workers", type=bool, @@ -316,6 +323,7 @@ def __init__(self, api): "code": fields.String(description="A unique identifier for this warning.", enum=[i.name for i in WarningMessage]), "message": fields.String( description="Something that you should be aware about this request, in plain text.", + min_length=1, ), }, ) @@ -1463,3 +1471,10 @@ def __init__(self, api): "regex": fields.String(required=True, description="The full regex for this filter type."), }, ) + self.model_extra_source_images = api.model( + "ExtraSourceImage", + { + "image": fields.String(description="The Base64-encoded webp to use for further processing."), + "strength": fields.Float(description="Optional field, determining the strength to use for the processing", default=1.0), + }, + ) diff --git a/horde/apis/v2/base.py b/horde/apis/v2/base.py index 76e1823e..0de57d0f 100644 --- a/horde/apis/v2/base.py +++ b/horde/apis/v2/base.py @@ -28,6 +28,7 @@ from horde.database import functions as database from horde.detection import prompt_checker from horde.flask import HORDE, cache, db +from horde.image import ensure_source_image_uploaded from horde.limiter import limiter from horde.logger import logger from horde.patreon import patrons @@ -325,7 +326,14 @@ def initiate_waiting_prompt(self): # We split this into its own function, so that it may be overriden and extended def activate_waiting_prompt(self): - self.wp.activate(self.downgrade_wp_priority) + if self.args.extra_source_images: + for iiter, eimg in enumerate(self.args.extra_source_images): + ( + eimg["image"], + _, + _, + ) = ensure_source_image_uploaded(eimg["image"], f"{self.wp.id}_exra_src_{iiter}", force_r2=True) + self.wp.activate(self.downgrade_wp_priority, extra_source_images=self.args.extra_source_images) class SyncGenerate(GenerateTemplate): diff --git a/horde/apis/v2/kobold.py b/horde/apis/v2/kobold.py index 531af12f..e1cd167d 100644 --- a/horde/apis/v2/kobold.py +++ b/horde/apis/v2/kobold.py @@ -150,20 +150,23 @@ def get_size_too_big_message(self): def validate(self): super().validate() if self.params.get("max_context_length", 1024) < self.params.get("max_length", 80): - raise e.BadRequest("You cannot request more tokens than your context length.") + raise e.BadRequest("You cannot request more tokens than your context length.", rc="TokenOverflow") if "sampler_order" in self.params and len(set(self.params["sampler_order"])) < 7: raise e.BadRequest( "When sending a custom sampler order, you need to specify all possible samplers in the order", + rc="MissingFullSamplerOrder", ) + if self.args.extra_source_images is not None and len(self.args.extra_source_images) > 0: + raise e.BadRequest("This request type does not accept extra source images.", rc="InvalidExtraSourceImages.") if "stop_sequence" in self.params: stop_seqs = set(self.params["stop_sequence"]) if len(stop_seqs) > 128: - raise e.BadRequest("Too many stop sequences specified (max allowed is 128).") + raise e.BadRequest("Too many stop sequences specified (max allowed is 128).", rc="TooManyStopSequences") total_stop_seq_len = 0 for seq in stop_seqs: total_stop_seq_len += len(seq) if total_stop_seq_len > 2000: - raise e.BadRequest("Your total stop sequence length exceeds the allowed limit (2000 chars).") + raise e.BadRequest("Your total stop sequence length exceeds the allowed limit (2000 chars).", rc="ExcessiveStopSequence") def get_hashed_params_dict(self): gen_payload = self.params.copy() diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 775606a0..a450557d 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -151,9 +151,9 @@ def validate(self): if "samplers" in model_req_dict and self.params.get("sampler_name", "k_euler_a") not in model_req_dict["samplers"]: self.warnings.add(WarningMessage.SamplerMismatch) # FIXME: Scheduler workaround until we support multiple schedulers - scheduler = 'karras' + scheduler = "karras" if not self.params.get("karras", True): - scheduler = 'simple' + scheduler = "simple" if "schedulers" in model_req_dict and scheduler not in model_req_dict["schedulers"]: self.warnings.add(WarningMessage.SchedulerMismatch) if "control_type" in self.params and any(model_name in ["pix2pix"] for model_name in self.args.models): @@ -170,8 +170,6 @@ def validate(self): if "control_type" in self.params: raise e.BadRequest("ControlNet does not work with SDXL currently.", rc="ControlNetMismatch") if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models): - if self.args.source_image: - raise e.BadRequest("Img2Img does not work with Stable Cascade currently.", rc="Img2ImgMismatch") if self.params.get("hires_fix", False) is True: raise e.BadRequest("hires fix does not work with Stable Cascade currently.", rc="HiResFixMismatch") if "control_type" in self.params: @@ -184,6 +182,15 @@ def validate(self): raise e.BadRequest("explicit LoRa version requests have to be a version ID (i.e integer).", rc="BadLoraVersion") if "tis" in self.params and len(self.params["tis"]) > 20: raise e.BadRequest("You cannot request more than 20 Textual Inversions per generation.", rc="TooManyTIs") + if self.args.source_processing == "remix" and any( + not model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models + ): + raise e.BadRequest("Image Remix is only available for Stable Cascade models.", rc="InvalidRemix") + if self.args.extra_source_images is not None and len(self.args.extra_source_images) > 0: + if len(self.args.extra_source_images) > 5: + raise e.BadRequest("You can send a maximum of 5 extra source images.", rc="TooManyExtraSourceImages.") + if self.args.source_processing != "remix": + raise e.BadRequest("This request type does not accept extra source images.", rc="InvalidExtraSourceImages.") if self.params.get("init_as_image") and self.params.get("return_control_map"): raise e.UnsupportedModel( "Invalid ControlNet parameters - cannot send inital map and return the same map", @@ -364,10 +371,18 @@ def activate_waiting_prompt(self): "Inpainting requests must either include a mask, or an alpha channel.", rc="InpaintingMissingMask", ) + if self.args.extra_source_images: + for iiter, eimg in enumerate(self.args.extra_source_images): + ( + eimg["image"], + _, + _, + ) = ensure_source_image_uploaded(eimg["image"], f"{self.wp.id}_exra_src_{iiter}", force_r2=True) self.wp.activate( downgrade_wp_priority=self.downgrade_wp_priority, source_image=self.source_image, source_mask=self.source_mask, + extra_source_images=self.args.extra_source_images, ) diff --git a/horde/bridge_reference.py b/horde/bridge_reference.py index f4bb9611..7344894a 100644 --- a/horde/bridge_reference.py +++ b/horde/bridge_reference.py @@ -5,6 +5,7 @@ BRIDGE_CAPABILITIES = { "AI Horde Worker reGen": { + 5: {"extra_source_images"}, 3: {"lora_versions"}, 2: {"textual_inversion", "lora"}, 1: { diff --git a/horde/classes/base/news.py b/horde/classes/base/news.py index d96b8df1..0bb78693 100644 --- a/horde/classes/base/news.py +++ b/horde/classes/base/news.py @@ -3,6 +3,15 @@ class News: HORDE_NEWS = [ + { + "date_published": "2024-03-24", + "newspiece": ( + "The AI Horde now supports [Stable Cascade](https://stability.ai/news/introducing-stable-cascade) along with its" + "[image variations / remix](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/#image-variations) capabilities!" + ), + "tags": ["Stable Cascade", "db0", "nlnet"], + "importance": "Information", + }, { "date_published": "2024-02-13", "newspiece": ( diff --git a/horde/classes/base/waiting_prompt.py b/horde/classes/base/waiting_prompt.py index 820b499e..d2c0cbad 100644 --- a/horde/classes/base/waiting_prompt.py +++ b/horde/classes/base/waiting_prompt.py @@ -15,6 +15,7 @@ from horde.flask import SQLITE_MODE, db from horde.logger import logger from horde.utils import get_db_uuid, get_expiry_date +from horde.bridge_reference import check_bridge_capability procgen_classes = { "template": ProcessingGeneration, @@ -81,6 +82,7 @@ class WaitingPrompt(db.Model): params = db.Column(MutableDict.as_mutable(json_column_type), default={}, nullable=False) gen_payload = db.Column(MutableDict.as_mutable(json_column_type), default={}, nullable=False) + extra_source_images = db.Column(MutableDict.as_mutable(json_column_type), nullable=True) nsfw = db.Column(db.Boolean, default=False, nullable=False) ipaddr = db.Column(db.String(39)) # ipv6 safe_ip = db.Column(db.Boolean, default=False, nullable=False) @@ -156,11 +158,13 @@ def set_models(self, model_names=None): model_entry = WPModels(model=model, wp_id=self.id) db.session.add(model_entry) - def activate(self, downgrade_wp_priority=False): + def activate(self, downgrade_wp_priority=False, extra_source_images=None): """We separate the activation from __init__ as often we want to check if there's a valid worker for it Before we add it to the queue """ self.active = True + if extra_source_images is not None and len(extra_source_images) > 0: + self.extra_source_images = {"esi": extra_source_images} if self.user.flagged and self.user.kudos > 10: self.extra_priority = round(self.user.kudos / 1000) elif self.user.flagged: @@ -272,6 +276,9 @@ def get_pop_payload(self, procgen_list, payload): "model": procgen_list[0].model, "ids": [g.id for g in procgen_list], } + if self.extra_source_images and check_bridge_capability("extra_source_images", procgen_list[0].worker.bridge_agent): + prompt_payload["extra_source_images"] = self.extra_source_images['esi'] + return prompt_payload def is_completed(self): diff --git a/horde/classes/kobold/waiting_prompt.py b/horde/classes/kobold/waiting_prompt.py index ab052269..0a48a189 100644 --- a/horde/classes/kobold/waiting_prompt.py +++ b/horde/classes/kobold/waiting_prompt.py @@ -57,10 +57,10 @@ def prepare_job_payload(self, initial_dict=None): self.gen_payload["n"] = 1 db.session.commit() - def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None): + def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None): # We separate the activation from __init__ as often we want to check if there's a valid worker for it # Before we add it to the queue - super().activate(downgrade_wp_priority) + super().activate(downgrade_wp_priority, extra_source_images=extra_source_images) proxied_account = "" if self.proxied_account: proxied_account = f":{self.proxied_account}" diff --git a/horde/classes/stable/kudos.py b/horde/classes/stable/kudos.py index 6381aa43..da712080 100644 --- a/horde/classes/stable/kudos.py +++ b/horde/classes/stable/kudos.py @@ -202,7 +202,11 @@ def payload_to_tensor(cls, payload): payload["sampler_name"] if payload["sampler_name"] in KudosModel.KNOWN_SAMPLERS else "k_euler", ) data_control_types.append(payload.get("control_type", "None")) - data_source_processing_types.append(payload.get("source_processing", "txt2img")) + sp = payload.get("source_processing", "txt2img") + # Little hack until new model is out + if sp == "remix": + sp = "img2img" + data_source_processing_types.append(sp) data_post_processors = payload.get("post_processing", [])[:] # logger.debug([data,data_control_types,data_source_processing_types,data_post_processors]) diff --git a/horde/classes/stable/waiting_prompt.py b/horde/classes/stable/waiting_prompt.py index 90253aaa..a552eb34 100644 --- a/horde/classes/stable/waiting_prompt.py +++ b/horde/classes/stable/waiting_prompt.py @@ -208,6 +208,8 @@ def get_pop_payload(self, procgen_list, payload): src_msk = download_source_mask(self.id) if src_msk: prompt_payload["source_mask"] = convert_pil_to_b64(src_msk, 50) + if self.extra_source_images and check_bridge_capability("extra_source_images", procgen.worker.bridge_agent): + prompt_payload["extra_source_images"] = self.extra_source_images['esi'] # We always ask the workers to upload the generation to R2 instead of sending it back as b64 # If they send it back as b64 anyway, we upload it outselves prompt_payload["r2_upload"] = generate_procgen_upload_url(str(procgen.id), self.shared) @@ -219,10 +221,10 @@ def get_pop_payload(self, procgen_list, payload): # logger.debug([payload,prompt_payload]) return prompt_payload - def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None): + def activate(self, downgrade_wp_priority=False, source_image=None, source_mask=None, extra_source_images=None): # We separate the activation from __init__ as often we want to check if there's a valid worker for it # Before we add it to the queue - super().activate(downgrade_wp_priority) + super().activate(downgrade_wp_priority, extra_source_images=extra_source_images) if source_image or source_mask: self.source_image = source_image self.source_mask = source_mask diff --git a/horde/classes/stable/worker.py b/horde/classes/stable/worker.py index 96207739..de1308ec 100644 --- a/horde/classes/stable/worker.py +++ b/horde/classes/stable/worker.py @@ -62,11 +62,16 @@ def can_generate(self, waiting_prompt): if waiting_prompt.source_image and not check_bridge_capability("img2img", self.bridge_agent): return [False, "img2img"] # logger.warning(datetime.utcnow()) - if waiting_prompt.source_processing != "img2img": + if waiting_prompt.source_processing in [ + "inpainting", + "outpainting", + ]: if not check_bridge_capability("inpainting", self.bridge_agent): return [False, "painting"] if not model_reference.has_inpainting_models(self.get_model_names()): return [False, "models"] + if not self.allow_painting: + return [False, "painting"] # If the only model loaded is the inpainting ones, we skip the worker when this kind of work is not required if waiting_prompt.source_processing not in [ "inpainting", @@ -124,8 +129,6 @@ def can_generate(self, waiting_prompt): self.bridge_agent, ): return [False, "bridge_version"] - if waiting_prompt.source_processing != "img2img" and not self.allow_painting: - return [False, "painting"] if not waiting_prompt.safe_ip and not self.allow_unsafe_ipaddr: return [False, "unsafe_ip"] # We do not give untrusted workers anon or VPN generations, to avoid anything slipping by and spooking them. diff --git a/horde/consts.py b/horde/consts.py index 76cad596..fbe0f023 100644 --- a/horde/consts.py +++ b/horde/consts.py @@ -1,4 +1,4 @@ -HORDE_VERSION = "4.33.0" +HORDE_VERSION = "4.34.0" WHITELISTED_SERVICE_IPS = { "212.227.227.178", # Turing Bot diff --git a/horde/database/functions.py b/horde/database/functions.py index 2081fe9f..0bab7a24 100644 --- a/horde/database/functions.py +++ b/horde/database/functions.py @@ -785,6 +785,10 @@ def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, p ImageWaitingPrompt.source_processing.not_in(["inpainting", "outpainting"]), worker.allow_painting == True, # noqa E712 ), + or_( + ImageWaitingPrompt.extra_source_images == None, # noqa E712 + check_bridge_capability("extra_source_images", worker.bridge_agent), + ), or_( ImageWaitingPrompt.safe_ip == True, # noqa E712 worker.allow_unsafe_ipaddr == True, # noqa E712 diff --git a/horde/exceptions.py b/horde/exceptions.py index ba6fda81..6ad64738 100644 --- a/horde/exceptions.py +++ b/horde/exceptions.py @@ -133,6 +133,13 @@ "InvalidPriorityUsername", "OnlyServiceAccountProxy", "RequiresTrust", + "InvalidRemixModel", + "InvalidExtraSourceImages", + "TooManyExtraSourceImages", + "MissingFullSamplerOrder", + "TooManyStopSequences", + "ExcessiveStopSequence", + "TokenOverflow", ] diff --git a/sql_statements/4.34.0.txt b/sql_statements/4.34.0.txt new file mode 100644 index 00000000..6c4b9cc1 --- /dev/null +++ b/sql_statements/4.34.0.txt @@ -0,0 +1 @@ +ALTER TABLE waiting_prompts ADD COLUMN extra_source_images JSONB; diff --git a/tests/test_image.py b/tests/test_image.py index fb7fab31..2e2bb901 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -12,10 +12,13 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: "r2": True, "shared": True, "trusted_workers": True, - "width": 1024, - "height": 1024, - "steps": 8, - "cfg_scale": 1.5, + "params": { + "width": 1024, + "height": 1024, + "steps": 8, + "cfg_scale": 1.5, + "sampler_name": "k_euler_a", + }, "sampler_name": "k_euler_a", "models": TEST_MODELS, "loras": [{"name": "247778", "is_version": True}], @@ -30,7 +33,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: # print(async_results) pop_dict = { "name": "CICD Fake Dreamer", - "models": ["Fustercluck", "AlbedoBase XL (SDXL)"], + "models": TEST_MODELS, "bridge_agent": "AI Horde Worker reGen:4.1.0-citests:https://github.com/Haidra-Org/horde-worker-reGen", "amount": 10, "max_pixels": 4194304, @@ -42,12 +45,22 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: "allow_lora": True, } pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) - assert pop_req.ok, pop_req.text + try: + assert pop_req.ok, pop_req.text + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err pop_results = pop_req.json() # print(json.dumps(pop_results, indent=4)) job_id = pop_results["id"] - assert job_id is not None, pop_results + try: + assert job_id is not None, pop_results + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err submit_dict = { "id": job_id, "generation": "R2", diff --git a/tests/test_image_extra_sources.py b/tests/test_image_extra_sources.py new file mode 100644 index 00000000..a8a75451 --- /dev/null +++ b/tests/test_image_extra_sources.py @@ -0,0 +1,111 @@ +import requests +import json +from PIL import Image +from io import BytesIO +import base64 + +def load_image_as_b64(image_path): + final_src_img = Image.open(image_path) + buffer = BytesIO() + final_src_img.save(buffer, format="Webp", quality=50, exact=True) + return base64.b64encode(buffer.getvalue()).decode("utf8") + +TEST_MODELS = ["Stable Cascade 1.0"] + + +def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None: + headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user + async_dict = { + "prompt": "A remix", + "nsfw": True, + "censor_nsfw": False, + "r2": True, + "shared": True, + "trusted_workers": True, + "params": { + "width": 1024, + "height": 1024, + "steps": 20, + "cfg_scale": 4, + "sampler_name": "k_euler_a", + }, + "models": TEST_MODELS, + "source_image": load_image_as_b64('img_stable/0.jpg'), + "source_processing": "remix", + "extra_source_images": [ + { + "image": load_image_as_b64('img_stable/1.jpg'), + "strength": 0.5, + }, + { + "image": load_image_as_b64('img_stable/2.jpg'), + }, + ], + } + protocol = "http" + if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]: + protocol = "https" + async_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/async", json=async_dict, headers=headers) + assert async_req.ok, async_req.text + async_results = async_req.json() + req_id = async_results["id"] + # print(async_results) + pop_dict = { + "name": "CICD Fake Dreamer", + "models": TEST_MODELS, + "bridge_agent": "AI Horde Worker reGen:5.3.0-citests:https://github.com/Haidra-Org/horde-worker-reGen", + "amount": 10, + "max_pixels": 4194304, + "allow_img2img": True, + "allow_painting": True, + "allow_unsafe_ipaddr": True, + "allow_post_processing": True, + "allow_controlnet": True, + "allow_lora": True, + } + pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers) + try: + assert pop_req.ok, pop_req.text + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + + pop_results = pop_req.json() + # print(json.dumps(pop_results, indent=4)) + + job_id = pop_results["id"] + try: + assert job_id is not None, pop_results + except AssertionError as err: + requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + print("Request cancelled") + raise err + submit_dict = { + "id": job_id, + "generation": "R2", + "state": "ok", + "seed": 0, + } + submit_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/submit", json=submit_dict, headers=headers) + assert submit_req.ok, submit_req.text + submit_results = submit_req.json() + assert submit_results["reward"] > 0 + retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers) + assert retrieve_req.ok, retrieve_req.text + retrieve_results = retrieve_req.json() + # print(json.dumps(retrieve_results,indent=4)) + assert len(retrieve_results["generations"]) == 1 + gen = retrieve_results["generations"][0] + assert len(gen["gen_metadata"]) == 0 + assert gen["seed"] == "0" + assert gen["worker_name"] == "CICD Fake Dreamer" + assert gen["model"] in TEST_MODELS + assert gen["state"] == "ok" + assert retrieve_results["kudos"] > 1 + assert retrieve_results["done"] is True + + +if __name__ == "__main__": + # "ci/cd#12285" + test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")