Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: caching improvements for stats endpoints #420

Merged
merged 11 commits into from
Jun 10, 2024
8 changes: 4 additions & 4 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from horde.apis.v2.base import GenerateTemplate, JobPopTemplate, JobSubmitTemplate, api
from horde.classes.base import settings
from horde.classes.kobold.genstats import (
compile_textgen_stats_models,
compile_textgen_stats_totals,
get_compiled_textgen_stats_models,
get_compiled_textgen_stats_totals,
)
from horde.classes.kobold.waiting_prompt import TextWaitingPrompt
from horde.classes.kobold.worker import TextWorker
Expand Down Expand Up @@ -356,7 +356,7 @@ def get(self):
"""Details how many texts have been generated in the past minux,hour,day,month and total
Also shows the amount of pixelsteps for the same timeframe.
"""
return compile_textgen_stats_totals(), 200
return get_compiled_textgen_stats_totals(), 200


class TextHordeStatsModels(Resource):
Expand All @@ -380,7 +380,7 @@ class TextHordeStatsModels(Resource):
)
def get(self):
"""Details how many texts were generated per model for the past day, month and total"""
return compile_textgen_stats_models(), 200
return get_compiled_textgen_stats_models(), 200


class KoboldKudosTransfer(Resource):
Expand Down
9 changes: 5 additions & 4 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from horde.classes.base import settings
from horde.classes.base.user import User
from horde.classes.stable.genstats import (
compile_imagegen_stats_models,
compile_imagegen_stats_totals,
get_compiled_imagegen_stats_models,
get_compiled_imagegen_stats_totals,
)
from horde.classes.stable.interrogation import Interrogation
from horde.classes.stable.interrogation_worker import InterrogationWorker
Expand Down Expand Up @@ -573,6 +573,7 @@ def post(self):
if "blacklist" in post_ret.get("skipped", {}):
db_skipped["blacklist"] = post_ret["skipped"]["blacklist"]
post_ret["skipped"] = db_skipped

return post_ret, retcode

def check_in(self):
Expand Down Expand Up @@ -1272,7 +1273,7 @@ def get(self):
"""Details how many images have been generated in the past minux,hour,day,month and total
Also shows the amount of pixelsteps for the same timeframe.
"""
return compile_imagegen_stats_totals(), 200
return get_compiled_imagegen_stats_totals(), 200


class ImageHordeStatsModels(Resource):
Expand Down Expand Up @@ -1312,4 +1313,4 @@ def get(self):
self.args = self.get_parser.parse_args()
if self.args.model_state not in ["known", "custom", "all"]:
raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']")
return compile_imagegen_stats_models(self.args.model_state), 200
return get_compiled_imagegen_stats_models(self.args.model_state), 200
29 changes: 28 additions & 1 deletion horde/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import horde.classes.base.stats # noqa 401
from horde.argparser import args
from horde.classes.base.detection import Filter # noqa 401
Expand All @@ -10,13 +12,14 @@
from horde.classes.kobold.worker import TextWorker # noqa 401
from horde.classes.stable.interrogation import Interrogation # noqa 401
from horde.classes.stable.interrogation_worker import InterrogationWorker # noqa 401
from horde.classes.stable.known_image_models import KnownImageModel # noqa 401

# Importing for DB creation

# noqa 401
from horde.classes.stable.waiting_prompt import ImageWaitingPrompt # noqa 401
from horde.classes.stable.worker import ImageWorker # noqa 401
from horde.flask import HORDE, db
from horde.logger import logger
from horde.utils import hash_api_key

with HORDE.app_context():
Expand All @@ -29,6 +32,29 @@
# sys.exit()
db.create_all()

sql_statement_dir = Path(__file__).parent.parent.parent / "sql_statements"

# The order of these directories is important. `cron` creates a stored procedure that is
# used by queries in all other `cron_jobs/` directories.
all_dirs_to_run = [
"cron/", # Must be first
"stored_procedures/",
"stored_procedures/cron_jobs/",
]

all_dirs_to_run = [sql_statement_dir / dir for dir in all_dirs_to_run]

with logger.catch(reraise=True):
for dir in all_dirs_to_run:
logger.info(f"Running files in {dir}")
for file in dir.iterdir():
if file.suffix == ".sql":
logger.info(f"Running {file}")
with file.open() as f:
db.session.execute(f.read())

db.session.commit()

if args.convert_flag == "roles":
# from horde.conversions import convert_user_roles

Expand Down Expand Up @@ -65,6 +91,7 @@
"TextWaitingPrompt",
"Interrogation",
"InterrogationWorker",
"KnownImageModel",
"User",
"Team",
"ImageWorker",
Expand Down
2 changes: 1 addition & 1 deletion horde/classes/base/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class User(db.Model):
oauth_id = db.Column(db.String(50), unique=True, nullable=False, index=True)
api_key = db.Column(db.String(100), unique=True, nullable=False, index=True)
client_id = db.Column(db.String(50), unique=True, default=generate_client_id, nullable=False)
created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
created = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True)
last_active = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
contact = db.Column(db.String(50), default=None)
admin_comment = db.Column(db.Text, default=None)
Expand Down
162 changes: 79 additions & 83 deletions horde/classes/kobold/genstats.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
from datetime import datetime, timedelta
from datetime import datetime

from sqlalchemy import Enum, func
from sqlalchemy import Enum

from horde.enums import ImageGenState
from horde.flask import db


class TextGenerationStatistic(db.Model):
__tablename__ = "text_gen_stats"
id = db.Column(db.Integer, primary_key=True)
finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow)
# Created comes from the procgen
created = db.Column(db.DateTime(timezone=False), nullable=True)
model = db.Column(db.String(255), index=True, nullable=False)
max_length = db.Column(db.Integer, nullable=False)
max_context_length = db.Column(db.Integer, nullable=False)
softprompt = db.Column(db.Integer, nullable=True)
prompt_length = db.Column(db.Integer, nullable=False)
client_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
bridge_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
state = db.Column(Enum(ImageGenState), default=ImageGenState.OK, nullable=False, index=True)


def record_text_statistic(procgen):
state = ImageGenState.OK
# Currently there's no way to record cancelled images, but maybe there will be in the future
Expand All @@ -44,71 +28,83 @@ def record_text_statistic(procgen):
db.session.commit()


def compile_textgen_stats_totals():
count_query = db.session.query(TextGenerationStatistic)
count_minute = count_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1),
).count()
count_hour = count_query.filter(TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1)).count()
count_day = count_query.filter(TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1)).count()
count_month = count_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).count()
count_total = count_query.count()
tokens_query = db.session.query(func.sum(TextGenerationStatistic.max_length))
tokens_minute = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(minutes=1),
).scalar()
tokens_hour = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(hours=1),
).scalar()
tokens_day = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1),
).scalar()
tokens_month = tokens_query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).scalar()
tokens_total = tokens_query.scalar()
stats_dict = {
"minute": {
"requests": count_minute,
"tokens": tokens_minute,
},
"hour": {
"requests": count_hour,
"tokens": tokens_hour,
},
"day": {
"requests": count_day,
"tokens": tokens_day,
},
"month": {
"requests": count_month,
"tokens": tokens_month,
},
"total": {
"requests": count_total,
"tokens": tokens_total,
},
}
class TextGenerationStatistic(db.Model):
__tablename__ = "text_gen_stats"
id = db.Column(db.Integer, primary_key=True)
finished = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True)
# Created comes from the procgen
created = db.Column(db.DateTime(timezone=False), nullable=True)
model = db.Column(db.String(255), nullable=False, index=True)
max_length = db.Column(db.Integer, nullable=False)
max_context_length = db.Column(db.Integer, nullable=False)
softprompt = db.Column(db.Integer, nullable=True)
prompt_length = db.Column(db.Integer, nullable=False)
client_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
bridge_agent = db.Column(db.Text, default="unknown:0:unknown", nullable=False, index=True)
state = db.Column(Enum(ImageGenState), default=ImageGenState.OK, nullable=False, index=True)


class CompiledTextGensStatsTotals(db.Model):
__tablename__ = "compiled_text_gen_stats_totals"
id = db.Column(db.Integer, primary_key=True)
created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True)
minute_requests = db.Column(db.Integer, nullable=False)
minute_tokens = db.Column(db.Integer, nullable=False)
hour_requests = db.Column(db.Integer, nullable=False)
hour_tokens = db.Column(db.Integer, nullable=False)
day_requests = db.Column(db.Integer, nullable=False)
day_tokens = db.Column(db.Integer, nullable=False)
month_requests = db.Column(db.Integer, nullable=False)
month_tokens = db.Column(db.Integer, nullable=False)
total_requests = db.Column(db.Integer, nullable=False)
total_tokens = db.Column(db.BigInteger, nullable=False)


def get_compiled_textgen_stats_totals() -> dict[str, dict[str, int]]:
"""Get the compiled text generation statistics for the minute, hour, day, month, and total periods.

Returns:
dict[str, dict[str, int]]: A dictionary with the period as the key and the requests and tokens as the values.
"""
query = db.session.query(CompiledTextGensStatsTotals).order_by(CompiledTextGensStatsTotals.created.desc()).first()

periods = ["minute", "hour", "day", "month", "total"]
stats_dict = {period: {"requests": 0, "tokens": 0} for period in periods}

if query:
for period in periods:
stats_dict[period]["requests"] = getattr(query, f"{period}_requests")
stats_dict[period]["tokens"] = getattr(query, f"{period}_tokens")

return stats_dict


def compile_textgen_stats_models():
query = db.session.query(TextGenerationStatistic.model, func.count()).group_by(TextGenerationStatistic.model)
ret_dict = {
"total": {model: count for model, count in query.all()},
"day": {
model: count
for model, count in query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1),
).all()
},
"month": {
model: count
for model, count in query.filter(
TextGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30),
).all()
},
}
return ret_dict
class CompiledTextGenStatsModels(db.Model):
__tablename__ = "compiled_text_gen_stats_models"
id = db.Column(db.Integer, primary_key=True)
created = db.Column(db.DateTime(timezone=False), default=datetime.utcnow, index=True)
model = db.Column(db.String(255), nullable=False, index=True)
day_requests = db.Column(db.Integer, nullable=False)
month_requests = db.Column(db.Integer, nullable=False)
total_requests = db.Column(db.Integer, nullable=False)


def get_compiled_textgen_stats_models() -> dict[str, dict[str, int]]:
"""Get the compiled text generation statistics for the day, month, and total periods for each model.

Returns:
dict[str, dict[str, int]]: A dictionary with the model as the key and the requests as the values.
"""

models: tuple[CompiledTextGenStatsModels] = (
db.session.query(CompiledTextGenStatsModels).order_by(CompiledTextGenStatsModels.created.desc()).all()
)

periods = ["day", "month", "total"]
stats = {period: {model.model: 0 for model in models} for period in periods}

for model in models:
for period in periods:
stats[period][model.model] = getattr(model, f"{period}_requests")

return stats
Loading
Loading