From d0612513bea093d559c4860344003a72bea28c4a Mon Sep 17 00:00:00 2001 From: Adrien Carpentier Date: Wed, 19 Jun 2024 14:47:36 +0200 Subject: [PATCH] feat: add change model command --- app/commands.py | 38 ++++++++++++++++++++++++++++++++++++-- app/pyalbert_utils.py | 10 ++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/app/commands.py b/app/commands.py index a976214..1bd0800 100755 --- a/app/commands.py +++ b/app/commands.py @@ -12,7 +12,13 @@ from matrix_bot.config import bot_lib_config, logger from matrix_bot.eventparser import EventNotConcerned, EventParser from nio import Event, RoomMemberEvent, RoomMessageText -from pyalbert_utils import generate, generate_sources, get_available_modes, new_chat +from pyalbert_utils import ( + generate, + generate_sources, + get_available_models, + get_available_modes, + new_chat, +) @dataclass @@ -228,6 +234,34 @@ async def albert_debug(ep: EventParser, matrix_client: MatrixClient): await matrix_client.send_markdown_message(ep.room.room_id, debug_message) +@register_feature( + group="albert", + onEvent=RoomMessageText, + command="model", + help=f"Pour modifier le modèle, utilisez **{COMMAND_PREFIX}model** MODEL_NAME", + hidden=True, +) +async def albert_model(ep: EventParser, matrix_client: MatrixClient): + config = user_configs[ep.sender] + await matrix_client.room_typing(ep.room.room_id) + commands = ep.event.body.split() + # Get all available models + all_models = get_available_models(config) + if len(commands) <= 1: + message = ( + f"La commande !model nécessite de donner un modèle parmi : {', '.join(all_models)}" + ) + else: + model = commands[1] + if model not in all_models: + message = f"Modèle inconnu. Les modèles disponibles sont : {', '.join(all_models)}" + else: + previous_model = config.albert_model_name + config.albert_model_name = model + message = f"Le modèle a été modifié : {previous_model} -> {model}" + await matrix_client.send_text_message(ep.room.room_id, message) + + @register_feature( group="albert", onEvent=RoomMessageText, @@ -251,7 +285,7 @@ async def albert_mode(ep: EventParser, matrix_client: MatrixClient): else: old_mode = config.albert_mode config.albert_mode = mode - message = f"Le mode a été modifié: {old_mode} -> {mode}" + message = f"Le mode a été modifié : {old_mode} -> {mode}" await matrix_client.send_text_message(ep.room.room_id, message) diff --git a/app/pyalbert_utils.py b/app/pyalbert_utils.py index ff23799..559a009 100755 --- a/app/pyalbert_utils.py +++ b/app/pyalbert_utils.py @@ -19,6 +19,16 @@ def log_and_raise_for_status(response: requests.Response) -> None: response.raise_for_status() +def get_available_models(config: Config) -> list[str] | None: + api_token = config.albert_api_token + url = config.albert_api_url + headers = {"Authorization": f"Bearer {api_token}"} + response = requests.get(f"{url}/models", headers=headers) + log_and_raise_for_status(response) + model_prompts = response.json() + return list(model_prompts.keys()) + + def get_available_modes(config: Config) -> list[str] | None: api_token = config.albert_api_token url = config.albert_api_url