Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various tweaks around Flux onboarding #451

Merged
merged 46 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d03cd56
fix: prevent workers without flux support picking up flux jobs
db0 Sep 12, 2024
c194d62
feat: adjusted TTL formula to be algorithmic
db0 Sep 12, 2024
a49fdf0
wip
db0 Sep 12, 2024
acb31be
reporting missing requests due to step count
db0 Sep 13, 2024
b0c6615
fix: missing orm declaration:
db0 Sep 13, 2024
5b81fb0
fix: missing orm declaration
db0 Sep 13, 2024
a63bad8
feat: support betas
db0 Sep 13, 2024
9c591e9
test
db0 Sep 13, 2024
b6dc735
test
db0 Sep 13, 2024
c2f6bc4
test
db0 Sep 13, 2024
ba20af2
test
db0 Sep 13, 2024
c6db21f
test
db0 Sep 13, 2024
76706e7
test
db0 Sep 13, 2024
7805b6c
test
db0 Sep 13, 2024
0497891
test
db0 Sep 13, 2024
f6bf397
fix
db0 Sep 13, 2024
541d32a
test
db0 Sep 13, 2024
39d1585
fix
db0 Sep 13, 2024
9e7fdcb
extra slow workers count
db0 Sep 13, 2024
b98922a
test
db0 Sep 13, 2024
ae80b12
test
db0 Sep 13, 2024
dc7895e
test
db0 Sep 13, 2024
7935adb
test
db0 Sep 13, 2024
64a102a
test
db0 Sep 13, 2024
64b21df
test
db0 Sep 13, 2024
daceb7e
test
db0 Sep 13, 2024
5998aaa
test
db0 Sep 13, 2024
30e05f4
test
db0 Sep 13, 2024
97dbfa1
test
db0 Sep 13, 2024
b752f1f
test
db0 Sep 13, 2024
727cba1
test
db0 Sep 13, 2024
ce6ada5
test
db0 Sep 13, 2024
ef38f73
flush
db0 Sep 13, 2024
e7d96bd
bad use of flush
db0 Sep 13, 2024
e137975
bad use of flush
db0 Sep 13, 2024
7227bcc
test
db0 Sep 13, 2024
9b78eb6
test
db0 Sep 13, 2024
beb69d3
test
db0 Sep 13, 2024
ef9ca96
test
db0 Sep 13, 2024
d81082c
lint
db0 Sep 13, 2024
e79453f
lint
db0 Sep 13, 2024
ef065ef
removed_debugs
db0 Sep 13, 2024
2871b33
removed_debugs
db0 Sep 13, 2024
e497e9e
fix: not_nulls
db0 Sep 13, 2024
5ccf1f9
avoid_timeouts_on_testys
db0 Sep 13, 2024
aa4d0d1
avoid_timeouts_on_testys
db0 Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.43.0

* Adjused TTL formula to be algorithmic
* prevent workers without flux support picking up flux jobs
* Adds `extra_slow_workers` bool for image gen async
* Adds `extra_slow_worker` bool for worker pop
* Adds `limit_max_steps` for worker pop

# 4.42.0

* Adds support for the Flux family of models
Expand Down
6 changes: 6 additions & 0 deletions horde/apis/models/kobold_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ def __init__(self, api):
"The request will include the details of the job as well as the request ID."
),
),
"extra_slow_workers": fields.Boolean(
default=False,
description=(
"When True, allows very slower workers to pick up this request. " "Use this when you don't mind waiting a lot."
),
),
},
)
self.response_model_contrib_details = api.inherit(
Expand Down
24 changes: 24 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def __init__(self):
help="If True, this worker will pick up requests requesting LoRas.",
location="json",
)
self.job_pop_parser.add_argument(
"limit_max_steps",
type=bool,
required=False,
default=False,
help="If True, This worker will not pick up jobs with more steps than the average allowed for that model.",
location="json",
)
self.job_submit_parser.add_argument(
"seed",
type=int,
Expand Down Expand Up @@ -451,6 +459,9 @@ def __init__(self, api):
"max_pixels": fields.Integer(
description="How many waiting requests were skipped because they demanded a higher size than this worker provides.",
),
"step_count": fields.Integer(
description="How many waiting requests were skipped because they demanded a higher step count that the worker wants.",
),
"unsafe_ip": fields.Integer(
description="How many waiting requests were skipped because they came from an unsafe IP.",
),
Expand Down Expand Up @@ -544,6 +555,13 @@ def __init__(self, api):
default=True,
description="If True, this worker will pick up requests requesting LoRas.",
),
"limit_max_steps": fields.Boolean(
default=True,
description=(
"If True, This worker will not pick up jobs with more steps than the average allowed for that model."
" this is for use by workers which might run into issues doing too many steps."
),
),
},
)
self.input_model_job_submit = api.inherit(
Expand Down Expand Up @@ -591,6 +609,12 @@ def __init__(self, api):
default=True,
description="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.",
),
"extra_slow_workers": fields.Boolean(
default=False,
description=(
"When True, allows very slower workers to pick up this request. " "Use this when you don't mind waiting a lot."
),
),
"censor_nsfw": fields.Boolean(
default=False,
description="If the request is SFW, and the worker accidentally generates NSFW, it will send back a censored image.",
Expand Down
22 changes: 22 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def __init__(self):
help="When True, allows slower workers to pick up this request. Disabling this incurs an extra kudos cost.",
location="json",
)
self.generate_parser.add_argument(
"extra_slow_workers",
type=bool,
default=False,
required=False,
help="When True, allows very slower workers to pick up this request. Use this when you don't mind waiting a lot.",
location="json",
)
self.generate_parser.add_argument(
"dry_run",
type=bool,
Expand Down Expand Up @@ -204,6 +212,13 @@ def __init__(self):
help="How many jobvs to pop at the same time",
location="json",
)
self.job_pop_parser.add_argument(
"extra_slow_worker",
type=bool,
default=False,
required=False,
location="json",
)

self.job_submit_parser = reqparse.RequestParser()
self.job_submit_parser.add_argument(
Expand Down Expand Up @@ -537,6 +552,13 @@ def __init__(self, api):
min=1,
max=20,
),
"extra_slow_worker": fields.Boolean(
default=True,
description=(
"If True, marks the worker as very slow. You should only use this if your mps/s is lower than 0.1."
"Extra slow workers are excluded from normal requests but users can opt in to use them."
),
),
},
)
self.response_model_worker_details = api.inherit(
Expand Down
3 changes: 1 addition & 2 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ def post(self):
# as they're typically countermeasures to raids
if skipped_reason != "secret":
self.skipped[skipped_reason] = self.skipped.get(skipped_reason, 0) + 1
# logger.warning(datetime.utcnow())

continue
# There is a chance that by the time we finished all the checks, another worker picked up the WP.
Expand All @@ -477,7 +476,7 @@ def post(self):
# We report maintenance exception only if we couldn't find any jobs
if self.worker.maintenance:
raise e.WorkerMaintenance(self.worker.maintenance_msg)
# logger.warning(datetime.utcnow())
# logger.debug(self.skipped)
return {"id": None, "ids": [], "skipped": self.skipped}, 200

def get_sorted_wp(self, priority_user_ids=None):
Expand Down
7 changes: 7 additions & 0 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def initiate_waiting_prompt(self):
validated_backends=self.args.validated_backends,
worker_blacklist=self.args.worker_blacklist,
slow_workers=self.args.slow_workers,
extra_slow_workers=self.args.extra_slow_workers,
source_processing=self.args.source_processing,
ipaddr=self.user_ip,
safe_ip=self.safe_ip,
Expand Down Expand Up @@ -599,6 +600,10 @@ def post(self):
db_skipped["kudos"] = post_ret["skipped"]["kudos"]
if "blacklist" in post_ret.get("skipped", {}):
db_skipped["blacklist"] = post_ret["skipped"]["blacklist"]
if "step_count" in post_ret.get("skipped", {}):
db_skipped["step_count"] = post_ret["skipped"]["step_count"]
if "bridge_version" in post_ret.get("skipped", {}):
db_skipped["bridge_version"] = db_skipped.get("bridge_version", 0) + post_ret["skipped"]["bridge_version"]
post_ret["skipped"] = db_skipped
# logger.debug(post_ret)
return post_ret, retcode
Expand All @@ -621,6 +626,8 @@ def check_in(self):
allow_controlnet=self.args.allow_controlnet,
allow_sdxl_controlnet=self.args.allow_sdxl_controlnet,
allow_lora=self.args.allow_lora,
extra_slow_worker=self.args.extra_slow_worker,
limit_max_steps=self.args.limit_max_steps,
priority_usernames=self.priority_usernames,
)

Expand Down
3 changes: 3 additions & 0 deletions horde/bridge_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

BRIDGE_CAPABILITIES = {
"AI Horde Worker reGen": {
9: {"flux"},
8: {"layer_diffuse"},
7: {"qr_code", "extra_texts", "workflow"},
6: {"stable_cascade_2pass"},
Expand Down Expand Up @@ -185,6 +186,7 @@ def parse_bridge_agent(bridge_agent):
@logger.catch(reraise=True)
def check_bridge_capability(capability, bridge_agent):
bridge_name, bridge_version = parse_bridge_agent(bridge_agent)
# logger.debug([bridge_name, bridge_version])
if bridge_name not in BRIDGE_CAPABILITIES:
return False
total_capabilities = set()
Expand All @@ -194,6 +196,7 @@ def check_bridge_capability(capability, bridge_agent):
if checked_semver.compare(bridge_version) <= 0:
total_capabilities.update(BRIDGE_CAPABILITIES[bridge_name][version])
# logger.debug([total_capabilities, capability, capability in total_capabilities])
# logger.debug([bridge_name, BRIDGE_CAPABILITIES[bridge_name]])
return capability in total_capabilities


Expand Down
13 changes: 11 additions & 2 deletions horde/classes/base/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ProcessingGeneration(db.Model):
nullable=False,
server_default=expression.literal(False),
)
job_ttl = db.Column(db.Integer, default=150, nullable=False, index=True)

wp_id = db.Column(
uuid_column_type(),
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, *args, **kwargs):
self.model = matching_models[0]
else:
self.model = kwargs["model"]
self.set_job_ttl()
db.session.commit()

def set_generation(self, generation, things_per_sec, **kwargs):
Expand Down Expand Up @@ -163,10 +165,10 @@ def is_completed(self):
def is_faulted(self):
return self.faulted

def is_stale(self, ttl):
def is_stale(self):
if self.is_completed() or self.is_faulted():
return False
return (datetime.utcnow() - self.start_time).total_seconds() > ttl
return (datetime.utcnow() - self.start_time).total_seconds() > self.job_ttl

def delete(self):
db.session.delete(self)
Expand Down Expand Up @@ -224,3 +226,10 @@ def send_webhook(self, kudos):
break
except Exception as err:
logger.debug(f"Exception when sending generation webhook: {err}. Will retry {3-riter-1} more times...")

def set_job_ttl(self):
"""Returns how many seconds each job request should stay waiting before considering it stale and cancelling it
This function should be overriden by the invididual hordes depending on how the calculating ttl
"""
self.job_ttl = 150
db.session.commit()
23 changes: 11 additions & 12 deletions horde/classes/base/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from horde.classes.stable.processing_generation import ImageProcessingGeneration
from horde.flask import SQLITE_MODE, db
from horde.logger import logger
from horde.utils import get_db_uuid, get_expiry_date
from horde.utils import get_db_uuid, get_expiry_date, get_extra_slow_expiry_date

procgen_classes = {
"template": ProcessingGeneration,
Expand Down Expand Up @@ -93,6 +93,7 @@ class WaitingPrompt(db.Model):
trusted_workers = db.Column(db.Boolean, default=False, nullable=False, index=True)
validated_backends = db.Column(db.Boolean, default=True, nullable=False, index=True)
slow_workers = db.Column(db.Boolean, default=True, nullable=False, index=True)
extra_slow_workers = db.Column(db.Boolean, default=False, nullable=False, index=True)
worker_blacklist = db.Column(db.Boolean, default=False, nullable=False, index=True)
faulted = db.Column(db.Boolean, default=False, nullable=False, index=True)
active = db.Column(db.Boolean, default=False, nullable=False, index=True)
Expand All @@ -105,6 +106,7 @@ class WaitingPrompt(db.Model):
things = db.Column(db.BigInteger, default=0, nullable=False)
total_usage = db.Column(db.Float, default=0, nullable=False)
extra_priority = db.Column(db.Integer, default=0, nullable=False, index=True)
# TODO: Delete. Obsoleted.
job_ttl = db.Column(db.Integer, default=150, nullable=False)
disable_batching = db.Column(db.Boolean, default=False, nullable=False)
webhook = db.Column(db.String(1024))
Expand Down Expand Up @@ -204,7 +206,6 @@ def extract_params(self):
self.things = 0
self.total_usage = round(self.things * self.n, 2)
self.prepare_job_payload()
self.set_job_ttl()
db.session.commit()

def prepare_job_payload(self):
Expand Down Expand Up @@ -241,7 +242,7 @@ def start_generation(self, worker, amount=1):
self.n -= safe_amount
payload = self.get_job_payload(current_n)
# This does a commit as well
self.refresh()
self.refresh(worker)
procgen_class = procgen_classes[self.wp_type]
gens_list = []
model = None
Expand Down Expand Up @@ -457,8 +458,13 @@ def abort_for_maintenance(self):
except Exception as err:
logger.warning(f"Error when aborting WP. Skipping: {err}")

def refresh(self):
self.expiry = get_expiry_date()
def refresh(self, worker=None):
if worker is not None and worker.extra_slow_worker is True:
self.expiry = get_extra_slow_expiry_date()
else:
new_expiry = get_expiry_date()
if self.expiry < new_expiry:
self.expiry = new_expiry
db.session.commit()

def is_stale(self):
Expand All @@ -469,13 +475,6 @@ def is_stale(self):
def get_priority(self):
return self.extra_priority

def set_job_ttl(self):
"""Returns how many seconds each job request should stay waiting before considering it stale and cancelling it
This function should be overriden by the invididual hordes depending on how the calculating ttl
"""
self.job_ttl = 150
db.session.commit()

def refresh_worker_cache(self):
worker_ids = [worker.worker_id for worker in self.workers]
worker_string_ids = [str(worker.worker_id) for worker in self.workers]
Expand Down
19 changes: 10 additions & 9 deletions horde/classes/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class WorkerTemplate(db.Model):
# Used by all workers to record how much they can pick up to generate
# The value of this column is dfferent per worker type
max_power = db.Column(db.Integer, default=20, nullable=False)
extra_slow_worker = db.Column(db.Boolean, default=False, nullable=False, index=True)

paused = db.Column(db.Boolean, default=False, nullable=False)
maintenance = db.Column(db.Boolean, default=False, nullable=False)
Expand Down Expand Up @@ -196,7 +197,7 @@ def report_suspicion(self, amount=1, reason=Suspicions.WORKER_PROFANITY, formats
f"Last suspicion log: {reason.name}.\n"
f"Total Suspicion {self.get_suspicion()}",
)
db.session.commit()
db.session.flush()

def get_suspicion_reasons(self):
return set([s.suspicion_id for s in self.suspicions])
Expand Down Expand Up @@ -261,10 +262,6 @@ def toggle_paused(self, is_paused_active):

# This should be extended by each worker type
def check_in(self, **kwargs):
# To avoid excessive commits,
# we only record new changes on the worker every 30 seconds
if (datetime.utcnow() - self.last_check_in).total_seconds() < 30 and (datetime.utcnow() - self.created).total_seconds() > 30:
return
self.ipaddr = kwargs.get("ipaddr", None)
self.bridge_agent = sanitize_string(kwargs.get("bridge_agent", "unknown:0:unknown"))
self.threads = kwargs.get("threads", 1)
Expand All @@ -275,6 +272,10 @@ def check_in(self, **kwargs):
self.prioritized_users = kwargs.get("prioritized_users", [])
if not kwargs.get("safe_ip", True) and not self.user.trusted:
self.report_suspicion(reason=Suspicions.UNSAFE_IP)
# To avoid excessive commits,
# we only record new uptime on the worker every 30 seconds
if (datetime.utcnow() - self.last_check_in).total_seconds() < 30 and (datetime.utcnow() - self.created).total_seconds() > 30:
return
if not self.is_stale() and not self.paused and not self.maintenance:
self.uptime += (datetime.utcnow() - self.last_check_in).total_seconds()
# Every 10 minutes of uptime gets 100 kudos rewarded
Expand All @@ -293,7 +294,6 @@ def check_in(self, **kwargs):
# So that they have to stay up at least 10 mins to get uptime kudos
self.last_reward_uptime = self.uptime
self.last_check_in = datetime.utcnow()
db.session.commit()

def get_human_readable_uptime(self):
if self.uptime < 60:
Expand Down Expand Up @@ -511,7 +511,8 @@ def check_in(self, **kwargs):
self.set_models(kwargs.get("models"))
self.nsfw = kwargs.get("nsfw", True)
self.set_blacklist(kwargs.get("blacklist", []))
db.session.commit()
self.extra_slow_worker = kwargs.get("extra_slow_worker", False)
# Commit should happen on calling extensions

def set_blacklist(self, blacklist):
# We don't allow more workers to claim they can server more than 50 models atm (to prevent abuse)
Expand All @@ -527,7 +528,7 @@ def set_blacklist(self, blacklist):
for word in blacklist:
blacklisted_word = WorkerBlackList(worker_id=self.id, word=word[0:15])
db.session.add(blacklisted_word)
db.session.commit()
db.session.flush()

def refresh_model_cache(self):
models_list = [m.model for m in self.models]
Expand Down Expand Up @@ -563,7 +564,7 @@ def set_models(self, models):
return
# logger.debug([existing_model_names,models, existing_model_names == models])
db.session.query(WorkerModel).filter_by(worker_id=self.id).delete()
db.session.commit()
db.session.flush()
for model_name in models:
model = WorkerModel(worker_id=self.id, model=model_name)
db.session.add(model)
Expand Down
Loading
Loading