Skip to content

Commit

Permalink
✨ mistralai
Browse files Browse the repository at this point in the history
  • Loading branch information
juftin committed Jan 7, 2024
1 parent 92e0161 commit 6f1fdc9
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 48 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ pipx install "llm-term[anthropic]"
llm-term --provider anthropic
```

### MistralAI

[MistralAI](https://mistral.ai/) is a European LLM provider. You can request
access to the MistralAI [here](https://console.mistral.ai/). The default model is
`mistral-small`, and you can use the `MISTRAL_API_KEY` environment variable.

```shell
pipx install "llm-term[mistralai]"
```

```shell
llm-term --provider mistralai
```

### GPT4All

GPT4All is a an open source LLM provider. These models run locally on your
Expand Down
14 changes: 14 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ pipx install "llm-term[anthropic]"
llm-term --provider anthropic
```

### MistralAI

[MistralAI](https://mistral.ai/) is a European LLM provider. You can request
access to the MistralAI [here](https://console.mistral.ai/). The default model is
`mistral-small`, and you can use the `MISTRAL_API_KEY` environment variable.

```shell
pipx install "llm-term[mistralai]"
```

```shell
llm-term --provider mistralai
```

### GPT4All

GPT4All is a an open source LLM provider. These models run locally on your
Expand Down
15 changes: 2 additions & 13 deletions llm_term/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from rich.console import Console

from llm_term.__about__ import __application__, __version__
from llm_term.utils import (
chat_session,
get_llm,
print_header,
setup_system_message,
)
from llm_term.utils import chat_session, get_llm, print_header, providers, setup_system_message

rich.traceback.install(show_locals=True)

Expand All @@ -37,13 +32,7 @@
envvar="LLM_PROVIDER",
show_envvar=True,
default="openai",
type=click.Choice(
[
"openai",
"anthropic",
"gpt4all",
]
),
type=click.Choice(providers),
)
@click.argument(
"chat",
Expand Down
28 changes: 26 additions & 2 deletions llm_term/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from typing import Iterator

from click.exceptions import ClickException
from langchain.chat_models import ChatAnthropic, ChatOpenAI
from langchain.llms import GPT4All
from langchain.llms.base import BaseLLM
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -46,19 +44,45 @@ def print_header(console: Console, model: str) -> None:
console.print("")


providers: list[str] = [
"openai",
"anthropic",
"gpt4all",
"mistralai",
]


def get_llm(provider: str, api_key: str, model: str | None) -> tuple[BaseChatModel | BaseLLM, str]:
"""
Check the credentials
"""
if provider == "openai":
from langchain_openai import ChatOpenAI

chat_model = model or "gpt-3.5-turbo"
return ChatOpenAI(openai_api_key=api_key, model_name=chat_model), chat_model
elif provider == "anthropic":
from langchain_community.chat_models import ChatAnthropic

chat_model = model or "claude-2.1"
return ChatAnthropic(anthropic_api_key=api_key, model_name=chat_model), chat_model
elif provider == "gpt4all":
from langchain_community.llms import GPT4All

chat_model = model or "mistral-7b-openorca.Q4_0.gguf"
return GPT4All(model=chat_model, allow_download=True), chat_model
elif provider == "mistralai":
try:
from langchain_mistralai import ChatMistralAI

chat_model = model or "mistral-small"
return ChatMistralAI(mistral_api_key=api_key, model=chat_model), chat_model
except ImportError as ie:
msg = (
"The `mistralai` provider requires the `mistralai` extra to be installed: "
'pipx install "llm-term[mistralai]"'
)
raise ClickException(msg) from ie
else:
msg = f"Provider {provider} is not supported... yet"
raise ClickException(msg)
Expand Down
19 changes: 6 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ classifiers = [
]
dependencies = [
"click~=8.1.7",
"openai~=1.3.6",
"rich~=13.5.3",
"prompt-toolkit~=3.0.39",
"langchain~=0.0.343",
"langchain~=0.1.0",
"langchain-community~=0.0.9",
"langchain-openai~=0.0.2",
"numpy~=1.24.4; python_version < '3.9'",
"numpy~=1.26.1; python_version > '3.8'"
]
Expand All @@ -36,9 +37,10 @@ readme = "README.md"
requires-python = ">=3.8,<4"

[project.optional-dependencies]
all = ["anthropic~=0.7.7", "gpt4all~=2.0.2"]
all = ["anthropic~=0.7.7", "gpt4all~=2.0.2", "langchain-mistralai~=0.0.1"]
anthropic = ["anthropic~=0.7.7"]
gpt4all = ["gpt4all~=2.0.2"]
mistralai = ["langchain-mistralai~=0.0.1"]

[project.scripts]
llm-term = "llm_term.cli:cli"
Expand Down Expand Up @@ -154,17 +156,8 @@ typing = "mypy --install-types --non-interactive {args:llm_term tests}"
[tool.hatch.version]
path = "llm_term/__about__.py"

[[tool.mypy.overrides]]
[tool.mypy]
ignore_missing_imports = true
module = [
"pytest.*",
"rich.*",
"openai.*",
"prompt_toolkit.*",
"click.*",
"langchain.*",
"langchain_core.*"
]

[tool.ruff]
ignore = [
Expand Down
Loading

0 comments on commit 6f1fdc9

Please sign in to comment.