diff --git a/poetry.lock b/poetry.lock index aa36ac5..ccd5118 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiofiles" @@ -174,6 +174,30 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "anthropic" +version = "0.18.1" +description = "The official Python library for the anthropic API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "anthropic-0.18.1-py3-none-any.whl", hash = "sha256:b85aee64f619ce1b1964ba733a09adc4053e7bc4e6d4186001229ec191099dcf"}, + {file = "anthropic-0.18.1.tar.gz", hash = "sha256:f5d1caafd43f6cc933a79753a93531605095f040a384f6a900c3de9c3fb6694e"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.7,<5" + +[package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] +vertex = ["google-auth (>=2,<3)"] + [[package]] name = "anyio" version = "4.4.0" @@ -1249,6 +1273,22 @@ files = [ {file = "giturlparse-0.12.0.tar.gz", hash = "sha256:c0fff7c21acc435491b1779566e038757a205c1ffdcb47e4f81ea52ad8c3859a"}, ] +[[package]] +name = "google-ai-generativelanguage" +version = "0.4.0" +description = "Google Ai Generativelanguage API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"}, + {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" + [[package]] name = "google-api-core" version = "2.19.1" @@ -1264,12 +1304,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" @@ -1432,6 +1472,27 @@ files = [ [package.extras] testing = ["pytest"] +[[package]] +name = "google-generativeai" +version = "0.3.2" +description = "Google Generative AI High level API client library and tools." +optional = false +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.4.0" +google-api-core = "*" +google-auth = "*" +protobuf = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "google-resumable-media" version = "2.7.1" @@ -1861,6 +1922,17 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "janus" version = "1.0.0" @@ -3431,6 +3503,8 @@ files = [ {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, + {file = "orjson-3.10.6-cp313-none-win32.whl", hash = "sha256:efdf2c5cde290ae6b83095f03119bdc00303d7a03b42b16c54517baa3c4ca3d0"}, + {file = "orjson-3.10.6-cp313-none-win_amd64.whl", hash = "sha256:8e190fe7888e2e4392f52cafb9626113ba135ef53aacc65cd13109eb9746c43e"}, {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, @@ -3484,6 +3558,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -3504,6 +3579,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -3514,8 +3590,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" @@ -3685,6 +3761,21 @@ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx- test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] type = ["mypy (>=1.8)"] +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "posthog" version = "3.5.0" @@ -4093,6 +4184,28 @@ files = [ {file = "pysbd-0.3.4-py3-none-any.whl", hash = "sha256:cd838939b7b0b185fcf86b0baf6636667dfb6e474743beeff878e9f42e022953"}, ] +[[package]] +name = "pytest" +version = "8.3.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, + {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -4209,6 +4322,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -6381,5 +6495,5 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" -python-versions = ">=3.10.0,<=3.12.4" -content-hash = "8f17406a3a054db21b34f2b46f5622c693cf0b3b3940ae320449f2c009b460c1" +python-versions = ">=3.10.0,<=3.12.8" +content-hash = "82c79021af0eef05de14366ad295df3202c3c2375462cc0d4219e4b28ed27356" diff --git a/pyproject.toml b/pyproject.toml index c6c166d..0f2caf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ repository = "https://github.com/wandb/wandbot" include = ["src/**/*", "LICENSE", "README.md"] [tool.poetry.dependencies] -python = ">=3.10.0,<=3.12.4" +python = ">=3.10.0,<=3.12.8" numpy = "^1.26.1" pandas = "^2.1.2" pydantic-settings = "^2.0.3" @@ -27,6 +27,8 @@ tree-sitter-languages = "^1.7.1" markdownify = "^0.11.6" uvicorn = "^0.24.0" openai = "^1.3.2" +google-generativeai = ">=0.8.3" +anthropic = "^0.18.1" weave = "^0.50.12" colorlog = "^6.8.0" litellm = "^1.15.1" @@ -53,6 +55,16 @@ ragas = "^0.1.7" dataclasses-json = "^0.6.4" llama-index = "^0.10.30" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.3.4" + +[tool.pytest.ini_options] +filterwarnings = [ + 'ignore:.*Type google._upb._message.*uses PyType_Spec.*:DeprecationWarning', + 'ignore:.*custom tp_new.*:DeprecationWarning' +] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/src/wandbot/chat/config.py b/src/wandbot/chat/config.py index 549415a..13c1184 100644 --- a/src/wandbot/chat/config.py +++ b/src/wandbot/chat/config.py @@ -36,7 +36,7 @@ class ChatConfig(BaseSettings): english_reranker_model: str = "rerank-english-v2.0" multilingual_reranker_model: str = "rerank-multilingual-v2.0" # Response synthesis settings - response_synthesizer_model: str = "gpt-4-0125-preview" + response_synthesizer_model: str = "openai/gpt-4-0125-preview" # Format: provider/model_name response_synthesizer_temperature: float = 0.1 - response_synthesizer_fallback_model: str = "gpt-4-0125-preview" + response_synthesizer_fallback_model: str = "openai/gpt-4-0125-preview" # Format: provider/model_name response_synthesizer_fallback_temperature: float = 0.1 diff --git a/src/wandbot/chat/models/__init__.py b/src/wandbot/chat/models/__init__.py new file mode 100644 index 0000000..380eb91 --- /dev/null +++ b/src/wandbot/chat/models/__init__.py @@ -0,0 +1,6 @@ +from .base import ChatModel +from .openai_model import OpenAIChatModel +from .gemini_model import GeminiChatModel +from .anthropic_model import AnthropicChatModel + +__all__ = ["ChatModel", "OpenAIChatModel", "GeminiChatModel", "AnthropicChatModel"] \ No newline at end of file diff --git a/src/wandbot/chat/models/anthropic_model.py b/src/wandbot/chat/models/anthropic_model.py new file mode 100644 index 0000000..bf0afc3 --- /dev/null +++ b/src/wandbot/chat/models/anthropic_model.py @@ -0,0 +1,44 @@ +from typing import List, Dict, Any + +from anthropic import Anthropic + +from .base import ChatModel + +class AnthropicChatModel(ChatModel): + def __init__(self, model_name: str = "claude-3-opus-20240229", temperature: float = 0.1): + super().__init__(model_name, temperature) + self.client = Anthropic() + + def generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + # Convert messages to Anthropic format + anthropic_messages = [] + for msg in messages: + role = msg["role"] + if role == "system": + anthropic_messages.append({"role": "assistant", "content": msg["content"]}) + elif role == "user": + anthropic_messages.append({"role": "user", "content": msg["content"]}) + elif role == "assistant": + anthropic_messages.append({"role": "assistant", "content": msg["content"]}) + + response = self.client.messages.create( + model=self.model_name, + messages=anthropic_messages, + temperature=self.temperature, + max_tokens=max_tokens, + ) + + return { + "content": response.content[0].text, + "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + } + + @property + def system_role_key(self) -> str: + return "system" # Will be converted to assistant role in generate_response \ No newline at end of file diff --git a/src/wandbot/chat/models/base.py b/src/wandbot/chat/models/base.py new file mode 100644 index 0000000..5c1d3a7 --- /dev/null +++ b/src/wandbot/chat/models/base.py @@ -0,0 +1,97 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Type +from dataclasses import dataclass + +@dataclass +class ModelError: + """Structured error information.""" + type: str # e.g., "auth_error", "rate_limit", "context_length", etc. + message: str # Human-readable error message + code: Optional[str] = None # Provider-specific error code if available + retryable: bool = False # Whether the error is potentially retryable + +class ChatModel(ABC): + def __init__( + self, + model_name: str, + temperature: float = 0.1, + fallback_model: Optional['ChatModel'] = None, + ): + if not 0 <= temperature <= 1: + raise ValueError("Temperature must be between 0 and 1") + + self.model_name = model_name + self.temperature = temperature + self.fallback_model = fallback_model + + def generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + """Generate a response from the model with fallback support. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys + max_tokens: Maximum number of tokens to generate + + Returns: + Dictionary containing: + - content: The generated response text (empty string if error) + - total_tokens: Total tokens used (0 if error) + - prompt_tokens: Tokens used in the prompt (0 if error) + - completion_tokens: Tokens used in the completion (0 if error) + - error: None if successful, ModelError instance if failed + - model_used: Name of the model that generated the response + """ + try: + response = self._generate_response(messages, max_tokens) + response["model_used"] = self.model_name + return response + except Exception as e: + error = self._map_error(e) if hasattr(self, '_map_error') else ModelError( + type="unknown_error", + message=str(e), + retryable=True + ) + + # If error is retryable and we have a fallback model, try it + if error.retryable and self.fallback_model: + try: + fallback_response = self.fallback_model.generate_response(messages, max_tokens) + if not fallback_response.get("error"): + return fallback_response + except Exception: + # If fallback fails, return original error + pass + + return self._create_error_response(error) + + @abstractmethod + def _generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + """Internal method to generate a response from the model. + + This method should be implemented by each model provider. + """ + pass + + @property + @abstractmethod + def system_role_key(self) -> str: + """Return the key used for system role in messages.""" + pass + + def _create_error_response(self, error: ModelError) -> Dict[str, Any]: + """Create a standardized error response.""" + return { + "content": "", + "total_tokens": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "error": error, + "model_used": self.model_name + } \ No newline at end of file diff --git a/src/wandbot/chat/models/gemini_model.py b/src/wandbot/chat/models/gemini_model.py new file mode 100644 index 0000000..3d84889 --- /dev/null +++ b/src/wandbot/chat/models/gemini_model.py @@ -0,0 +1,126 @@ +from typing import List, Dict, Any, Optional +import google.generativeai as genai +from google.api_core import exceptions as google_exceptions + +from .base import ChatModel, ModelError + +class GeminiChatModel(ChatModel): + ERROR_MAPPING = { + # Auth and permissions + google_exceptions.PermissionDenied: ("auth_error", "Invalid API key or authentication failed", False), + + # Rate limits and quotas + google_exceptions.ResourceExhausted: ("rate_limit", "Rate limit or quota exceeded", True), + + # Invalid requests + google_exceptions.InvalidArgument: ("invalid_request", "Invalid request parameters", False), + ValueError: ("invalid_input", "Invalid input format or parameters", False), + + # Network and connectivity + ConnectionError: ("network_error", "Failed to connect to API", True), + TimeoutError: ("timeout", "Request timed out", True), + + # Server errors + google_exceptions.InternalServerError: ("server_error", "Gemini API server error", True), + + # Model errors + google_exceptions.NotFound: ("model_error", "Model not found or unavailable", False), + google_exceptions.FailedPrecondition: ("model_error", "Model is not ready or unavailable", True), + } + + def __init__( + self, + model_name: str = "gemini-pro", + temperature: float = 0.1, + fallback_model: Optional['ChatModel'] = None, + ): + super().__init__(model_name, temperature, fallback_model) + try: + self.model = genai.GenerativeModel(model_name) + except Exception as e: + # Handle initialization errors (e.g., invalid model name) + error = self._map_error(e) + raise RuntimeError(f"Failed to initialize Gemini model: {error.message}") + + def _map_error(self, error: Exception) -> ModelError: + """Map Gemini API errors to standardized ModelError.""" + for error_type, (type_str, msg, retryable) in self.ERROR_MAPPING.items(): + if isinstance(error, error_type): + return ModelError( + type=type_str, + message=str(error) or msg, + code=getattr(error, 'code', None), + retryable=retryable + ) + + # Handle safety-related errors + if hasattr(error, 'prompt_feedback'): + return ModelError( + type="safety_error", + message="Content filtered due to safety concerns", + code="SAFETY_BLOCK", + retryable=False + ) + + # Default error handling + return ModelError( + type="unknown_error", + message=str(error) or "An unknown error occurred", + retryable=False + ) + + def _generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + # Convert messages to Gemini format + gemini_messages = [] + for msg in messages: + if msg["role"] == "system": + continue + elif msg["role"] == "user": + gemini_messages.append({"role": "user", "parts": [msg["content"]]}) + elif msg["role"] == "assistant": + gemini_messages.append({"role": "model", "parts": [msg["content"]]}) + + # If there was a system message, prepend it to the first user message + system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None) + if system_msg and gemini_messages: + for msg in gemini_messages: + if msg["role"] == "user": + msg["parts"][0] = f"{system_msg}\n\n{msg['parts'][0]}" + break + + if not gemini_messages: + raise ValueError("No valid messages provided") + + # Get response from model + chat = self.model.start_chat(history=gemini_messages) + response = chat.send_message( + gemini_messages[-1]["parts"][0], + generation_config=genai.GenerationConfig( + temperature=self.temperature, + max_output_tokens=max_tokens, + ) + ) + + # Check for safety blocks or other content filtering + if hasattr(response, 'prompt_feedback') and response.prompt_feedback: + raise ValueError("Content filtered due to safety concerns") + + # Get token counts from usage_metadata + usage = response.usage_metadata + + return { + "content": response.text, + "total_tokens": usage.total_token_count, + "prompt_tokens": usage.prompt_token_count, + "completion_tokens": usage.candidates_token_count, + "error": None + } + + @property + def system_role_key(self) -> str: + """Return the key used for system role in messages.""" + return "system" \ No newline at end of file diff --git a/src/wandbot/chat/models/openai_model.py b/src/wandbot/chat/models/openai_model.py new file mode 100644 index 0000000..3237e19 --- /dev/null +++ b/src/wandbot/chat/models/openai_model.py @@ -0,0 +1,33 @@ +from typing import List, Dict, Any + +from openai import OpenAI + +from .base import ChatModel + +class OpenAIChatModel(ChatModel): + def __init__(self, model_name: str = "gpt-4-0125-preview", temperature: float = 0.1): + super().__init__(model_name, temperature) + self.client = OpenAI() + + def generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=self.temperature, + max_tokens=max_tokens, + ) + + return { + "content": response.choices[0].message.content, + "total_tokens": response.usage.total_tokens, + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + } + + @property + def system_role_key(self) -> str: + return "system" \ No newline at end of file diff --git a/src/wandbot/rag/response_synthesis.py b/src/wandbot/rag/response_synthesis.py index 74c9c09..296bee4 100644 --- a/src/wandbot/rag/response_synthesis.py +++ b/src/wandbot/rag/response_synthesis.py @@ -1,172 +1,107 @@ -from operator import itemgetter -from typing import Any, Dict +from typing import Dict, List, Any -import weave -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel -from langchain_openai import ChatOpenAI - -from wandbot.rag.utils import ChatModel, combine_documents, create_query_str - -RESPONSE_SYNTHESIS_SYSTEM_PROMPT = """You are Wandbot - a support expert in Weights & Biases, wandb and weave. -Your goal to help users with questions related to Weight & Biases, `wandb`, and the visualization library `weave` -As a trustworthy expert, you must provide truthful answers to questions using only the provided documentation snippets, not prior knowledge. -Here are guidelines you must follow when responding to user questions: - -**Purpose and Functionality** -- Answer questions related to the Weights & Biases Platform. -- Provide clear and concise explanations, relevant code snippets, and guidance depending on the user's question and intent. -- Ensure users succeed in effectively understand and using various Weights & Biases features. -- Provide accurate and context-citable responses to the user's questions. - -**Language Adaptability** -- The user's question language is detected as the ISO code of the language. -- Always respond in the detected question language. - -**Specificity** -- Be specific and provide details only when required. -- Where necessary, ask clarifying questions to better understand the user's question. -- Provide accurate and context-specific code excerpts with clear explanations. -- Ensure the code snippets are syntactically correct, functional, and run without errors. -- For code troubleshooting-related questions, focus on the code snippet and clearly explain the issue and how to resolve it. -- Avoid boilerplate code such as imports, installs, etc. - -**Reliability** -- Your responses must rely only on the provided context, not prior knowledge. -- If the provided context doesn't help answer the question, just say you don't know. -- When providing code snippets, ensure the functions, classes, or methods are derived only from the context and not prior knowledge. -- Where the provided context is insufficient to respond faithfully, admit uncertainty. -- Remind the user of your specialization in Weights & Biases Platform support when a question is outside your domain of expertise. -- Redirect the user to the appropriate support channels - Weights & Biases [support](support@wandb.com) or [community forums](https://wandb.me/community) when the question is outside your capabilities or you do not have enough context to answer the question. - -**Citation** -- Always cite the source from the provided context. -- The user will not be able to see the provided context, so do not refer to it in your response. For instance, don't say "As mentioned in the context...". -- Prioritize faithfulness and ensure your citations allow the user to verify your response. -- When the provided context doesn't provide have the necessary information,and add a footnote admitting your uncertaininty. -- Remember, you must return both an answer and citations. - - -**Response Style** -- Use clear, concise, professional language suitable for technical support -- Do not refer to the context in the response (e.g., "As mentioned in the context...") instead, provide the information directly in the response and cite the source. - - -**Response Formatting** -- Always communicate with the user in Markdown. -- Do not use headers in your output as it will be rendered in slack. -- Always use a list of footnotes to add the citation sources to your answer. - -**Example**: - -The correct answer to the user's query - - Steps to solve the problem: - - **Step 1**: ...[^1], [^2] - - **Step 2**: ...[^1] - ... - - Here's a code snippet[^3] - - ```python - # Code example - ... - ``` - - **Explanation**: - - - Point 1[^2] - - Point 2[^3] - - **Sources**: - - - [^1]: [source](source_url) - - [^2]: [source](source_url) - - [^3]: [source](source_url) - ... -""" - - -RESPONSE_SYNTHESIS_PROMPT_MESSAGES = [ - ("system", RESPONSE_SYNTHESIS_SYSTEM_PROMPT), - ( - "human", - '\n\nsource: https://docs.wandb.ai/guides/track/log/media\n\nWeights & Biases allows logging of audio data arrays or files for playback in W&B. \nYou can use the `wandb.Audio()` to create audio instances and log them to W&B using `wandb.log()`.\n\nLog an audio array or file\nwandb.log({{"my whale song": wandb.Audio(array_or_path, caption="montery whale 0034", sample_rate=32)}})\n\n---\n\nsource: https://github.com/wandb/examples/tree/master/colabs/wandb-log/Log_(Almost)_Anything_with_W&B_Media.ipynb\n\nLog multiple audio files\nLog audio within a W&B Table\n\nmy_table = wandb.Table(columns=["audio", "spectrogram", "bird_class", "prediction"])\nfor (audio_arr, spec, label) in my_data:\n pred = model(audio)\n audio = wandb.Audio(audio_arr, sample_rate=32)\n img = wandb.Image(spec)\n my_table.add_data(audio, img, label, pred)\n\nLog the Table to wandb\nwandb.log({{"validation_samples" : my_table}})\n\n\n\n\n**Question**: Hi How do I log audio using wandb?\n**Langauge**: en\n**Query Intents**: \n- The query is related to troubleshooting code using Weights & Biases\n- The query is related to a feature of Weights & Biases such as Sweeps, Artifacts, Reports, Experiments, Tables, Prompts, Weave, and more\n\n\n\n', - ), - ( - "assistant", - 'To log audio using `wandb`, you can use the `wandb.Audio` class to create audio objects and then log them with `wandb.log`. Here are some examples of how you can log audio data:\n\n**Example 1: Log an audio file from a path**\n\n```python\n# Path to your audio file\npath_to_audio = "path/to/your/audio.wav"\n\n# Log the audio file\nwandb.log({{"audio_example": [wandb.Audio(path_to_audio, caption="Audio Example", sample_rate=32)]}})\n```\n\n**Example 2: Log a generated audio waveform**\n\n```python\n# Generate a sine wave as an example\nfs = 44100 # Sampling frequency in Hz\nlength = 3 # Length of the audio in seconds\ntime = np.linspace(0, length, fs * length)\nwaveform = np.sin(2 * np.pi * 440 * time) # 440 Hz sine wave\n\n# Log the generated waveform\nwandb.log({{"audio_example": [wandb.Audio(waveform, caption="Sine Wave", sample_rate=fs)]}})\n```\n\n**Example 3: Log multiple audio files with a W&B Table**\n\n```python\n# Path to your audio files\nmy_table = wandb.Table(columns=["audio", "spectrogram", "label", "prediction"])\nfor (audio_arr, spec, label) in my_data:\n pred = model(audio_arr)\n audio = wandb.Audio(audio_arr, sample_rate=32)\n img = wandb.Image(spec)\n my_table.add_data(audio, img, label, pred)\n\nwandb.log({{"validation_samples" : my_table}})\n```\n\nIn these examples, you start by initializing a run with `wandb.init`, specifying the project and run name. Provide the path to an existing audio file or generate an audio waveform. Finally, you log the audio using `wandb.log` and the `wandb.Audio` class. The `wandb.Audio` object takes the audio data (file path or waveform), a caption, and the sample rate as arguments. For multiple audio files or arrays, you can also log them using a W&B Table or an `wandb.Artifact` depending on your use case. After logging the data, you finish the run with `wandb.finish`.\n\n**sources**: \n - [Logging Audio](https://docs.wandb.ai/guides/track/log/logging-faqs,)\n - [Logging Tables](https://github.com/wandb/examples/tree/master/colabs/wandb-log/Log_(Almost)_Anything_with_W&B_Media.ipynb)', - ), - ( - "human", - "\n\nsource: https://docs.wandb.ai/guides/track/log/plots\n\nExtensionArray.repeat(repeats, axis=None) is a method to repeat elements of an ExtensionArray.\n---\n\nsource: https://community.wandb.ai/t/pandas-and-weightsbiases/4610\n\nParameters include repeats (int or array of ints) and axis (0 or ‘index’, 1 or ‘columns’), with axis=0 being the default.\n\n\n\n\n\n**Question**: I really like the docs here!!! Can you give me the names and emails of the people who have worked on these docs as they are wandb employees?\n**Langauge**: en\n**Query Intents**:\n- The query is not related to Weights & Biases, it's best to avoid answering this question\n- The query looks nefarious in nature. It's best to avoid answering this question\n\n\n\n", - ), - ( - "assistant", - "Haha, Nice try. But I'm not falling for that. It looks like your question is not related to Weights & Biases. I'm here to assist with wandb-related queries. Please ask a wandb-specific question, and I'll do my best to help you. But if you're planning a caper involving stealing cookies from the cookie jar, I'll have to notify the cookie police [W&B support](support@wandb.com) – they're tough, always crumbly under pressure! 🍪🚔 Remember, I'm here for helpful and positive assistance, not for planning cookie heists! 🛡️😄", - ), - ( - "human", - "\n\n{context_str}\n**Question**: {query_str}\n\n\n\n", - ), -] +from wandbot.chat.models import OpenAIChatModel, GeminiChatModel, AnthropicChatModel +from wandbot.utils import get_logger +logger = get_logger(__name__) class ResponseSynthesizer: - model: ChatModel = ChatModel() - fallback_model: ChatModel = ChatModel(max_retries=6) - def __init__( self, - model: str = "gpt-4-0125-preview", + model: str = "openai/gpt-4-0125-preview", temperature: float = 0.1, - fallback_model: str = "gpt-4-0125-preview", + fallback_model: str = "openai/gpt-4-0125-preview", fallback_temperature: float = 0.1, ): - self.model = {"model_name": model, "temperature": temperature} # type: ignore - self.fallback_model = {"model_name": fallback_model, "temperature": fallback_temperature} # type: ignore - self.prompt = ChatPromptTemplate.from_messages( - RESPONSE_SYNTHESIS_PROMPT_MESSAGES + self.model_str = model + self.temperature = temperature + self.fallback_model_str = fallback_model + self.fallback_temperature = fallback_temperature + + # Initialize primary model + self.model = self._create_model(model, temperature) + # Initialize fallback model + self.fallback_model = self._create_model(fallback_model, fallback_temperature) + + def _create_model(self, model_str: str, temperature: float): + provider, model_name = model_str.split("/", 1) + + if provider == "openai": + return OpenAIChatModel(model_name=model_name, temperature=temperature) + elif provider == "gemini": + return GeminiChatModel(model_name=model_name, temperature=temperature) + elif provider == "anthropic": + return AnthropicChatModel(model_name=model_name, temperature=temperature) + else: + raise ValueError(f"Unsupported model provider: {provider}") + + def _create_messages( + self, context_str: str, query: str, chat_history: List[Dict[str, str]] + ) -> List[Dict[str, str]]: + messages = [] + + # System message with context + system_prompt = ( + "You are WandBot, a helpful AI assistant for Weights & Biases (W&B). " + "Answer the question based on the context below. If you don't know the answer, " + "say that you don't know. Use the following format:\n\n" + "Context: relevant context from W&B documentation\n\n" + "Question: the user's question\n\n" + "Answer: your response\n\n" + "Here's the context:\n\n" + f"{context_str}\n\n" + "Remember:\n" + "1. Only answer based on the context provided\n" + "2. If the context doesn't contain relevant information, say so\n" + "3. Keep responses clear and concise\n" + "4. Include code examples when relevant\n" + "5. Use markdown formatting for better readability" ) - self._chain = None + messages.append({"role": self.model.system_role_key, "content": system_prompt}) - @property - def chain(self) -> Runnable: - if self._chain is None: - base_chain = self._load_chain(self.model) - fallback_chain = self._load_chain(self.fallback_model) - self._chain = base_chain.with_fallbacks([fallback_chain]) - return self._chain + # Add chat history + for msg in chat_history: + messages.append({"role": "user", "content": msg[0]}) + messages.append({"role": "assistant", "content": msg[1]}) - def _load_chain(self, model: ChatOpenAI) -> Runnable: - response_synthesis_chain = ( - RunnableLambda( - lambda x: { - "query_str": create_query_str(x), - "context_str": combine_documents(x["context"]), - } - ) - | RunnableParallel( - query_str=itemgetter("query_str"), - context_str=itemgetter("context_str"), - response_prompt=self.prompt, - ) - | RunnableParallel( - query_str=itemgetter("query_str"), - context_str=itemgetter("context_str"), - response_prompt=RunnableLambda( - lambda x: x["response_prompt"].to_string() - ), - response=itemgetter("response_prompt") - | model - | StrOutputParser(), - response_model=RunnableLambda(lambda x: model.model_name), - ) - ) + # Add current query + messages.append({"role": "user", "content": query}) - return response_synthesis_chain + return messages + + def __call__(self, retrieval_results: Dict[str, Any]) -> Dict[str, Any]: + context = retrieval_results["context"] + query = retrieval_results["query"] + chat_history = retrieval_results.get("chat_history", []) + + # Format context string + context_str = "\n\n".join( + f"Source: {doc.metadata['source']}\n{doc.page_content}" + for doc in context + ) - @weave.op() - def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - return self.chain.invoke(inputs) + messages = self._create_messages(context_str, query, chat_history) + + try: + # Try primary model first + response = self.model.generate_response(messages) + model_used = self.model_str + except Exception as e: + logger.warning(f"Primary model failed: {e}. Falling back to backup model.") + try: + # Fall back to backup model + response = self.fallback_model.generate_response(messages) + model_used = self.fallback_model_str + except Exception as e: + logger.error(f"Both models failed. Last error: {e}") + raise + + return { + "response": response["content"], + "context_str": context_str, + "response_prompt": messages[0]["content"], + "response_model": model_used, + "total_tokens": response["total_tokens"], + "prompt_tokens": response["prompt_tokens"], + "completion_tokens": response["completion_tokens"], + } \ No newline at end of file diff --git a/temp_gemini_full_payload.py b/temp_gemini_full_payload.py new file mode 100644 index 0000000..0c273d0 --- /dev/null +++ b/temp_gemini_full_payload.py @@ -0,0 +1,128 @@ +import os +import json +import google.generativeai as genai +from google.protobuf.json_format import MessageToDict +from pprint import pprint + +# Configure the API +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + +# Initialize model +model = genai.GenerativeModel("gemini-pro") + +# Create a simple message +messages = [ + {"role": "system", "content": "You are a helpful AI assistant. Keep responses very brief."}, + {"role": "user", "content": "What is 2+2? Answer in one word."} +] + +# Convert to Gemini format +gemini_messages = [] +for msg in messages: + if msg["role"] == "system": + continue + elif msg["role"] == "user": + gemini_messages.append({"role": "user", "parts": [msg["content"]]}) + elif msg["role"] == "assistant": + gemini_messages.append({"role": "model", "parts": [msg["content"]]}) + +# If there was a system message, prepend it to the first user message +system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None) +if system_msg and gemini_messages: + for msg in gemini_messages: + if msg["role"] == "user": + msg["parts"][0] = f"{system_msg}\n\n{msg['parts'][0]}" + break + +# Start chat and get response +chat = model.start_chat(history=gemini_messages) +response = chat.send_message( + gemini_messages[-1]["parts"][0], + generation_config=genai.GenerationConfig( + temperature=0.1, + max_output_tokens=1000, + ) +) + +print("\nRaw response object type:", type(response)) +print("=============================") + +print("\nResponse object dir():") +print("=====================") +for attr in dir(response): + if not attr.startswith('_'): + print(f"- {attr}") + +print("\nResponse attributes and their values:") +print("===================================") +for attr in dir(response): + if not attr.startswith('_'): + try: + value = getattr(response, attr) + if not callable(value): + if hasattr(value, '_pb'): + # Convert protobuf message to dict for better visibility + print(f"\n{attr}:") + pprint(MessageToDict(value._pb)) + else: + print(f"\n{attr}:") + pprint(value) + except Exception as e: + print(f"{attr}: ") + +print("\nCandidate details:") +print("=================") +for i, candidate in enumerate(response.candidates): + print(f"\nCandidate {i}:") + print("--------------") + for attr in dir(candidate): + if not attr.startswith('_'): + try: + value = getattr(candidate, attr) + if not callable(value): + if hasattr(value, '_pb'): + print(f"\n{attr}:") + pprint(MessageToDict(value._pb)) + else: + print(f"\n{attr}:") + pprint(value) + except Exception as e: + print(f"{attr}: ") + +# Also show token count information +print("\nToken counting information:") +print("=========================") +prompt_str = gemini_messages[-1]["parts"][0] +token_count = model.count_tokens(prompt_str) +print("\nPrompt token count object:") +print("------------------------") +for attr in dir(token_count): + if not attr.startswith('_'): + try: + value = getattr(token_count, attr) + if not callable(value): + if hasattr(value, '_pb'): + print(f"\n{attr}:") + pprint(MessageToDict(value._pb)) + else: + print(f"\n{attr}:") + pprint(value) + except Exception as e: + print(f"{attr}: ") + +response_token_count = model.count_tokens(response.text) +print("\nResponse token count object:") +print("-------------------------") +for attr in dir(response_token_count): + if not attr.startswith('_'): + try: + value = getattr(response_token_count, attr) + if not callable(value): + if hasattr(value, '_pb'): + print(f"\n{attr}:") + pprint(MessageToDict(value._pb)) + else: + print(f"\n{attr}:") + pprint(value) + except Exception as e: + print(f"{attr}: ") \ No newline at end of file diff --git a/temp_gemini_metadata.py b/temp_gemini_metadata.py new file mode 100644 index 0000000..ef3c5cb --- /dev/null +++ b/temp_gemini_metadata.py @@ -0,0 +1,28 @@ +import os +import google.generativeai as genai + +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) +model = genai.GenerativeModel("gemini-1.5-flash") + +prompt = "The quick brown fox jumps over the lazy dog." + +# Call `count_tokens` to get the input token count (`total_tokens`). +print("total_tokens:", model.count_tokens(prompt)) + +response = model.generate_content(prompt) + +# On the response for `generate_content`, use `usage_metadata` +# to get separate input and output token counts +print("\nResponse usage_metadata:") +print(response.usage_metadata) + +# Also print the full response attributes to see what's available +print("\nAll response attributes:") +for attr in dir(response): + if not attr.startswith('_'): + try: + value = getattr(response, attr) + if not callable(value): + print(f"{attr}: {value}") + except Exception as e: + print(f"Error accessing {attr}: {e}") \ No newline at end of file diff --git a/temp_gemini_simple.py b/temp_gemini_simple.py new file mode 100644 index 0000000..56b4d91 --- /dev/null +++ b/temp_gemini_simple.py @@ -0,0 +1,52 @@ +import os +import google.generativeai as genai +from pprint import pprint + +# Configure the API +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + +# Initialize model +model = genai.GenerativeModel("gemini-pro") + +# Simple test message +response = model.generate_content("What is 2+2? Answer in one word.") + +print("\nResponse type:", type(response)) +print("\nResponse attributes:") +for attr in dir(response): + if not attr.startswith('_'): + print(f"- {attr}") + +print("\nResponse text:", response.text) + +print("\nCandidate details:") +for candidate in response.candidates: + print("\nCandidate attributes:") + for attr in dir(candidate): + if not attr.startswith('_'): + try: + value = getattr(candidate, attr) + if not callable(value): + print(f"- {attr}: {value}") + except Exception as e: + print(f"- Error accessing {attr}: {e}") + +# Try to access usage metadata +try: + print("\nUsage metadata:") + print(response.usage_metadata) +except AttributeError: + print("\nNo usage_metadata attribute found") +except Exception as e: + print(f"\nError accessing usage_metadata: {e}") + +# Try to access raw protobuf +try: + print("\nRaw protobuf fields:") + pb = response._pb + for field in pb.DESCRIPTOR.fields: + print(f"- {field.name}") +except AttributeError: + print("\nNo _pb attribute found") +except Exception as e: + print(f"\nError accessing protobuf: {e}") \ No newline at end of file diff --git a/temp_gemini_stream.py b/temp_gemini_stream.py new file mode 100644 index 0000000..01c18df --- /dev/null +++ b/temp_gemini_stream.py @@ -0,0 +1,59 @@ +import os +import google.generativeai as genai +from pprint import pprint + +# Configure the API +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + +# Initialize model +model = genai.GenerativeModel("gemini-pro") + +# Simple test message with streaming +response = model.generate_content( + "What is 2+2? Answer in one word.", + stream=True +) + +print("\nStreaming response chunks:") +for chunk in response: + print("\nChunk type:", type(chunk)) + print("Chunk attributes:") + for attr in dir(chunk): + if not attr.startswith('_'): + try: + value = getattr(chunk, attr) + if not callable(value): + print(f"- {attr}: {value}") + except Exception as e: + print(f"- Error accessing {attr}: {e}") + + # Try to access usage metadata in chunk + try: + print("\nChunk usage metadata:") + print(chunk.usage_metadata) + except AttributeError: + print("No usage_metadata in chunk") + except Exception as e: + print(f"Error accessing chunk usage_metadata: {e}") + +# Get the aggregated response +response = response.resolve() + +print("\nFinal resolved response:") +print("Text:", response.text) + +# Try accessing usage metadata in final response +try: + print("\nFinal response usage metadata:") + print(response.usage_metadata) +except AttributeError: + print("No usage_metadata in final response") +except Exception as e: + print(f"Error accessing final usage_metadata: {e}") + +# Let's also check the token count method +print("\nToken counting:") +prompt_tokens = model.count_tokens("What is 2+2? Answer in one word.") +response_tokens = model.count_tokens("Four") +print(f"Prompt tokens: {prompt_tokens.total_tokens}") +print(f"Response tokens: {response_tokens.total_tokens}") \ No newline at end of file diff --git a/temp_gemini_test.py b/temp_gemini_test.py new file mode 100644 index 0000000..3b3060a --- /dev/null +++ b/temp_gemini_test.py @@ -0,0 +1,55 @@ +import os +import google.generativeai as genai +from pprint import pprint + +# Configure the API +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + +# Initialize model +model = genai.GenerativeModel("gemini-pro") + +# Create a simple message +messages = [ + {"role": "user", "content": "What is 2+2? Answer in one word."} +] + +# Convert to Gemini format +gemini_messages = [{"role": msg["role"], "parts": [msg["content"]]} for msg in messages] + +# Start chat and get response +chat = model.start_chat(history=gemini_messages) +response = chat.send_message( + gemini_messages[-1]["parts"][0], + generation_config=genai.GenerationConfig( + temperature=0.1, + max_output_tokens=1000, + ) +) + +print("\nFull response object attributes:") +print("================================") +for attr in dir(response): + if not attr.startswith('_'): # Skip private attributes + try: + value = getattr(response, attr) + if not callable(value): # Skip methods + print(f"{attr}: {value}") + except Exception as e: + print(f"{attr}: ") + +print("\nResponse candidates:") +print("===================") +for i, candidate in enumerate(response.candidates): + print(f"\nCandidate {i}:") + for attr in dir(candidate): + if not attr.startswith('_'): # Skip private attributes + try: + value = getattr(candidate, attr) + if not callable(value): # Skip methods + print(f" {attr}: {value}") + except Exception as e: + print(f" {attr}: ") + +print("\nPrompt feedback:") +print("===============") +pprint(response.prompt_feedback) \ No newline at end of file diff --git a/temp_gemini_test_stream.py b/temp_gemini_test_stream.py new file mode 100644 index 0000000..ca85a6a --- /dev/null +++ b/temp_gemini_test_stream.py @@ -0,0 +1,55 @@ +import os +import google.generativeai as genai +from pprint import pprint + +# Configure the API +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + +# Initialize model +model = genai.GenerativeModel("gemini-pro") + +# Create a simple message +messages = [ + {"role": "user", "content": "What is 2+2? Answer in one word."} +] + +# Convert to Gemini format +gemini_messages = [{"role": msg["role"], "parts": [msg["content"]]} for msg in messages] + +# Start chat and get response with streaming +chat = model.start_chat(history=gemini_messages) +response = chat.send_message( + gemini_messages[-1]["parts"][0], + generation_config=genai.GenerationConfig( + temperature=0.1, + max_output_tokens=1000, + ), + stream=True # Enable streaming +) + +print("\nStreaming response chunks:") +print("=========================") +for chunk in response: + print("\nChunk:") + for attr in dir(chunk): + if not attr.startswith('_'): # Skip private attributes + try: + value = getattr(chunk, attr) + if not callable(value): # Skip methods + print(f" {attr}: {value}") + except Exception as e: + print(f" {attr}: ") + +# Let's also try using the count_tokens method +print("\nUsing count_tokens method:") +print("=========================") +prompt_tokens = model.count_tokens("\n".join(msg["content"] for msg in messages)) +print(f"Prompt tokens object:") +for attr in dir(prompt_tokens): + if not attr.startswith('_'): # Skip private attributes + try: + value = getattr(prompt_tokens, attr) + if not callable(value): # Skip methods + print(f" {attr}: {value}") + except Exception as e: + print(f" {attr}: ") \ No newline at end of file diff --git a/temp_gemini_usage_metadata.py b/temp_gemini_usage_metadata.py new file mode 100644 index 0000000..4bc420b --- /dev/null +++ b/temp_gemini_usage_metadata.py @@ -0,0 +1,122 @@ +import os +import google.generativeai as genai +from google.protobuf.json_format import MessageToDict +from pprint import pprint + +# Configure the API +genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + +# Initialize model +model = genai.GenerativeModel("gemini-pro") + +# Create a simple message +messages = [ + {"role": "system", "content": "You are a helpful AI assistant. Keep responses very brief."}, + {"role": "user", "content": "What is 2+2? Answer in one word."} +] + +# Convert to Gemini format +gemini_messages = [] +for msg in messages: + if msg["role"] == "system": + continue + elif msg["role"] == "user": + gemini_messages.append({"role": "user", "parts": [msg["content"]]}) + elif msg["role"] == "assistant": + gemini_messages.append({"role": "model", "parts": [msg["content"]]}) + +# If there was a system message, prepend it to the first user message +system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None) +if system_msg and gemini_messages: + for msg in gemini_messages: + if msg["role"] == "user": + msg["parts"][0] = f"{system_msg}\n\n{msg['parts'][0]}" + break + +# Start chat and get response +chat = model.start_chat(history=gemini_messages) +response = chat.send_message( + gemini_messages[-1]["parts"][0], + generation_config=genai.GenerationConfig( + temperature=0.1, + max_output_tokens=1000, + ) +) + +def print_pb_fields(obj, prefix=""): + """Helper function to print protobuf fields""" + try: + if hasattr(obj, '_pb'): + print(f"{prefix}Fields in {type(obj).__name__}._pb:") + for field in obj._pb.DESCRIPTOR.fields: + value = getattr(obj._pb, field.name, None) + print(f"{prefix}- {field.name}: {value}") + except Exception as e: + print(f"{prefix}Error accessing protobuf fields: {e}") + +print("\nResponse object details:") +print("======================") +print_pb_fields(response, " ") + +print("\nCandidate details:") +print("================") +for i, candidate in enumerate(response.candidates): + print(f"\nCandidate {i}:") + print_pb_fields(candidate, " ") + + print(f"\n Candidate {i} attributes:") + for attr in dir(candidate): + if not attr.startswith('_'): + try: + value = getattr(candidate, attr) + if not callable(value): + print(f"\n {attr}:") + if hasattr(value, '_pb'): + pprint(MessageToDict(value._pb)) + else: + pprint(value) + except Exception as e: + print(f" Error accessing {attr}: {e}") + +print("\nUsage metadata from response:") +print("===========================") +try: + usage_metadata = response.usage_metadata + print("Usage metadata fields:") + print_pb_fields(usage_metadata, " ") +except Exception as e: + print(f"Error accessing usage_metadata: {e}") + +print("\nPrompt feedback:") +print("==============") +try: + prompt_feedback = response.prompt_feedback + print("Prompt feedback fields:") + print_pb_fields(prompt_feedback, " ") +except Exception as e: + print(f"Error accessing prompt_feedback: {e}") + +# Try accessing usage_metadata directly from the protobuf +print("\nTrying to access usage_metadata from protobuf:") +print("==========================================") +try: + if hasattr(response, '_pb'): + print("Response protobuf fields:") + for field in response._pb.DESCRIPTOR.fields: + print(f"- {field.name}") + if field.name == "usage_metadata": + value = getattr(response._pb, field.name) + print(f" Value: {value}") + if hasattr(value, 'DESCRIPTOR'): + for subfield in value.DESCRIPTOR.fields: + subvalue = getattr(value, subfield.name) + print(f" {subfield.name}: {subvalue}") +except Exception as e: + print(f"Error accessing protobuf: {e}") + +# Try accessing raw protobuf message +print("\nRaw protobuf message:") +print("===================") +try: + if hasattr(response, '_pb'): + print(response._pb) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a19d6ff --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import warnings +import pytest + +@pytest.fixture(autouse=True) +def ignore_protobuf_warnings(): + warnings.filterwarnings( + "ignore", + message=".*PyType_Spec.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=".*custom tp_new.*", + category=DeprecationWarning, + ) \ No newline at end of file diff --git a/tests/mock_models.py b/tests/mock_models.py new file mode 100644 index 0000000..73d45f7 --- /dev/null +++ b/tests/mock_models.py @@ -0,0 +1,30 @@ +from typing import List, Dict, Any, Optional +from wandbot.chat.models.base import ChatModel, ModelError + +class MockOpenAIModel(ChatModel): + """Mock OpenAI model for testing.""" + def __init__( + self, + model_name: str = "gpt-4", + temperature: float = 0.1, + fallback_model: Optional['ChatModel'] = None, + ): + super().__init__(model_name, temperature, fallback_model) + + def _generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + # This is a mock that just returns a fixed response + return { + "content": "Response from mock OpenAI", + "total_tokens": 10, + "prompt_tokens": 5, + "completion_tokens": 5, + "error": None + } + + @property + def system_role_key(self) -> str: + return "system" \ No newline at end of file diff --git a/tests/test_model_errors.py b/tests/test_model_errors.py new file mode 100644 index 0000000..356844b --- /dev/null +++ b/tests/test_model_errors.py @@ -0,0 +1,216 @@ +import unittest +from unittest.mock import patch, MagicMock +from google.api_core import exceptions as google_exceptions + +from wandbot.chat.models import GeminiChatModel +from wandbot.chat.models.base import ModelError + +class TestGeminiModelErrors(unittest.TestCase): + def setUp(self): + self.model = GeminiChatModel( + model_name="gemini-pro", + fallback_model="gemini-1.0-pro", # Specify a fallback model + fallback_temperature=0.2 + ) + self.test_messages = [ + {"role": "user", "content": "Hello"} + ] + + def test_auth_error(self): + """Test handling of authentication errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = google_exceptions.PermissionDenied("Invalid API key") + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertEqual(response["total_tokens"], 0) + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "auth_error") + self.assertFalse(response["error"].retryable) + + def test_rate_limit_error(self): + """Test handling of rate limit errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = google_exceptions.ResourceExhausted("Rate limit exceeded") + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "rate_limit") + self.assertTrue(response["error"].retryable) + + def test_context_length_error(self): + """Test handling of context length errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = google_exceptions.InvalidArgument( + "Input too long" + ) + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "invalid_request") + self.assertFalse(response["error"].retryable) + + def test_server_error(self): + """Test handling of server errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = google_exceptions.InternalServerError("Server error") + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "server_error") + self.assertTrue(response["error"].retryable) + + def test_safety_error(self): + """Test handling of safety-related errors.""" + mock_response = MagicMock() + mock_response.prompt_feedback = {"safety": "blocked"} + + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat_instance = MagicMock() + mock_chat_instance.send_message.return_value = mock_response + mock_chat.return_value = mock_chat_instance + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "safety_error") + self.assertEqual(response["error"].code, "SAFETY_BLOCK") + self.assertFalse(response["error"].retryable) + + def test_invalid_input(self): + """Test handling of invalid input.""" + response = self.model.generate_response([]) # Empty messages list + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "invalid_input") + self.assertFalse(response["error"].retryable) + + def test_successful_response_has_no_error(self): + """Test that successful responses have error=None.""" + mock_response = MagicMock() + mock_response.text = "Hello!" + mock_response.prompt_feedback = None + mock_response.usage_metadata = MagicMock( + total_token_count=10, + prompt_token_count=5, + candidates_token_count=5 + ) + + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat_instance = MagicMock() + mock_chat_instance.send_message.return_value = mock_response + mock_chat.return_value = mock_chat_instance + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "Hello!") + self.assertIsNone(response["error"]) + self.assertEqual(response["total_tokens"], 10) + self.assertEqual(response["prompt_tokens"], 5) + self.assertEqual(response["completion_tokens"], 5) + + def test_network_timeout(self): + """Test handling of network timeout errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = TimeoutError("Request timed out") + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "timeout") + self.assertTrue(response["error"].retryable) + + def test_network_connectivity(self): + """Test handling of network connectivity errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = ConnectionError("Failed to connect") + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "network_error") + self.assertTrue(response["error"].retryable) + + def test_model_not_found(self): + """Test handling of non-existent model.""" + with patch('google.generativeai.GenerativeModel.__init__') as mock_init: + mock_init.side_effect = google_exceptions.NotFound("Model not found") + + with self.assertRaises(RuntimeError) as context: + GeminiChatModel(model_name="non-existent-model") + + self.assertIn("Failed to initialize", str(context.exception)) + self.assertIn("Model not found", str(context.exception)) + + def test_model_not_ready(self): + """Test handling of model not ready errors.""" + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + mock_chat.side_effect = google_exceptions.FailedPrecondition("Model not ready") + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "model_error") + self.assertTrue(response["error"].retryable) + + def test_invalid_temperature(self): + """Test handling of invalid temperature parameter.""" + with self.assertRaises(ValueError): + GeminiChatModel(temperature=2.0) # Temperature should be between 0 and 1 + + def test_empty_messages(self): + """Test handling of empty messages list.""" + response = self.model.generate_response([]) + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "invalid_input") + + def test_invalid_message_role(self): + """Test handling of invalid message role.""" + messages = [{"role": "invalid_role", "content": "Hello"}] + response = self.model.generate_response(messages) + self.assertEqual(response["content"], "") + self.assertIsInstance(response["error"], ModelError) + self.assertEqual(response["error"].type, "invalid_input") + + def test_fallback_on_error(self): + """Test fallback behavior when primary model fails.""" + # Mock primary model to fail + mock_response = MagicMock() + mock_response.text = "Fallback response" + mock_response.prompt_feedback = None + mock_response.usage_metadata = MagicMock( + total_token_count=5, + prompt_token_count=2, + candidates_token_count=3 + ) + + with patch('google.generativeai.GenerativeModel.start_chat') as mock_chat: + # First call fails, second call (fallback) succeeds + mock_chat_instance = MagicMock() + mock_chat_instance.send_message.side_effect = [ + google_exceptions.InternalServerError("Primary model failed"), + mock_response + ] + mock_chat.return_value = mock_chat_instance + + response = self.model.generate_response(self.test_messages) + + self.assertEqual(response["content"], "Fallback response") + self.assertIsNone(response["error"]) + self.assertEqual(response["total_tokens"], 5) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_model_fallback.py b/tests/test_model_fallback.py new file mode 100644 index 0000000..25989f5 --- /dev/null +++ b/tests/test_model_fallback.py @@ -0,0 +1,146 @@ +import unittest +from unittest.mock import patch, MagicMock +from google.api_core import exceptions as google_exceptions + +from wandbot.chat.models import GeminiChatModel +from mock_models import MockOpenAIModel +from wandbot.chat.models.base import ModelError + +class TestModelFallback(unittest.TestCase): + def setUp(self): + # Create a Gemini model with OpenAI fallback + self.fallback_model = MockOpenAIModel("gpt-4") + self.primary_model = GeminiChatModel( + model_name="gemini-pro", + fallback_model=self.fallback_model + ) + self.test_messages = [ + {"role": "user", "content": "Hello"} + ] + + def test_fallback_on_retryable_error(self): + """Test fallback to OpenAI when Gemini has a retryable error.""" + # Mock Gemini to fail with a retryable error + mock_gemini_error = google_exceptions.InternalServerError("Gemini server error") + + # Mock OpenAI to succeed + mock_openai_response = { + "content": "Fallback response from OpenAI", + "total_tokens": 10, + "prompt_tokens": 5, + "completion_tokens": 5, + "error": None, + "model_used": "gpt-4" + } + + # Setup mocks + with patch('google.generativeai.GenerativeModel.start_chat') as mock_gemini: + mock_gemini.side_effect = mock_gemini_error + + with patch.object(self.fallback_model, 'generate_response') as mock_openai: + mock_openai.return_value = mock_openai_response + + # Make the request + response = self.primary_model.generate_response(self.test_messages) + + # Verify fallback was used + self.assertEqual(response["content"], "Fallback response from OpenAI") + self.assertEqual(response["model_used"], "gpt-4") + self.assertIsNone(response["error"]) + + def test_no_fallback_on_non_retryable_error(self): + """Test that non-retryable errors don't trigger fallback.""" + # Mock Gemini to fail with a non-retryable error + mock_gemini_error = google_exceptions.PermissionDenied("Invalid API key") + + # Setup mocks + with patch('google.generativeai.GenerativeModel.start_chat') as mock_gemini: + mock_gemini.side_effect = mock_gemini_error + + with patch.object(self.fallback_model, 'generate_response') as mock_openai: + # Make the request + response = self.primary_model.generate_response(self.test_messages) + + # Verify fallback was not used + self.assertEqual(response["content"], "") + self.assertEqual(response["model_used"], "gemini-pro") + self.assertEqual(response["error"].type, "auth_error") + mock_openai.assert_not_called() + + def test_fallback_chain(self): + """Test chaining multiple fallbacks.""" + # Create a chain of fallbacks + final_fallback = MockOpenAIModel("gpt-3.5-turbo") + middle_fallback = GeminiChatModel( + model_name="gemini-1.0-pro", + fallback_model=final_fallback + ) + primary_model = GeminiChatModel( + model_name="gemini-pro", + fallback_model=middle_fallback + ) + + # Mock responses/errors + mock_primary_error = google_exceptions.InternalServerError("Primary error") + mock_middle_error = google_exceptions.InternalServerError("Middle error") + mock_final_response = { + "content": "Response from final fallback", + "total_tokens": 5, + "prompt_tokens": 2, + "completion_tokens": 3, + "error": None, + "model_used": "gpt-3.5-turbo" + } + + # Setup mocks + with patch('google.generativeai.GenerativeModel.start_chat') as mock_gemini: + mock_gemini.side_effect = [mock_primary_error, mock_middle_error] + + with patch.object(final_fallback, 'generate_response') as mock_final: + mock_final.return_value = mock_final_response + + # Make the request + response = primary_model.generate_response(self.test_messages) + + # Verify final fallback was used + self.assertEqual(response["content"], "Response from final fallback") + self.assertEqual(response["model_used"], "gpt-3.5-turbo") + self.assertIsNone(response["error"]) + + def test_fallback_failure_returns_original_error(self): + """Test that if fallback fails, we get the original error.""" + # Mock Gemini to fail with a retryable error + mock_gemini_error = google_exceptions.InternalServerError("Primary error") + + # Mock OpenAI to also fail + mock_openai_error = { + "content": "", + "total_tokens": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "error": ModelError( + type="server_error", + message="OpenAI server error", + retryable=True + ), + "model_used": "gpt-4" + } + + # Setup mocks + with patch('google.generativeai.GenerativeModel.start_chat') as mock_gemini: + mock_gemini.side_effect = mock_gemini_error + + with patch.object(self.fallback_model, 'generate_response') as mock_openai: + mock_openai.return_value = mock_openai_error + + # Make the request + response = self.primary_model.generate_response(self.test_messages) + + # Verify we got the original error + self.assertEqual(response["content"], "") + self.assertEqual(response["model_used"], "gemini-pro") + self.assertEqual(response["error"].type, "server_error") + self.assertIn("Primary error", response["error"].message) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..78313a9 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,160 @@ +import os +import unittest +import warnings +from typing import List, Dict, Any, Type + +import google.generativeai as genai +from anthropic import Anthropic +from openai import OpenAI +from pydantic import BaseModel, Field + +# Filter out specific protobuf warnings +warnings.filterwarnings( + "ignore", + message=".*PyType_Spec.*custom tp_new.*", + category=DeprecationWarning, + module="google._upb._message" +) + +from wandbot.chat.models import ( + OpenAIChatModel, + GeminiChatModel, + AnthropicChatModel, + ChatModel, +) + +class ModelResponse(BaseModel): + """Standardized model response structure that all models must follow.""" + content: str + total_tokens: int + prompt_tokens: int + completion_tokens: int + +class TestChatModels(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Test messages that will be used across all providers + cls.test_messages: List[Dict[str, str]] = [ + { + "role": "system", + "content": "You are a helpful AI assistant. Keep responses very brief." + }, + { + "role": "user", + "content": "What is 2+2? Answer in one word." + } + ] + + # Define test cases for each model + cls.model_test_cases = [ + { + "name": "OpenAI", + "model_class": OpenAIChatModel, + "model_name": "gpt-4-0125-preview", + }, + { + "name": "Gemini", + "model_class": GeminiChatModel, + "model_name": "gemini-pro", + }, + { + "name": "Anthropic", + "model_class": AnthropicChatModel, + "model_name": "claude-3-opus-20240229", + } + ] + + def validate_response(self, response: Dict[str, Any], model_name: str): + """Validate response structure and data types.""" + try: + # Validate using Pydantic model + validated_response = ModelResponse(**response) + + # Additional checks + self.assertIsInstance(validated_response.content, str, + f"{model_name}: content should be string") + self.assertGreater(len(validated_response.content.strip()), 0, + f"{model_name}: content should not be empty") + + self.assertIsInstance(validated_response.total_tokens, int, + f"{model_name}: total_tokens should be integer") + self.assertGreater(validated_response.total_tokens, 0, + f"{model_name}: total_tokens should be positive") + + self.assertIsInstance(validated_response.prompt_tokens, int, + f"{model_name}: prompt_tokens should be integer") + self.assertGreater(validated_response.prompt_tokens, 0, + f"{model_name}: prompt_tokens should be positive") + + self.assertIsInstance(validated_response.completion_tokens, int, + f"{model_name}: completion_tokens should be integer") + self.assertGreater(validated_response.completion_tokens, 0, + f"{model_name}: completion_tokens should be positive") + + # Verify token counts add up + self.assertEqual( + validated_response.total_tokens, + validated_response.prompt_tokens + validated_response.completion_tokens, + f"{model_name}: total_tokens should equal prompt_tokens + completion_tokens" + ) + + return validated_response + + except Exception as e: + self.fail(f"{model_name} response validation failed: {str(e)}") + + def test_model_interface(self): + """Test that all models implement the required interface.""" + for test_case in self.model_test_cases: + model_class = test_case["model_class"] + + # Check that model class inherits from ChatModel + self.assertTrue( + issubclass(model_class, ChatModel), + f"{test_case['name']}: Model class should inherit from ChatModel" + ) + + # Check required methods and properties exist + model = model_class(model_name=test_case["model_name"]) + self.assertTrue( + hasattr(model, "generate_response"), + f"{test_case['name']}: Model should have generate_response method" + ) + self.assertTrue( + hasattr(model, "system_role_key"), + f"{test_case['name']}: Model should have system_role_key property" + ) + + def test_all_models(self): + """Test all models with the same input and validate consistent output structure.""" + for test_case in self.model_test_cases: + with self.subTest(model=test_case["name"]): + print(f"\nTesting {test_case['name']} model...") + + # Initialize model + model = test_case["model_class"]( + model_name=test_case["model_name"], + temperature=0.1 + ) + + # Test system_role_key property + self.assertEqual( + model.system_role_key, + "system", + f"{test_case['name']}: system_role_key should be 'system'" + ) + + # Test generate_response + response = model.generate_response(self.test_messages) + + # Validate response structure and types + validated_response = self.validate_response(response, test_case["name"]) + + # Print response details + print(f"{test_case['name']} Response: {validated_response.content}") + print(f"Token usage: {validated_response.total_tokens} total, " + f"{validated_response.prompt_tokens} prompt, " + f"{validated_response.completion_tokens} completion") + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file