Skip to content

Commit

Permalink
feat: single collections
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Nov 1, 2024
1 parent d645537 commit 2e5ab12
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 55 deletions.
23 changes: 23 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,29 @@

class Parsers:
def __init__(self):
# A Basic parser which only expects a Client-Agent
self.basic_parser = reqparse.RequestParser()
self.basic_parser.add_argument(
"Client-Agent",
default="unknown:0:unknown",
type=str,
required=False,
help="The client name and version",
location="headers",
)

# A Basic parser which only expects a Client-Agent and an API Key
self.apikey_parser = reqparse.RequestParser()
self.apikey_parser.add_argument("apikey", type=str, required=True, help="A mod API key.", location="headers")
self.apikey_parser.add_argument(
"Client-Agent",
default="unknown:0:unknown",
type=str,
required=False,
help="The client name and version",
location="headers",
)

self.generate_parser = reqparse.RequestParser()
self.generate_parser.add_argument(
"apikey",
Expand Down
2 changes: 2 additions & 0 deletions horde/apis/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
api.add_resource(kobold.SingleTextStyle, "/styles/text/<string:style_id>")
api.add_resource(kobold.SingleImageStyleByName, "/styles/text_by_name/<string:style_name>")
api.add_resource(base.Collection, "/collections")
api.add_resource(base.SingleCollection, "/collections/<string:collection_id>")
api.add_resource(base.SingleCollectionByName, "/collection_by_name/<string:collection_name>")
api.add_resource(base.Users, "/users")
api.add_resource(base.UserSingle, "/users/<string:user_id>")
api.add_resource(base.FindUser, "/find_user")
Expand Down
224 changes: 189 additions & 35 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3173,16 +3173,6 @@ def validate(self):
class SingleStyleTemplateGet(Resource):
gentype = "template"

get_parser = reqparse.RequestParser()
get_parser.add_argument(
"Client-Agent",
default="unknown:0:unknown",
type=str,
required=False,
help="The client name and version.",
location="headers",
)

def get_existing_style(self):
if self.existing_style.style_type != self.gentype:
raise e.BadRequest(
Expand All @@ -3192,9 +3182,9 @@ def get_existing_style(self):
return self.existing_style.get_details()

def get_through_id(self, style_id):
self.existing_style = database.get_style_by_uuid(style_id)
self.existing_style = database.get_style_by_uuid(style_id, is_collection=False)
if not self.existing_style:
raise e.ThingNotFound("Image Style", style_id)
raise e.ThingNotFound(f"{self.gentype} Style", style_id)
return self.get_existing_style()


Expand All @@ -3217,9 +3207,7 @@ def patch(self, style_id):
self.user = database.find_user_by_api_key(self.args["apikey"])
if not self.user:
raise e.InvalidAPIKey("Style PATCH")
if self.user.is_anon():
raise e.Forbidden("Anonymous users cannot update styles", rc="StylesAnonForbidden")
self.existing_style = database.get_style_by_uuid(style_id)
self.existing_style = database.get_style_by_uuid(style_id, is_collection=False)
if not self.existing_style:
raise e.ThingNotFound("Style", style_id)
if self.existing_style.user_id != self.user.id:
Expand Down Expand Up @@ -3274,31 +3262,20 @@ def patch(self, style_id):
def validate(self):
pass

delete_parser = reqparse.RequestParser()
delete_parser.add_argument("apikey", type=str, required=True, help="A mod API key.", location="headers")
delete_parser.add_argument(
"Client-Agent",
default="unknown:0:unknown",
type=str,
required=False,
help="The client name and version",
location="headers",
)

def delete(self, style_id):
self.args = self.delete_parser.parse_args()
self.args = parsers.apikey_parser.parse_args()
self.user = database.find_user_by_api_key(self.args["apikey"])
if not self.user:
raise e.InvalidAPIKey("Style DELETE")
if self.user.is_anon():
raise e.Forbidden("Anonymous users cannot delete styles", rc="StylesAnonForbidden")
self.existing_style = database.get_style_by_uuid(style_id)
self.existing_style = database.get_style_by_uuid(style_id, is_collection=False)
if not self.existing_style:
raise e.ThingNotFound("Style", style_id)
if self.existing_style.user_id != self.user.id and not self.existing_style.user.moderator:
raise e.Forbidden(f"This Style is not owned by user {self.user.get_unique_alias()}")
if self.existing_style.user_id != self.user.id and self.existing_style.user.moderator:
logger.info(f"Moderator {self.existing_style.user.moderator} deleted style {self.existing_style.id}")
if self.existing_style.user_id != self.user.id and self.user.moderator:
logger.info(f"Moderator {self.user.moderator} deleted style {self.existing_style.id}")
self.existing_style.delete()
return ({"message": "OK"}, 200)

Expand Down Expand Up @@ -3357,13 +3334,11 @@ def get(self):
raise e.BadRequest("'model_state' needs to be one of ['popular', 'age']")
if self.args.type not in ["all", "image", "text"]:
raise e.BadRequest("'type' needs to be one of ['all', 'image', 'text']")
logger.debug([self.args.sort, self.args.page - 1, self.args.type])
collections = database.retrieve_available_collections(
sort=self.args.sort,
page=self.args.page - 1,
collection_type=self.args.type if self.args.type in ["image", "text"] else None,
)
logger.debug(collections)
collections_ret = [co.get_details() for co in collections]
return collections_ret, 200

Expand Down Expand Up @@ -3405,7 +3380,7 @@ def get(self):
post_parser.add_argument(
"styles",
type=list,
required=False,
required=True,
location="json",
)

Expand Down Expand Up @@ -3443,9 +3418,9 @@ def post(self):
if self.user.is_anon():
raise e.Forbidden("Anonymous users cannot create collections", rc="StylesAnonForbidden")
for st in self.args.styles:
existing_style = database.get_style_by_uuid(st)
existing_style = database.get_style_by_uuid(st, is_collection=False)
if not existing_style:
existing_style = database.get_style_by_name(st)
existing_style = database.get_style_by_name(st, is_collection=False)
if not existing_style:
raise e.BadRequest(f"A style with name '{st}' cannot be found")
if styles_type is None:
Expand All @@ -3469,6 +3444,185 @@ def post(self):
}, 200


class SingleCollectionGet(Resource):

def get_through_id(self, style_id):
self.existing_collection = database.get_style_by_uuid(style_id, is_collection=True)
if not self.existing_collection:
raise e.ThingNotFound("Collection", style_id)
return self.existing_collection.get_details()


class SingleCollection(SingleCollectionGet):
args = None

@cache.cached(timeout=30, query_string=True)
@api.expect(parsers.basic_parser)
@api.marshal_with(
models.response_model_collection,
code=200,
description="Lists collection information",
as_list=False,
)
def get(self, collection_id):
return super().get_through_id(collection_id)

patch_parser = reqparse.RequestParser()
patch_parser.add_argument(
"apikey",
type=str,
required=True,
help="The API Key corresponding to a registered user.",
location="headers",
)
patch_parser.add_argument(
"Client-Agent",
default="unknown:0:unknown",
type=str,
required=False,
help="The client name and version",
location="headers",
)
patch_parser.add_argument(
"name",
type=str,
required=False,
location="json",
)
patch_parser.add_argument(
"info",
type=str,
required=False,
location="json",
)
patch_parser.add_argument(
"public",
type=bool,
required=False,
location="json",
)
patch_parser.add_argument(
"styles",
type=list,
required=False,
location="json",
)

decorators = [
limiter.limit(
limit_value=lim.get_request_90min_limit_per_ip,
key_func=lim.get_request_path,
),
limiter.limit(limit_value=lim.get_request_2sec_limit_per_ip, key_func=lim.get_request_path),
]

@api.expect(patch_parser, models.input_model_collection, validate=True)
@api.marshal_with(
models.response_model_styles_post,
code=202,
description="Collection Modified",
skip_none=True,
)
@api.response(400, "Validation Error", models.response_model_validation_errors)
@api.response(401, "Invalid API Key", models.response_model_error)
def patch(self, collection_id):
self.warnings = set()
# For styles, we just store the models in the params
self.styles = []
styles_type = None
self.args = self.patch_parser.parse_args()
if self.args.styles:
if len(self.args.styles) < 1:
raise e.BadRequest("A collection has to include at least 1 style")
for st in self.args.styles:
existing_style = database.get_style_by_uuid(st, is_collection=False)
if not existing_style:
existing_style = database.get_style_by_name(st, is_collection=False)
if not existing_style:
raise e.BadRequest(f"A style with name '{st}' cannot be found")
if styles_type is None:
styles_type = existing_style.style_type
elif styles_type != existing_style.style_type:
raise e.BadRequest("Cannot mix image and text styles in the same collection", "StyleMismatch")
self.styles.append(existing_style)
self.user = database.find_user_by_api_key(self.args["apikey"])
if not self.user:
raise e.InvalidAPIKey("Collection PATCH")
self.existing_collection = database.get_style_by_uuid(collection_id, is_collection=True)
if not self.existing_collection:
raise e.ThingNotFound("Collection", collection_id)
if self.existing_collection.user_id != self.user.id:
raise e.Forbidden(f"This Collection is not owned by user {self.user.get_unique_alias()}")
if self.existing_collection.style_type != styles_type:
raise e.BadRequest("Cannot mix image and text styles in the same collection", "StyleMismatch")
collection_modified = False
if self.args.name:
self.existing_collection.name = ensure_clean(self.args.name, "collection name")
collection_modified = True
if self.args.info is not None:
self.existing_collection.info = ensure_clean(self.args.info, "style info")
collection_modified = True
if self.args.public is not None:
self.existing_collection.public = self.args.public
collection_modified = True
if len(self.styles) > 0:
self.existing_collection.styles.clear()
for st in self.styles:
self.existing_collection.styles.append(st)
collection_modified = True
if not collection_modified:
return {
"id": self.existing_collection.id,
"message": "OK",
}, 200
db.session.commit()
return {
"id": self.existing_collection.id,
"message": "OK",
"warnings": self.warnings,
}, 200

@api.expect(parsers.apikey_parser)
@api.marshal_with(
models.response_model_simple_response,
code=200,
description="Operation Completed",
skip_none=True,
)
@api.response(400, "Validation Error", models.response_model_validation_errors)
@api.response(401, "Invalid API Key", models.response_model_error)
def delete(self, collection_id):
self.args = parsers.apikey_parser.parse_args()
self.user = database.find_user_by_api_key(self.args["apikey"])
if not self.user:
raise e.InvalidAPIKey("Collection PATCH")
self.existing_collection = database.get_style_by_uuid(collection_id, is_collection=True)
if not self.existing_collection:
raise e.ThingNotFound("Collection", collection_id)
if self.existing_collection.user_id != self.user.id and not self.user.moderator:
raise e.Forbidden(f"This Collection is not owned by user {self.user.get_unique_alias()}")
if self.existing_collection.user_id != self.user.id and self.user.moderator:
logger.info(f"Moderator {self.user.moderator} deleted collection {self.existing_collection.id}")
self.existing_collection.delete()
return ({"message": "OK"}, 200)


class SingleCollectionByName(SingleCollectionGet):
@cache.cached(timeout=30)
@api.expect(parsers.basic_parser)
@api.marshal_with(
models.response_model_collection,
code=200,
description="Lists collection information by name",
as_list=False,
)
def get(self, collection_name):
self.existing_collection = database.get_style_by_name(collection_name)
if not self.existing_collection:
raise e.ThingNotFound("Collection", collection_name)
return self.existing_collection.get_details()


# style: sfw bool
# style: tags
# style: transfer kudos on use
Expand Down
8 changes: 4 additions & 4 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ class SingleTextStyle(SingleStyleTemplate):
gentype = "text"

@cache.cached(timeout=30)
@api.expect(SingleStyleTemplate.get_parser)
@api.expect(parsers.basic_parser)
@api.marshal_with(
models.response_model_style,
code=200,
Expand Down Expand Up @@ -627,7 +627,7 @@ def validate(self):
param_validator.check_for_special()
param_validator.validate_text_prompt(prompt)

@api.expect(SingleStyleTemplate.delete_parser)
@api.expect(parsers.apikey_parser)
@api.marshal_with(
models.response_model_simple_response,
code=200,
Expand All @@ -644,15 +644,15 @@ class SingleImageStyleByName(SingleStyleTemplateGet):
gentype = "text"

@cache.cached(timeout=30)
@api.expect(SingleStyleTemplate.get_parser)
@api.expect(parsers.basic_parser)
@api.marshal_with(
models.response_model_style,
code=200,
description="Lists image style information by name",
as_list=False,
)
def get(self, style_name):
self.existing_style = database.get_style_by_name(style_name)
self.existing_style = database.get_style_by_name(style_name, is_collection=False)
if not self.existing_style:
raise e.ThingNotFound("Style", style_name)
return super().get_existing_style()
Loading

0 comments on commit 2e5ab12

Please sign in to comment.