Skip to content

Commit

Permalink
add llm connector changes
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Feb 7, 2025
1 parent e4a7ecc commit 47aba0d
Show file tree
Hide file tree
Showing 7 changed files with 704 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def __init__(
)

self._is_hosted = self.base_url in KNOWN_URLS
if not is_hosted:
base_url = self._validate_url(base_url)

if is_hosted and api_key == "NO_API_KEY_PROVIDED":
warnings.warn(
"An API key is required for the hosted NIM. This will become an error in 0.2.0.",
)
if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
if api_key == "NO_API_KEY_PROVIDED":
raise ValueError("An API key is required for hosted NIM.")
Expand Down Expand Up @@ -177,8 +184,8 @@ def _validate_url(self, base_url):
"""
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings, /completions, /rankings,
or /reranking, emit a warning. old documentation told users to pass in the full
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
Expand Down Expand Up @@ -221,7 +228,7 @@ def _validate_model(self, model_name: str) -> None:
if model_name not in [model.id for model in self.available_models]:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
f"Please check `{self.class_name()}.available_models`."
)
if model and model.endpoint:
self.base_url = model.endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,48 @@
from deprecated import deprecated
import warnings
import json
import os

from llama_index.core.bridge.pydantic import PrivateAttr, BaseModel
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.base.llms.generic_utils import (
get_from_param_or_env,
)

from llama_index.llms.nvidia.utils import (
is_nvidia_function_calling_model,
is_chat_model,
ALL_MODELS,
)

from llama_index.llms.openai_like import OpenAILike
from llama_index.core.llms.function_calling import FunctionCallingLLM
from urllib.parse import urlparse, urlunparse

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
MessageRole,
)

from llama_index.core.llms.llm import ToolSelection
from .utils import (
BASE_URL,
DEFAULT_MODEL,
Model,
determine_model,
)

if TYPE_CHECKING:
from llama_index.core.tools.types import BaseTool

DEFAULT_MODEL = "meta/llama3-8b-instruct"
BASE_URL = "https://integrate.api.nvidia.com/v1/"

KNOWN_URLS = [
BASE_URL,
"https://integrate.api.nvidia.com/v1",
]


class Model(BaseModel):
id: str
base_model: Optional[str]
is_function_calling_model: Optional[bool] = False
is_chat_model: Optional[bool] = False


class NVIDIA(OpenAILike, FunctionCallingLLM):
"""NVIDIA's API Catalog Connector."""

_is_hosted: bool = PrivateAttr(True)
_mode: str = PrivateAttr(default="nvidia")
_client: Any = PrivateAttr()
_aclient: Any = PrivateAttr()
_is_hosted: bool = PrivateAttr(True)

def __init__(
self,
model: Optional[str] = None,
nvidia_api_key: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = BASE_URL,
base_url: Optional[str] = os.getenv("NVIDIA_BASE_URL", BASE_URL),
max_tokens: Optional[int] = 1024,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -87,11 +74,12 @@ def __init__(
"NO_API_KEY_PROVIDED",
)

is_hosted = base_url in KNOWN_URLS
if base_url not in KNOWN_URLS:
base_url = base_url or BASE_URL
self._is_hosted = base_url == BASE_URL
if not self._is_hosted:
base_url = self._validate_url(base_url)

if is_hosted and api_key == "NO_API_KEY_PROVIDED":
if self._is_hosted and api_key == "NO_API_KEY_PROVIDED":
warnings.warn(
"An API key is required for the hosted NIM. This will become an error in 0.2.0.",
)
Expand All @@ -100,12 +88,11 @@ def __init__(
api_key=api_key,
api_base=base_url,
max_tokens=max_tokens,
is_chat_model=is_chat_model(model),
is_chat_model=self._is_chat_model(model),
default_headers={"User-Agent": "llama-index-llms-nvidia"},
is_function_calling_model=is_nvidia_function_calling_model(model),
is_function_calling_model=self._is_function_calling_model(model),
**kwargs,
)
self._is_hosted = base_url in KNOWN_URLS

if self._is_hosted and api_key == "NO_API_KEY_PROVIDED":
warnings.warn(
Expand Down Expand Up @@ -144,24 +131,35 @@ def __get_default_model(self):

def _validate_url(self, base_url):
"""
Base URL Validation.
ValueError : url which do not have valid scheme and netloc.
Warning : v1/chat/completions routes.
ValueError : Any other routes other than above.
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /completions, /chat/completions,
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
expected_format = "Expected format is 'http://host:port'."
result = urlparse(base_url)
if not (result.scheme and result.netloc):
raise ValueError(f"Invalid base_url, {expected_format}")
if result.path:
normalized_path = result.path.strip("/")
if normalized_path == "v1":
pass
elif normalized_path == "v1/chat/completions":
warnings.warn(f"{expected_format} Rest is Ignored.")
else:
raise ValueError(f"Invalid base_url, {expected_format}")
return urlunparse((result.scheme, result.netloc, "v1", "", "", ""))
if base_url is not None:
parsed = urlparse(base_url)

# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)

normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"

base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
return base_url

def _validate_model(self, model_name: str) -> None:
"""
Expand All @@ -174,29 +172,28 @@ def _validate_model(self, model_name: str) -> None:
ValueError: If the model is incompatible with the client.
"""
if self._is_hosted:
if model_name not in ALL_MODELS:
if model_name in [model.id for model in self.available_models]:
warnings.warn(f"Unable to determine validity of {model_name}")
else:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
model = determine_model(model_name)
if model_name not in [model.id for model in self.available_models]:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models`."
)
if model and model.endpoint:
self.base_url = model.endpoint
else:
if model_name not in [model.id for model in self.available_models]:
raise ValueError(f"No locally hosted {model_name} was found.")

@property
def available_models(self) -> List[Model]:
models = [
Model(
id=model.id,
base_model=getattr(model, "params", {}).get("root", None),
is_function_calling_model=is_nvidia_function_calling_model(model.id),
is_chat_model=is_chat_model(model.id),
)
for model in self._get_client().models.list().data
]
models = []
for element in self._get_client().models.list().data:
if not (model := determine_model(element["id"])):
model = Model(
id=model.id,
base_model=getattr(model, "params", {}).get("root", None),
)
models.append(model)
# only exclude models in hosted mode. in non-hosted mode, the administrator has control
# over the model name and may deploy an excluded name that will work.
if self._is_hosted:
Expand Down Expand Up @@ -248,7 +245,11 @@ def mode(

@property
def _is_chat_model(self) -> bool:
return is_chat_model(self.model)
return model.supports_tools if (model := determine_model(self.model)) else False

@property
def _is_function_calling_model(self) -> bool:
return model.supports_tools if (model := determine_model(self.model)) else False

def _prepare_chat_with_tools(
self,
Expand Down
Loading

0 comments on commit 47aba0d

Please sign in to comment.