Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding API to fetch tokenizer config for model #1052

Merged
merged 21 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
707e96c
Adding default chat_template api
kumar-shivam-ranjan Feb 3, 2025
30592c8
Merge branch 'main' into feature/default-chat-template-api
kumar-shivam-ranjan Feb 3, 2025
3cc7e8a
Adding chat template api
kumar-shivam-ranjan Feb 3, 2025
41a305c
Adding chat template api
kumar-shivam-ranjan Feb 3, 2025
b95aeb6
Adding chat template api
kumar-shivam-ranjan Feb 3, 2025
74601c0
Merge branch 'main' into feature/default-chat-template-api
kumar-shivam-ranjan Feb 4, 2025
8997594
Addressing review comments
kumar-shivam-ranjan Feb 4, 2025
dbed1dd
Formatting
kumar-shivam-ranjan Feb 4, 2025
aa0ee81
Merge branch 'main' into feature/default-chat-template-api
kumar-shivam-ranjan Feb 4, 2025
d6b728e
Updating path
kumar-shivam-ranjan Feb 4, 2025
43f94b0
Merge branch 'feature/default-chat-template-api' of https://github.co…
kumar-shivam-ranjan Feb 4, 2025
035c4ee
Resolving conflicts
kumar-shivam-ranjan Feb 4, 2025
a5cb572
Addressing review comments
kumar-shivam-ranjan Feb 5, 2025
766b201
Merge branch 'main' into feature/default-chat-template-api
kumar-shivam-ranjan Feb 5, 2025
8356379
Fixing tests
kumar-shivam-ranjan Feb 5, 2025
e449f0a
Addressing review comments
kumar-shivam-ranjan Feb 5, 2025
ef2c40e
Addressing review comments
kumar-shivam-ranjan Feb 5, 2025
88fa1cc
Addressing review comments
kumar-shivam-ranjan Feb 5, 2025
c66bdd6
Addressing review comments
kumar-shivam-ranjan Feb 6, 2025
a2f0ebe
Adding model_id in MD response
kumar-shivam-ranjan Feb 6, 2025
be96c70
Fixing UTs
kumar-shivam-ranjan Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import os
import traceback
from dataclasses import fields
from typing import Dict, Union
from typing import Dict, Optional, Union

import oci
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails

from ads import set_auth
from ads.aqua import logger
from ads.aqua.common.enums import Tags
from ads.aqua.common.enums import ConfigFolder, Tags
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import (
_is_valid_mvs,
Expand Down Expand Up @@ -268,7 +268,12 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
logger.info(f"Artifact not found in model {model_id}.")
return False

def get_config(self, model_id: str, config_file_name: str) -> Dict:
def get_config(
self,
model_id: str,
config_file_name: str,
config_folder: Optional[str] = ConfigFolder.CONFIG,
) -> Dict:
"""Gets the config for the given Aqua model.

Parameters
Expand All @@ -277,12 +282,17 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
The OCID of the Aqua model.
config_file_name: str
name of the config file
config_folder: (str, optional):
subfolder path where config_file_name needs to be searched
Defaults to `ConfigFolder.CONFIG`.
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`

Returns
-------
Dict:
A dict of allowed configs.
"""
config_folder = config_folder or ConfigFolder.CONFIG
oci_model = self.ds_client.get_model(model_id).data
oci_aqua = (
(
Expand All @@ -304,22 +314,25 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
f"Base model found for the model: {oci_model.id}. "
f"Loading {config_file_name} for base model {base_model_ocid}."
)
base_model = self.ds_client.get_model(base_model_ocid).data
artifact_path = get_artifact_path(base_model.custom_metadata_list)
if config_folder == ConfigFolder.ARTIFACT:
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
else:
base_model = self.ds_client.get_model(base_model_ocid).data
artifact_path = get_artifact_path(base_model.custom_metadata_list)
else:
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
artifact_path = get_artifact_path(oci_model.custom_metadata_list)

if not artifact_path:
logger.debug(
f"Failed to get artifact path from custom metadata for the model: {model_id}"
)
return config

config_path = f"{os.path.dirname(artifact_path)}/config/"
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
if not is_path_exists(config_path):
config_path = f"{artifact_path.rstrip('/')}/config/"

config_path = os.path.join(artifact_path.rstrip("/"), config_folder)
if not is_path_exists(config_path):
config_path = f"{artifact_path.rstrip('/')}/"
config_file_path = f"{config_path}{config_file_name}"
if is_path_exists(config_file_path):
try:
Expand Down
5 changes: 5 additions & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,8 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):

MODEL_ID = "model-id"
PORT = "port"


class ConfigFolder(ExtendedEnum):
CONFIG = "config"
ARTIFACT = "artifact"
1 change: 1 addition & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
README = "README.md"
LICENSE_TXT = "config/LICENSE.txt"
DEPLOYMENT_CONFIG = "deployment_config.json"
AQUA_MODEL_TOKENIZER_CONFIG = "tokenizer_config.json"
COMPARTMENT_MAPPING_KEY = "service-model-compartment"
CONTAINER_INDEX = "container_index.json"
EVALUATION_REPORT_JSON = "report.json"
Expand Down
19 changes: 19 additions & 0 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import (
get_hf_model_info,
is_valid_ocid,
list_hf_models,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
Expand Down Expand Up @@ -316,8 +317,26 @@ def post(self, *args, **kwargs): # noqa: ARG002
)


class AquaModelTokenizerConfigHandler(AquaAPIhandler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the pydocs to the class?

Handles requests for retrieving the Hugging Face tokenizer configuration 
    of a specified model.
    
    Expected request format:
        GET /aqua/models/<model-ocid>/tokenizer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

def get(self, model_id):
url_parse = urlparse(self.request.path)
paths = url_parse.path.strip("/")
path_list = paths.split("/")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Looks like we only use path_list further in this method, then probably can do

path_list = urlparse(self.request.path).path.strip("/").split("/")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
# path_list=['aqua','models','<model-ocid>','tokenizer']
if (
len(path_list) == 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is 4 and 3 in this code? :) Let's move them either to constants or add some meaningful comments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The request path here is: /aqua/models/ocid1.iad.ahdxxx/tokenizer
path_list=['aqua','models','ocid1.iad.ahdxxx','tokenizer']

Added comments

and is_valid_ocid(path_list[2])
and path_list[3] == "tokenizer"
):
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: looks like else is not required here. I think it would be more clear to do something like this:

 is_valid_path = (
    len(path_list) == 4
    and path_list[3] == "tokenizer"
    and is_valid_ocid(path_list[2])
)

if not is_valid_path:
       raise HTTPError(400, f"The request {self.request.path} is invalid.")

return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

raise HTTPError(400, f"The request {self.request.path} is invalid.")


__handlers__ = [
("model/?([^/]*)", AquaModelHandler),
("model/?([^/]*)/license", AquaModelLicenseHandler),
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
]
22 changes: 22 additions & 0 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
from ads.aqua.app import AquaApp
from ads.aqua.common.enums import (
ConfigFolder,
CustomInferenceContainerTypeFamily,
FineTuningContainerTypeFamily,
InferenceContainerTypeFamily,
Expand Down Expand Up @@ -44,6 +45,7 @@
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
AQUA_MODEL_ARTIFACT_FILE,
AQUA_MODEL_TOKENIZER_CONFIG,
AQUA_MODEL_TYPE_CUSTOM,
HF_METADATA_FOLDER,
LICENSE_TXT,
Expand Down Expand Up @@ -568,6 +570,26 @@ def _build_ft_metrics(
training_final,
]

def get_hf_tokenizer_config(self, model_id):
"""Gets the default chat template for the given Aqua model.

Parameters
----------
model_id: str
The OCID of the Aqua model.

Returns
-------
str:
Chat template string.
"""
config = self.get_config(
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
)
if not config:
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
return config

@staticmethod
def to_aqua_model(
model: Union[
Expand Down
38 changes: 37 additions & 1 deletion tests/unitary/with_extras/aqua/test_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from huggingface_hub.hf_api import HfApi, ModelInfo
from huggingface_hub.utils import GatedRepoError
from notebook.base.handlers import IPythonHandler
from notebook.base.handlers import IPythonHandler, HTTPError
from parameterized import parameterized

from ads.aqua.common.errors import AquaRuntimeError
Expand All @@ -18,6 +18,7 @@
AquaHuggingFaceHandler,
AquaModelHandler,
AquaModelLicenseHandler,
AquaModelTokenizerConfigHandler,
)
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
Expand Down Expand Up @@ -250,6 +251,41 @@ def test_get(self, mock_load_license):
mock_load_license.assert_called_with("test_model_id")


class ModelTokenizerConfigHandlerTestCase(TestCase):
@patch.object(IPythonHandler, "__init__")
def setUp(self, ipython_init_mock) -> None:
ipython_init_mock.return_value = None
self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler(
MagicMock(), MagicMock()
)
self.model_tokenizer_config_handler.finish = MagicMock()
self.model_tokenizer_config_handler.request = MagicMock()

@patch.object(AquaModelApp, "get_hf_tokenizer_config")
@patch("ads.aqua.extension.model_handler.urlparse")
def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config):
request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer")
mock_urlparse.return_value = request_path
self.model_tokenizer_config_handler.get(model_id="test_model_id")
self.model_tokenizer_config_handler.finish.assert_called_with(
mock_get_hf_tokenizer_config.return_value
)
mock_get_hf_tokenizer_config.assert_called_with("test_model_id")

@patch.object(AquaModelApp, "get_hf_tokenizer_config")
@patch("ads.aqua.extension.model_handler.urlparse")
def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config):
"""Test invalid request path should raise HTTPError(400)"""
request_path = MagicMock(path="/invalid/path")
mock_urlparse.return_value = request_path

with self.assertRaises(HTTPError) as context:
self.model_tokenizer_config_handler.get(model_id="test_model_id")
self.assertEqual(context.exception.status_code, 400)
self.model_tokenizer_config_handler.finish.assert_not_called()
mock_get_hf_tokenizer_config.assert_not_called()


class TestAquaHuggingFaceHandler:
def setup_method(self):
with patch.object(IPythonHandler, "__init__"):
Expand Down