Skip to content

Commit

Permalink
allow tests to run against dev.stablehorde.net
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Feb 22, 2024
1 parent 93d71b9 commit 66e9b5e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
15 changes: 9 additions & 6 deletions tests/test_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def test_simple_alchemy(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
],
"source_image": "https://github.com/Haidra-Org/AI-Horde/blob/main/icon.png?raw=true",
}
async_req = requests.post(f"http://{HORDE_URL}/api/v2/interrogate/async", json=async_dict, headers=headers)
protocol = "http"
if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]:
protocol = "https"
async_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/interrogate/async", json=async_dict, headers=headers)
assert async_req.ok, async_req.text
async_results = async_req.json()
req_id = async_results["id"]
Expand All @@ -21,9 +24,9 @@ def test_simple_alchemy(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"max_tiles": 96,
}
try:
pop_req = requests.post(f"http://{HORDE_URL}/api/v2/interrogate/pop", json=pop_dict, headers=headers)
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/interrogate/pop", json=pop_dict, headers=headers)
except Exception:
requests.delete(f"http://{HORDE_URL}/api/v2/interrogate/status/{req_id}", headers=headers)
requests.delete(f"{protocol}://{HORDE_URL}/api/v2/interrogate/status/{req_id}", headers=headers)
raise
assert pop_req.ok, pop_req.text
pop_results = pop_req.json()
Expand All @@ -36,11 +39,11 @@ def test_simple_alchemy(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"result": {"caption": "Test"},
"state": "ok",
}
submit_req = requests.post(f"http://{HORDE_URL}/api/v2/interrogate/submit", json=submit_dict, headers=headers)
submit_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/interrogate/submit", json=submit_dict, headers=headers)
assert submit_req.ok, submit_req.text
submit_results = submit_req.json()
assert submit_results["reward"] > 0
retrieve_req = requests.get(f"http://{HORDE_URL}/api/v2/interrogate/status/{req_id}", headers=headers)
retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/interrogate/status/{req_id}", headers=headers)
assert retrieve_req.ok, retrieve_req.text
retrieve_results = retrieve_req.json()
# print(json.dumps(retrieve_results,indent=4))
Expand All @@ -57,4 +60,4 @@ def test_simple_alchemy(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:


if __name__ == "__main__":
test_simple_alchemy()
test_simple_alchemy("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")
13 changes: 8 additions & 5 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"models": TEST_MODELS,
"loras": [{"name": "247778", "is_version": True}],
}
async_req = requests.post(f"http://{HORDE_URL}/api/v2/generate/async", json=async_dict, headers=headers)
protocol = "http"
if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]:
protocol = "https"
async_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/async", json=async_dict, headers=headers)
assert async_req.ok, async_req.text
async_results = async_req.json()
req_id = async_results["id"]
Expand All @@ -38,7 +41,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"allow_controlnet": True,
"allow_lora": True,
}
pop_req = requests.post(f"http://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
assert pop_req.ok, pop_req.text
pop_results = pop_req.json()
# print(json.dumps(pop_results, indent=4))
Expand All @@ -51,11 +54,11 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"state": "ok",
"seed": 0,
}
submit_req = requests.post(f"http://{HORDE_URL}/api/v2/generate/submit", json=submit_dict, headers=headers)
submit_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/submit", json=submit_dict, headers=headers)
assert submit_req.ok, submit_req.text
submit_results = submit_req.json()
assert submit_results["reward"] > 0
retrieve_req = requests.get(f"http://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/status/{req_id}", headers=headers)
assert retrieve_req.ok, retrieve_req.text
retrieve_results = retrieve_req.json()
# print(json.dumps(retrieve_results,indent=4))
Expand All @@ -71,4 +74,4 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:


if __name__ == "__main__":
test_simple_image_gen()
test_simple_image_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")
13 changes: 8 additions & 5 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"temperature": 1,
"models": TEST_MODELS,
}
async_req = requests.post(f"http://{HORDE_URL}/api/v2/generate/text/async", json=async_dict, headers=headers)
protocol = "http"
if HORDE_URL in ["dev.stablehorde.net", "stablehorde.net"]:
protocol = "https"
async_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/text/async", json=async_dict, headers=headers)
assert async_req.ok, async_req.text
async_results = async_req.json()
req_id = async_results["id"]
Expand All @@ -26,7 +29,7 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"max_context_length": 4096,
"max_length": 512,
}
pop_req = requests.post(f"http://{HORDE_URL}/api/v2/generate/text/pop", json=pop_dict, headers=headers)
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/text/pop", json=pop_dict, headers=headers)
assert pop_req.ok, pop_req.text
pop_results = pop_req.json()
# print(json.dumps(pop_results, indent=4))
Expand All @@ -39,11 +42,11 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"state": "ok",
"seed": 0,
}
submit_req = requests.post(f"http://{HORDE_URL}/api/v2/generate/text/submit", json=submit_dict, headers=headers)
submit_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/text/submit", json=submit_dict, headers=headers)
assert submit_req.ok, submit_req.text
submit_results = submit_req.json()
assert submit_results["reward"] > 0
retrieve_req = requests.get(f"http://{HORDE_URL}/api/v2/generate/text/status/{req_id}", headers=headers)
retrieve_req = requests.get(f"{protocol}://{HORDE_URL}/api/v2/generate/text/status/{req_id}", headers=headers)
assert retrieve_req.ok, retrieve_req.text
retrieve_results = retrieve_req.json()
# print(json.dumps(retrieve_results,indent=4))
Expand All @@ -59,4 +62,4 @@ def test_simple_text_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:


if __name__ == "__main__":
test_simple_text_gen()
test_simple_text_gen("2bc5XkMeLAWiN9O5s7bhfg", "dev.stablehorde.net", "0.1.1")

0 comments on commit 66e9b5e

Please sign in to comment.