Skip to content

Commit

Permalink
4x rewards for SC
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Feb 22, 2024
1 parent 66e9b5e commit e8075d7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def get_details(self):

def get_gen_kudos(self):
# We have pre-calculated them as they don't change per worker
if model_reference.get_model_baseline(self.model) in ["stable_diffusion_xl", "stable_cascade"]:
if model_reference.get_model_baseline(self.model) in ["stable_diffusion_xl"]:
return self.wp.kudos * 2
if model_reference.get_model_baseline(self.model) in ["stable_cascade"]:
return self.wp.kudos * 4
return self.wp.kudos

def log_aborted_generation(self):
Expand Down
7 changes: 6 additions & 1 deletion horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def require_upfront_kudos(self, counted_totals, total_threads):
# Using more than 10 steps with LCM requires upfront kudos
if self.is_using_lcm() and self.get_accurate_steps() > 10:
return (True, max_res)
# Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse.
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30:
return (True, max_res)
if self.get_accurate_steps() > 50:
return (True, max_res)
if self.width * self.height > max_res * max_res:
Expand Down Expand Up @@ -439,8 +442,10 @@ def extrapolate_dry_run_kudos(self):
model_name = self.models[0].model
else:
model_name = "SDXL 1.0"
if model_reference.get_model_baseline(model_name) in ["stable_diffusion_xl", "stable_cascade"]:
if model_reference.get_model_baseline(model_name) in ["stable_diffusion_xl"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1
if model_reference.get_model_baseline(model_name) in ["stable_cascade"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 4) + 1
# The +1 is the extra kudos burn per request
return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1

Expand Down

0 comments on commit e8075d7

Please sign in to comment.