Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Sep 13, 2024
1 parent 5998aaa commit 30e05f4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
15 changes: 15 additions & 0 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,21 @@ def can_generate(self, waiting_prompt):
self.bridge_agent,
waiting_prompt.gen_payload.get("karras", False),
):
logger.debug("bridge_version")
return [False, "bridge_version"]
# logger.warning(datetime.utcnow())
if len(waiting_prompt.gen_payload.get("post_processing", [])) >= 1 and not check_bridge_capability(
"post-processing",
self.bridge_agent,
):
logger.debug("bridge_version")
return [False, "bridge_version"]
for pp in KNOWN_POST_PROCESSORS:
if pp in waiting_prompt.gen_payload.get("post_processing", []) and not check_bridge_capability(
pp,
self.bridge_agent,
):
logger.debug("bridge_version")
return [False, "bridge_version"]
if waiting_prompt.source_image and not self.allow_img2img:
return [False, "img2img"]
Expand All @@ -112,47 +115,59 @@ def can_generate(self, waiting_prompt):
):
return [False, "models"]
if waiting_prompt.params.get("tiling") and not check_bridge_capability("tiling", self.bridge_agent):
logger.debug("bridge_version")
return [False, "bridge_version"]
if waiting_prompt.params.get("return_control_map") and not check_bridge_capability(
"return_control_map",
self.bridge_agent,
):
logger.debug("bridge_version")
return [False, "bridge_version"]
if waiting_prompt.params.get("control_type"):
if not check_bridge_capability("controlnet", self.bridge_agent):
logger.debug("bridge_version")
return [False, "bridge_version"]
if not check_bridge_capability("image_is_control", self.bridge_agent):
logger.debug("bridge_version")
return [False, "bridge_version"]
if not self.allow_controlnet:
logger.debug("bridge_version")
return [False, "controlnet"]
if waiting_prompt.params.get("workflow") == "qr_code":
if not check_bridge_capability("controlnet", self.bridge_agent):
logger.debug("bridge_version")
return [False, "bridge_version"]
if not check_bridge_capability("qr_code", self.bridge_agent):
logger.debug("bridge_version")
return [False, "bridge_version"]
if "stable_diffusion_xl" in model_reference.get_all_model_baselines(self.get_model_names()) and not self.allow_sdxl_controlnet:
return [False, "controlnet"]
if waiting_prompt.params.get("hires_fix") and not check_bridge_capability("hires_fix", self.bridge_agent):
logger.debug("bridge_version")
return [False, "bridge_version"]
if (
waiting_prompt.params.get("hires_fix")
and "stable_cascade" in model_reference.get_all_model_baselines(self.get_model_names())
and not check_bridge_capability("stable_cascade_2pass", self.bridge_agent)
):
logger.debug("bridge_version")
return [False, "bridge_version"]
if "flux_1" in model_reference.get_all_model_baselines(self.get_model_names()) and not check_bridge_capability(
"flux", self.bridge_agent
):
logger.debug("bridge_version")
return [False, "bridge_version"]
if waiting_prompt.params.get("clip_skip", 1) > 1 and not check_bridge_capability(
"clip_skip",
self.bridge_agent,
):
logger.debug("bridge_version")
return [False, "bridge_version"]
if any(lora.get("is_version") for lora in waiting_prompt.params.get("loras", [])) and not check_bridge_capability(
"lora_versions",
self.bridge_agent,
):
logger.debug("bridge_version")
return [False, "bridge_version"]
if not waiting_prompt.safe_ip and not self.allow_unsafe_ipaddr:
return [False, "unsafe_ip"]
Expand Down
39 changes: 35 additions & 4 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert pop_results["id"] is None, pop_results
assert pop_results["skipped"].get("step_count") == 1, pop_results
except AssertionError as err:
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
print("Request cancelled")
# requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
# print("Request cancelled")
raise err

# Test extra_slow_worker
Expand Down Expand Up @@ -233,7 +233,38 @@ def test_flux_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
assert retrieve_results["done"] is True
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)

def quick_pop(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
print("quick_pop")
headers = {"apikey": api_key, "Client-Agent": f"aihorde_ci_client:{CIVERSION}:(discord)db0#1625"} # ci/cd user
protocol = "http"
if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]:
protocol = "https"
# print(async_results)
pop_dict = {
"name": "CICD Fake Dreamer",
"models": TEST_MODELS_FLUX,
"bridge_agent": "AI Horde Worker reGen:9.1.0-citests:https://github.com/Haidra-Org/horde-worker-reGen",
"nsfw": True,
"amount": 10,
"max_pixels": 4194304,
"allow_img2img": True,
"allow_painting": True,
"allow_unsafe_ipaddr": True,
"allow_post_processing": True,
"allow_controlnet": True,
"allow_sdxl_controlnet": True,
"allow_lora": True,
"extra_slow_worker": False,
"limit_max_steps": True,
}

# Test limit_max_steps
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
print(pop_req.text)


if __name__ == "__main__":
# "ci/cd#12285"
test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")
test_flux_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")
# test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.2.0")
# test_flux_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.2.0")
quick_pop("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.2.0")

0 comments on commit 30e05f4

Please sign in to comment.