Skip to content

Commit

Permalink
feat: Allow quering only for custom models
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Apr 8, 2024
1 parent b5c58dd commit c0d4c4f
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 22 deletions.
6 changes: 3 additions & 3 deletions horde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
def after_request(response):
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS, PUT, DELETE, PATCH"
response.headers["Access-Control-Allow-Headers"] = (
"Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, apikey, Client-Agent, X-Fields"
)
response.headers[
"Access-Control-Allow-Headers"
] = "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, apikey, Client-Agent, X-Fields"
response.headers["Horde-Node"] = f"{socket.gethostname()}:{args.port}:{HORDE_VERSION}"
return response

Expand Down
17 changes: 16 additions & 1 deletion horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,10 +1565,22 @@ class Models(Resource):
help="Filter the models that have at most this amount of threads serving.",
location="args",
)
get_parser.add_argument(
"model_state",
required=False,
default="all",
type=str,
help=(
"If 'known', only show stats for known models in the model reference. "
"If 'custom' only show stats for custom models. "
"If 'all' shows stats for all models."
),
location="args",
)

@logger.catch(reraise=True)
@cache.cached(timeout=2, query_string=True)
@api.expect(get_parser)
@api.response(400, "Validation Error", models.response_model_error)
@api.marshal_with(
models.response_model_active_model,
code=200,
Expand All @@ -1578,10 +1590,13 @@ class Models(Resource):
def get(self):
"""Returns a list of models active currently in this horde"""
self.args = self.get_parser.parse_args()
if self.args.model_state not in ["known", "custom", "all"]:
raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']")
models_ret = database.retrieve_available_models(
model_type=self.args.type,
min_count=self.args.min_count,
max_count=self.args.max_count,
model_state=self.args.model_state,
)
return (models_ret, 200)

Expand Down
17 changes: 11 additions & 6 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,17 +1274,22 @@ class ImageHordeStatsModels(Resource):
location="headers",
)
get_parser.add_argument(
"model_type",
"model_state",
required=False,
default='known',
default="known",
type=str,
help="If 'known', only show stats for known models in the model reference. If 'custom' only show stats for custom models. If 'all' shows stats for all models.",
help=(
"If 'known', only show stats for known models in the model reference. "
"If 'custom' only show stats for custom models. "
"If 'all' shows stats for all models."
),
location="args",
)

@logger.catch(reraise=True)
# @cache.cached(timeout=50, query_string=True)
@api.expect(get_parser)
@api.response(400, "Validation Error", models.response_model_error)
@api.marshal_with(
models.response_model_stats_models,
code=200,
Expand All @@ -1293,6 +1298,6 @@ class ImageHordeStatsModels(Resource):
def get(self):
"""Details how many images were generated per model for the past day, month and total"""
self.args = self.get_parser.parse_args()
if self.args.model_type not in ['known', 'custom', 'all']:
return e.BadRequest("'model_type' needs to be one of ['known', 'custom', 'all']")
return compile_imagegen_stats_models(self.args.model_type), 200
if self.args.model_state not in ["known", "custom", "all"]:
raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']")
return compile_imagegen_stats_models(self.args.model_state), 200
23 changes: 12 additions & 11 deletions horde/classes/stable/genstats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from loguru import logger
from datetime import datetime, timedelta

from sqlalchemy import Enum, func
Expand Down Expand Up @@ -208,30 +207,32 @@ def compile_imagegen_stats_totals():
return stats_dict


def compile_imagegen_stats_models(model_type = 'known'):
def compile_imagegen_stats_models(model_state="known"):
query = db.session.query(ImageGenerationStatistic.model, func.count()).group_by(ImageGenerationStatistic.model)
def check_model_type(model_name):
if model_type == 'known' and model_reference.is_known_image_model(model_name):
return True
if model_type == 'custom' and not model_reference.is_known_image_model(model_name):
return True
if model_type == 'all':

def check_model_state(model_name):
if model_state == "known" and model_reference.is_known_image_model(model_name):
return True
if model_state == "custom" and not model_reference.is_known_image_model(model_name):
return True
if model_state == "all":
return True
return False

return {
"total": {model: count for model, count in query.all() if check_model_type(model)},
"total": {model: count for model, count in query.all() if check_model_state(model)},
"day": {
model: count
for model, count in query.filter(
ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1),
).all()
if check_model_type(model)
if check_model_state(model)
},
"month": {
model: count
for model, count in query.filter(
ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).all()
if check_model_type(model)
if check_model_state(model)
},
}
19 changes: 18 additions & 1 deletion horde/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def get_available_models(filter_model_name: str = None):
return list(models_dict.values())


def retrieve_available_models(model_type=None, min_count=None, max_count=None):
def retrieve_available_models(model_type=None, min_count=None, max_count=None, model_state="known"):
"""Retrieves model details from Redis cache, or from DB if cache is unavailable"""
if hr.horde_r is None:
return get_available_models()
Expand All @@ -384,6 +384,23 @@ def retrieve_available_models(model_type=None, min_count=None, max_count=None):
models_ret = [md for md in models_ret if md["count"] >= min_count]
if max_count is not None:
models_ret = [md for md in models_ret if md["count"] <= max_count]

def check_model_state(model_name):
if model_type is None:
return True
model_check = model_reference.is_known_image_model
if model_type == "text":
model_check = model_reference.is_known_text_model
if model_state == "known" and model_check(model_name):
return True
if model_state == "custom" and not model_check(model_name):
return True
if model_state == "all":
return True
return False

models_ret = [md for md in models_ret if check_model_state(md["name"])]

return models_ret


Expand Down

0 comments on commit c0d4c4f

Please sign in to comment.