From 73bf3edf094ab963af69aa07c01c4932617562d7 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sat, 11 Jan 2025 22:30:03 -0800 Subject: [PATCH 01/22] fix: standardize inference with vllm --- .../text-generation/inference_api.py | 147 +++++------------- .../tests/test_inference_api.py | 43 +---- 2 files changed, 46 insertions(+), 144 deletions(-) diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index 23de55487..d5781a88d 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -116,14 +116,13 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]: model_args = asdict(args) model_args["local_files_only"] = not model_args.pop('allow_remote_files') -model_pipeline = model_args.pop('pipeline') combination_type = model_args.pop('combination_type') app = FastAPI() -resovled_chat_template = load_chat_template(model_args.pop('chat_template')) +resolved_chat_template = load_chat_template(model_args.pop('chat_template')) tokenizer = AutoTokenizer.from_pretrained(**model_args) -if resovled_chat_template is not None: - tokenizer.chat_template = resovled_chat_template +if resolved_chat_template is not None: + tokenizer.chat_template = resolved_chat_template base_model = AutoModelForCausalLM.from_pretrained(**model_args) if not os.path.exists(ADAPTERS_DIR): @@ -177,12 +176,16 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]: if args.torch_dtype: pipeline_kwargs["torch_dtype"] = args.torch_dtype -pipeline = transformers.pipeline( - task="text-generation", - model=model, - tokenizer=tokenizer, - **pipeline_kwargs -) +try: + pipeline = transformers.pipeline( + task="text-generation", + model=model, + tokenizer=tokenizer, + **pipeline_kwargs + ) +except Exception as e: + logger.critical(f"Failed to initialize the pipeline: {e}") + raise RuntimeError("Pipeline initialization failed. Check logs for details.") try: # Attempt to load the generation configuration @@ -240,54 +243,42 @@ class HealthStatus(BaseModel): def health_check(): if not model: logger.error("Model not initialized") - raise HTTPException(status_code=500, detail="Model not initialized") + raise HTTPException(status_code=500, detail="Model loading failed. Check configuration or weights.") if not pipeline: logger.error("Pipeline not initialized") - raise HTTPException(status_code=500, detail="Pipeline not initialized") + raise HTTPException(status_code=500, detail="Pipeline setup failed. Check transformer configurations.") return {"status": "Healthy"} -class GenerateKwargs(BaseModel): - max_length: int = 200 # Length of input prompt+max_new_tokens +class UnifiedRequestModel(BaseModel): + # Fields for text generation + prompt: Optional[str] = Field(None, description="Prompt for text generation. Required for text-generation pipeline.") + return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") + clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") + prefix: Optional[str] = Field(None, description="Prefix added to prompt") + handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") + + # Common Generation parameters (formerly in generate_kwargs) + max_length: int = 200 # Length of input prompt + max_new_tokens min_length: int = 0 do_sample: bool = True - early_stopping: bool = False num_beams: int = 1 temperature: float = 1.0 top_k: int = 10 top_p: float = 1 - typical_p: float = 1 - repetition_penalty: float = 1 - pad_token_id: Optional[int] = tokenizer.pad_token_id - eos_token_id: Optional[int] = tokenizer.eos_token_id + class Config: - extra = 'allow' # Allows for additional fields not explicitly defined + extra = "allow" # Allows for additional, unspecified fields json_schema_extra = { "example": { + "prompt": "Tell me a joke", "max_length": 200, "temperature": 0.7, "top_p": 0.9, + "return_full_text": True, + "clean_up_tokenization_spaces": False, "additional_param": "Example value" } } - -class Message(BaseModel): - role: str - content: str - -class UnifiedRequestModel(BaseModel): - # Fields for text generation - prompt: Optional[str] = Field(None, description="Prompt for text generation. Required for text-generation pipeline. Do not use with 'messages'.") - return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") - clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") - prefix: Optional[str] = Field(None, description="Prefix added to prompt") - handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") - generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") - - # Field for conversational model - messages: Optional[List[Message]] = Field(None, description="Messages for conversational model. Required for conversational pipeline. Do not use with 'prompt'.") - def messages_to_dict_list(self): - return [message.dict() for message in self.messages] if self.messages else [] - class ErrorResponse(BaseModel): detail: str @@ -305,12 +296,6 @@ class ErrorResponse(BaseModel): "value": { "Result": "Generated text based on the prompt." } - }, - "conversation": { - "summary": "Conversation Response", - "value": { - "Result": "Response to the last message in the conversation." - } } } } @@ -325,10 +310,6 @@ class ErrorResponse(BaseModel): "missing_prompt": { "summary": "Missing Prompt", "value": {"detail": "Text generation parameter prompt required"} - }, - "missing_messages": { - "summary": "Missing Messages", - "value": {"detail": "Conversational parameter messages required"} } } } @@ -354,23 +335,7 @@ def generate_text( "clean_up_tokenization_spaces": False, "prefix": None, "handle_long_generation": None, - "generate_kwargs": GenerateKwargs().dict(), - }, - }, - "conversation_example": { - "summary": "Conversation Example", - "description": "An example of a conversational request.", - "value": { - "messages": [ - {"role": "user", "content": "What is your favourite condiment?"}, - {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, - {"role": "user", "content": "Do you have mayonnaise recipes?"} - ], - "return_full_text": True, - "clean_up_tokenization_spaces": False, - "prefix": None, - "handle_long_generation": None, - "generate_kwargs": GenerateKwargs().dict(), + "temperature": 1.0, }, }, }, @@ -378,49 +343,21 @@ def generate_text( ], ): """ - Processes chat requests, generating text based on the specified pipeline (text generation or conversational). - Validates required parameters based on the pipeline and returns the generated text. + Processes chat requests, generating text based on the specified pipeline. """ - user_generate_kwargs = request_model.generate_kwargs.dict() if request_model.generate_kwargs else {} - generate_kwargs = {**default_generate_config, **user_generate_kwargs} - - if args.pipeline == "text-generation": - if not request_model.prompt: - logger.error("Text generation parameter prompt required") - raise HTTPException(status_code=400, detail="Text generation parameter prompt required") - sequences = pipeline( - request_model.prompt, - # return_tensors=request_model.return_tensors, - # return_text=request_model.return_text, - return_full_text=request_model.return_full_text, - clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, - prefix=request_model.prefix, - handle_long_generation=request_model.handle_long_generation, - **generate_kwargs - ) + user_generate_kwargs = request_model.dict(exclude_unset=True) - result = "" - for seq in sequences: - logger.debug(f"Result: {seq['generated_text']}") - result += seq['generated_text'] + if not request_model.prompt: + logger.error("Text generation parameter prompt required") + raise HTTPException(status_code=400, detail="Text generation parameter prompt required") + try: + sequences = pipeline(request_model.prompt, **user_generate_kwargs) + result = "".join(seq["generated_text"] for seq in sequences) return {"Result": result} - - elif args.pipeline == "conversational": - if not request_model.messages: - logger.error("Conversational parameter messages required") - raise HTTPException(status_code=400, detail="Conversational parameter messages required") - - response = pipeline( - request_model.messages_to_dict_list(), - clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, - **generate_kwargs - ) - return {"Result": str(response[-1])} - - else: - logger.error("Invalid pipeline type") - raise HTTPException(status_code=400, detail="Invalid pipeline type") + except Exception as e: + logger.error(f"Error during text generation: {e}") + raise HTTPException(status_code=500, detail="Error during text generation") class MemoryInfo(BaseModel): used: str diff --git a/presets/workspace/inference/text-generation/tests/test_inference_api.py b/presets/workspace/inference/text-generation/tests/test_inference_api.py index 480f1b480..9b27be9e8 100644 --- a/presets/workspace/inference/text-generation/tests/test_inference_api.py +++ b/presets/workspace/inference/text-generation/tests/test_inference_api.py @@ -18,7 +18,6 @@ @pytest.fixture(params=[ {"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, - {"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, ]) def configured_app(request): original_argv = sys.argv.copy() @@ -43,37 +42,6 @@ def configured_app(request): sys.argv = original_argv -def test_conversational(configured_app): - if configured_app.test_config['pipeline'] != 'conversational': - pytest.skip("Skipping non-conversational tests") - client = TestClient(configured_app) - messages = [ - {"role": "user", "content": "What is your favourite condiment?"}, - {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, - {"role": "user", "content": "Do you have mayonnaise recipes?"} - ] - request_data = { - "messages": messages, - "generate_kwargs": {"max_new_tokens": 20, "do_sample": True} - } - response = client.post("/chat", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert "Result" in data - assert len(data["Result"]) > 0 # Check if the conversation result is not empty - -def test_missing_messages_for_conversation(configured_app): - if configured_app.test_config['pipeline'] != 'conversational': - pytest.skip("Skipping non-conversational tests") - client = TestClient(configured_app) - request_data = { - # "messages" is missing for conversational pipeline - } - response = client.post("/chat", json=request_data) - assert response.status_code == 400 # Expecting a Bad Request response due to missing messages - assert "Conversational parameter messages required" in response.json().get("detail", "") - def test_text_generation(configured_app): if configured_app.test_config['pipeline'] != 'text-generation': pytest.skip("Skipping non-text-generation tests") @@ -82,7 +50,8 @@ def test_text_generation(configured_app): "prompt": "Hello, world!", "return_full_text": True, "clean_up_tokenization_spaces": False, - "generate_kwargs": {"max_length": 50, "min_length": 10} # Example generate_kwargs + "max_length": 50, + "temperature": 0.7 } response = client.post("/chat", json=request_data) assert response.status_code == 200 @@ -98,7 +67,7 @@ def test_missing_prompt(configured_app): # "prompt" is missing "return_full_text": True, "clean_up_tokenization_spaces": False, - "generate_kwargs": {"max_length": 50} + "max_length": 50 } response = client.post("/chat", json=request_data) assert response.status_code == 400 # Expecting a Bad Request response due to missing prompt @@ -195,7 +164,6 @@ def test_default_generation_params(configured_app): "prompt": "Test default params", "return_full_text": True, "clean_up_tokenization_spaces": False - # Note: generate_kwargs is not provided, so defaults should be used } with patch('inference_api.pipeline') as mock_pipeline: @@ -213,13 +181,10 @@ def test_default_generation_params(configured_app): assert kwargs['max_length'] == 200 assert kwargs['min_length'] == 0 assert kwargs['do_sample'] is True + assert kwargs['num_beams'] == 1 assert kwargs['temperature'] == 1.0 assert kwargs['top_k'] == 10 assert kwargs['top_p'] == 1 - assert kwargs['typical_p'] == 1 - assert kwargs['repetition_penalty'] == 1 - assert kwargs['num_beams'] == 1 - assert kwargs['early_stopping'] is False def test_generation_with_max_length(configured_app): if configured_app.test_config['pipeline'] != 'text-generation': From 0425a775f562d1e9638d74a6d5a5edbb89c3d996 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sat, 11 Jan 2025 22:34:20 -0800 Subject: [PATCH 02/22] fix: standardize inference with vllm --- .../inference/text-generation/inference_api.py | 4 ++-- .../text-generation/tests/test_inference_api.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index d5781a88d..51e8bcf19 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -283,7 +283,7 @@ class ErrorResponse(BaseModel): detail: str @app.post( - "/chat", + "/v1/completions", summary="Chat Endpoint", responses={ 200: { @@ -381,7 +381,7 @@ class MetricsResponse(BaseModel): cpu_info: Optional[CPUInfo] = None @app.get( - "/metrics", + "/v1/metrics", response_model=MetricsResponse, summary="Metrics Endpoint", responses={ diff --git a/presets/workspace/inference/text-generation/tests/test_inference_api.py b/presets/workspace/inference/text-generation/tests/test_inference_api.py index 9b27be9e8..cb4ca7c43 100644 --- a/presets/workspace/inference/text-generation/tests/test_inference_api.py +++ b/presets/workspace/inference/text-generation/tests/test_inference_api.py @@ -53,7 +53,7 @@ def test_text_generation(configured_app): "max_length": 50, "temperature": 0.7 } - response = client.post("/chat", json=request_data) + response = client.post("/v1/completions", json=request_data) assert response.status_code == 200 data = response.json() assert "Result" in data @@ -69,7 +69,7 @@ def test_missing_prompt(configured_app): "clean_up_tokenization_spaces": False, "max_length": 50 } - response = client.post("/chat", json=request_data) + response = client.post("/v1/completions", json=request_data) assert response.status_code == 400 # Expecting a Bad Request response due to missing prompt assert "Text generation parameter prompt required" in response.json().get("detail", "") @@ -87,7 +87,7 @@ def test_health_check(configured_app): def test_get_metrics(configured_app): client = TestClient(configured_app) - response = client.get("/metrics") + response = client.get("/v1/metrics") assert response.status_code == 200 assert "gpu_info" in response.json() @@ -117,7 +117,7 @@ def __init__(self, id, name, load, temperature, memoryUsed, memoryTotal): # Mock GPUtil.getGPUs to return a list containing the mock GPU object with patch('torch.cuda.is_available', return_value=True), \ patch('GPUtil.getGPUs', return_value=[mock_gpu]): - response = client.get("/metrics") + response = client.get("/v1/metrics") assert response.status_code == 200 data = response.json() @@ -143,7 +143,7 @@ def test_get_metrics_no_gpus(configured_app): patch('psutil.virtual_memory') as mock_virtual_memory: mock_virtual_memory.return_value.used = 4 * (1024 ** 3) # 4 GB mock_virtual_memory.return_value.total = 16 * (1024 ** 3) # 16 GB - response = client.get("/metrics") + response = client.get("/v1/metrics") assert response.status_code == 200 data = response.json() assert data["gpu_info"] is None # No GPUs available @@ -169,7 +169,7 @@ def test_default_generation_params(configured_app): with patch('inference_api.pipeline') as mock_pipeline: mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function - response = client.post("/chat", json=request_data) + response = client.post("/v1/completions", json=request_data) assert response.status_code == 200 data = response.json() @@ -202,7 +202,7 @@ def test_generation_with_max_length(configured_app): "generate_kwargs": {"max_length": max_length} } - response = client.post("/chat", json=request_data) + response = client.post("/v1/completions", json=request_data) assert response.status_code == 200 data = response.json() @@ -237,7 +237,7 @@ def test_generation_with_min_length(configured_app): "generate_kwargs": {"min_length": min_length, "max_length": max_length} } - response = client.post("/chat", json=request_data) + response = client.post("/v1/completions", json=request_data) assert response.status_code == 200 data = response.json() From ab3db034902125fdf2a1eee359395641f5cda1d9 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sat, 11 Jan 2025 22:58:06 -0800 Subject: [PATCH 03/22] feat: Updates models format --- presets/ragengine/inference/inference.py | 87 +++++++++++++++++++----- presets/ragengine/main.py | 14 ++-- presets/ragengine/models.py | 32 +++++++-- 3 files changed, 108 insertions(+), 25 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index f48248463..54a7c75cf 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -2,17 +2,20 @@ # Licensed under the MIT license. from typing import Any +from dataclasses import field from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen from llama_index.llms.openai import OpenAI from llama_index.core.llms.callbacks import llm_completion_callback import requests +from urllib.parse import urlparse, urljoin from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD OPENAI_URL_PREFIX = "https://api.openai.com" HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co" class Inference(CustomLLM): - params: dict = {} + params: dict = field(default_factory=dict) + _default_model: str = None def set_params(self, params: dict) -> None: self.params = params @@ -38,30 +41,80 @@ def complete(self, prompt: str, **kwargs) -> CompletionResponse: self.params = {} def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - llm = OpenAI( - api_key=LLM_ACCESS_SECRET, - **kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc. - ) - return llm.complete(prompt) + return OpenAI(api_key=LLM_ACCESS_SECRET, **kwargs).complete(prompt) def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} - data = {"messages": [{"role": "user", "content": prompt}]} - response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) - response_data = response.json() - return CompletionResponse(text=str(response_data)) + return self._post_request( + {"messages": [{"role": "user", "content": prompt}]}, + headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + ) def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + model = kwargs.pop("model", self._get_default_model()) data = {"prompt": prompt, **kwargs} + if model: + data["model"] = model # Include the model only if it is not None + + # For Debugging Purposes + # import json + # # Construct curl command + # curl_command = ( + # f"curl -X POST {LLM_INFERENCE_URL} " + # + " ".join([f'-H "{key}: {value}"' for key, value in {"Authorization": f"Bearer {LLM_ACCESS_SECRET}", "Content-Type": "application/json"}.items()]) + # + f" -d '{json.dumps(data)}'" + # ) + # print("Equivalent curl command:") + # print(curl_command) + + return self._post_request( + data, + headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", "Content-Type": "application/json"} + ) - response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) - response_data = response.json() + def _get_models_endpoint(self) -> str: + """ + Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL. + """ + parsed = urlparse(LLM_INFERENCE_URL) + return urljoin(f"{parsed.scheme}://{parsed.netloc}", "/v1/models") - # Dynamically extract the field from the response based on the specified response_field - # completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now - return CompletionResponse(text=str(response_data)) + def _fetch_default_model(self) -> str: + """ + Fetch the default model from the /v1/models endpoint. + """ + try: + models_url = self._get_models_endpoint() + headers = { + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" + } + + response = requests.get(models_url, headers=headers) + response.raise_for_status() # Raise an exception for HTTP errors + + models = response.json().get("data", []) + return models[0].get("id") if models else None + except Exception as e: + print(f"Error fetching default model from {models_url}: {e}") + return None + def _get_default_model(self) -> str: + """ + Returns the cached default model if available; otherwise fetches and caches it. + """ + if not self._default_model: + self._default_model = self._fetch_default_model() + return self._default_model + + def _post_request(self, data: dict, headers: dict) -> CompletionResponse: + try: + response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) + response.raise_for_status() # Raise exception for HTTP errors + response_data = response.json() + return CompletionResponse(text=str(response_data)) + except requests.RequestException as e: + print(f"Error during POST request to {LLM_INFERENCE_URL}: {e}") + raise @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" diff --git a/presets/ragengine/main.py b/presets/ragengine/main.py index 56f891178..fd2d26efc 100644 --- a/presets/ragengine/main.py +++ b/presets/ragengine/main.py @@ -60,11 +60,17 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh @app.post("/query", response_model=QueryResponse) async def query_index(request: QueryRequest): try: - llm_params = request.llm_params or {} # Default to empty dict if no params provided - rerank_params = request.rerank_params or {} # Default to empty dict if no params provided - return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params) + llm_params = request.llm_params or {} # Default to empty dict if no params provided + rerank_params = request.rerank_params or {} # Default to empty dict if no params provided + return rag_ops.query( + request.index_name, request.query, request.top_k, llm_params, rerank_params + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) # Validation issue except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, detail=f"An unexpected error occurred: {str(e)}" + ) @app.get("/indexed-documents", response_model=ListDocumentsResponse) async def list_all_indexed_documents(): diff --git a/presets/ragengine/models.py b/presets/ragengine/models.py index a1b2ff529..d9e7d1f60 100644 --- a/presets/ragengine/models.py +++ b/presets/ragengine/models.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, root_validator, ValidationError class Document(BaseModel): text: str @@ -22,8 +22,32 @@ class QueryRequest(BaseModel): index_name: str query: str top_k: int = 10 - llm_params: Optional[Dict] = None # Accept a dictionary for parameters - rerank_params: Optional[Dict] = None # Accept a dictionary for parameters + # Accept a dictionary for our LLM parameters + llm_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Optional parameters for the language model, e.g., temperature, top_p", + ) + # Accept a dictionary for rerank parameters + rerank_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Optional parameters for reranking, e.g., top_n, batch_size", + ) + + @root_validator(pre=True) + def validate_params(cls, values): + llm_params = values.get("llm_params", {}) + rerank_params = values.get("rerank_params", {}) + + # Validate LLM parameters + if "temperature" in llm_params and not (0.0 <= llm_params["temperature"] <= 1.0): + raise ValueError("Temperature must be between 0.0 and 1.0.") + + # Validate rerank parameters + top_k = values["top_k"] + if "top_n" in rerank_params and rerank_params["top_n"] > top_k: + raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.") + + return values class ListDocumentsResponse(BaseModel): documents: Dict[str, Dict[str, Dict[str, str]]] From 4ebb5a3d331e16551e791c4809fd69614cb578aa Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 02:03:51 -0800 Subject: [PATCH 04/22] fix: nit param --- presets/ragengine/models.py | 9 +++++---- .../workspace/inference/text-generation/inference_api.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/presets/ragengine/models.py b/presets/ragengine/models.py index d9e7d1f60..5b18e6175 100644 --- a/presets/ragengine/models.py +++ b/presets/ragengine/models.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, root_validator, ValidationError +from pydantic import BaseModel, Field, model_validator + class Document(BaseModel): text: str @@ -33,8 +34,8 @@ class QueryRequest(BaseModel): description="Optional parameters for reranking, e.g., top_n, batch_size", ) - @root_validator(pre=True) - def validate_params(cls, values): + @model_validator(mode="before") + def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: llm_params = values.get("llm_params", {}) rerank_params = values.get("rerank_params", {}) @@ -43,7 +44,7 @@ def validate_params(cls, values): raise ValueError("Temperature must be between 0.0 and 1.0.") # Validate rerank parameters - top_k = values["top_k"] + top_k = values.get("top_k") if "top_n" in rerank_params and rerank_params["top_n"] > top_k: raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.") diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index 51e8bcf19..11b690e0d 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -36,7 +36,6 @@ class ModelConfig: """ Transformers Model Configuration Parameters """ - pipeline: Optional[str] = field(default="text-generation", metadata={"help": "The model pipeline for the pre-trained model"}) pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"}) combination_type: Optional[str]=field(default="svd", metadata={"help": "The combination type of multi adapters"}) state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"}) From 8e6acf98d256541bbf2afdcbd4716f409f71d84a Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 02:15:31 -0800 Subject: [PATCH 05/22] fix: nit param --- .../inference/text-generation/tests/test_inference_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/presets/workspace/inference/text-generation/tests/test_inference_api.py b/presets/workspace/inference/text-generation/tests/test_inference_api.py index cb4ca7c43..f1f033530 100644 --- a/presets/workspace/inference/text-generation/tests/test_inference_api.py +++ b/presets/workspace/inference/text-generation/tests/test_inference_api.py @@ -24,7 +24,6 @@ def configured_app(request): # Use request.param to set correct test arguments for each configuration test_args = [ 'program_name', - '--pipeline', request.param['pipeline'], '--pretrained_model_name_or_path', request.param['model_path'], '--device_map', request.param['device'], '--allow_remote_files', 'True', From 54c1165cd04ef5b79fe911f95c34a572d4e799fb Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 02:23:30 -0800 Subject: [PATCH 06/22] fix: nit param --- presets/ragengine/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/ragengine/models.py b/presets/ragengine/models.py index 5b18e6175..a1e89a21f 100644 --- a/presets/ragengine/models.py +++ b/presets/ragengine/models.py @@ -42,7 +42,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: # Validate LLM parameters if "temperature" in llm_params and not (0.0 <= llm_params["temperature"] <= 1.0): raise ValueError("Temperature must be between 0.0 and 1.0.") - + # TODO: More LLM Param Validations here # Validate rerank parameters top_k = values.get("top_k") if "top_n" in rerank_params and rerank_params["top_n"] > top_k: From 77490d0f69ce8305bf07703a508f0d109dce44cd Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 03:06:21 -0800 Subject: [PATCH 07/22] fix: Test cases --- .../text-generation/inference_api.py | 10 +++---- .../tests/test_inference_api.py | 29 ++++++++----------- .../tests/test_model_config.py | 14 ++------- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index 11b690e0d..9b4cdf4a2 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -86,10 +86,6 @@ def __post_init__(self): # validate parameters else: self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None - supported_pipelines = {"conversational", "text-generation"} - if self.pipeline not in supported_pipelines: - raise ValueError(f"Unsupported pipeline: {self.pipeline}") - def load_chat_template(chat_template: Optional[str]) -> Optional[str]: logger.info(chat_template) if chat_template is None: @@ -342,11 +338,13 @@ def generate_text( ], ): """ - Processes chat requests, generating text based on the specified pipeline. + Processes chat requests, generating text based on the specified text-generation pipeline. """ user_generate_kwargs = request_model.dict(exclude_unset=True) - if not request_model.prompt: + # Extract the prompt separately and remove it from model kwargs + prompt = user_generate_kwargs.pop("prompt", None) + if not prompt: logger.error("Text generation parameter prompt required") raise HTTPException(status_code=400, detail="Text generation parameter prompt required") diff --git a/presets/workspace/inference/text-generation/tests/test_inference_api.py b/presets/workspace/inference/text-generation/tests/test_inference_api.py index f1f033530..cf1eb91c5 100644 --- a/presets/workspace/inference/text-generation/tests/test_inference_api.py +++ b/presets/workspace/inference/text-generation/tests/test_inference_api.py @@ -17,7 +17,7 @@ "{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}") @pytest.fixture(params=[ - {"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, + {"model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, ]) def configured_app(request): original_argv = sys.argv.copy() @@ -42,8 +42,6 @@ def configured_app(request): sys.argv = original_argv def test_text_generation(configured_app): - if configured_app.test_config['pipeline'] != 'text-generation': - pytest.skip("Skipping non-text-generation tests") client = TestClient(configured_app) request_data = { "prompt": "Hello, world!", @@ -59,8 +57,6 @@ def test_text_generation(configured_app): assert len(data["Result"]) > 0 # Check if the result text is not empty def test_missing_prompt(configured_app): - if configured_app.test_config['pipeline'] != 'text-generation': - pytest.skip("Skipping non-text-generation tests") client = TestClient(configured_app) request_data = { # "prompt" is missing @@ -154,15 +150,19 @@ def test_get_metrics_no_gpus(configured_app): assert data["cpu_info"]["memory"]["total"] == "16.00 GB" def test_default_generation_params(configured_app): - if configured_app.test_config['pipeline'] != 'text-generation': - pytest.skip("Skipping non-text-generation tests") - client = TestClient(configured_app) request_data = { "prompt": "Test default params", "return_full_text": True, - "clean_up_tokenization_spaces": False + "clean_up_tokenization_spaces": False, + "max_length": 200, + "min_length": 0, + "do_sample": True, + "num_beams": 1, + "temperature": 1.0, + "top_k": 10, + "top_p": 1, } with patch('inference_api.pipeline') as mock_pipeline: @@ -186,9 +186,6 @@ def test_default_generation_params(configured_app): assert kwargs['top_p'] == 1 def test_generation_with_max_length(configured_app): - if configured_app.test_config['pipeline'] != 'text-generation': - pytest.skip("Skipping non-text-generation tests") - client = TestClient(configured_app) prompt = "This prompt requests a response of a certain minimum length to test the functionality." avg_res_len = 15 @@ -198,7 +195,7 @@ def test_generation_with_max_length(configured_app): "prompt": prompt, "return_full_text": True, "clean_up_tokenization_spaces": False, - "generate_kwargs": {"max_length": max_length} + "max_length": max_length, } response = client.post("/v1/completions", json=request_data) @@ -221,9 +218,6 @@ def test_generation_with_max_length(configured_app): assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length" def test_generation_with_min_length(configured_app): - if configured_app.test_config['pipeline'] != 'text-generation': - pytest.skip("Skipping non-text-generation tests") - client = TestClient(configured_app) prompt = "This prompt requests a response of a certain minimum length to test the functionality." min_length = 30 @@ -233,7 +227,8 @@ def test_generation_with_min_length(configured_app): "prompt": prompt, "return_full_text": True, "clean_up_tokenization_spaces": False, - "generate_kwargs": {"min_length": min_length, "max_length": max_length} + "min_length": min_length, + "max_length": max_length, } response = client.post("/v1/completions", json=request_data) diff --git a/presets/workspace/inference/text-generation/tests/test_model_config.py b/presets/workspace/inference/text-generation/tests/test_model_config.py index df5b98e8d..718890e9c 100644 --- a/presets/workspace/inference/text-generation/tests/test_model_config.py +++ b/presets/workspace/inference/text-generation/tests/test_model_config.py @@ -10,15 +10,13 @@ sys.path.append(parent_dir) @pytest.fixture(params=[ - {"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21"}, - {"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21"}, + {"model_path": "stanford-crfm/alias-gpt2-small-x21"}, ]) def configured_model_config(request): original_argv = sys.argv.copy() sys.argv = [ 'program_name', - '--pipeline', request.param['pipeline'], '--pretrained_model_name_or_path', request.param['model_path'], '--allow_remote_files', 'True' ] @@ -29,7 +27,6 @@ def configured_model_config(request): # Create and configure the ModelConfig instance model_config = ModelConfig( - pipeline=request.param['pipeline'], pretrained_model_name_or_path=request.param['model_path'], ) @@ -73,16 +70,9 @@ def test_ignore_double_dash_arguments(configured_model_config): assert getattr(config, "new_arg2", None) is True assert getattr(config, "new_arg3", None) == "correct_value" -# Test case to verify handling unsupported pipeline values -def test_unsupported_pipeline_raises_value_error(configured_model_config): - with pytest.raises(ValueError) as excinfo: - from inference_api import ModelConfig - ModelConfig(pipeline="unsupported_pipeline") - assert "Unsupported pipeline" in str(excinfo.value) - # Test case for validating torch_dtype def test_invalid_torch_dtype_raises_value_error(configured_model_config): with pytest.raises(ValueError) as excinfo: from inference_api import ModelConfig - ModelConfig(pipeline="text-generation", torch_dtype="unsupported_dtype") + ModelConfig(torch_dtype="unsupported_dtype") assert "Invalid torch dtype" in str(excinfo.value) \ No newline at end of file From b10cfc5df87e5652803d97773f427a71942c5406 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 03:24:27 -0800 Subject: [PATCH 08/22] fix: Test cases --- presets/ragengine/tests/api/test_main.py | 2 +- presets/ragengine/tests/vector_store/test_base_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/presets/ragengine/tests/api/test_main.py b/presets/ragengine/tests/api/test_main.py index fee67dd7b..102936d9d 100644 --- a/presets/ragengine/tests/api/test_main.py +++ b/presets/ragengine/tests/api/test_main.py @@ -175,7 +175,7 @@ def test_query_index_failure(): } response = client.post("/query", json=request_data) - assert response.status_code == 500 + assert response.status_code == 400 assert response.json()["detail"] == "No such index: 'non_existent_index' exists." diff --git a/presets/ragengine/tests/vector_store/test_base_store.py b/presets/ragengine/tests/vector_store/test_base_store.py index d3f49848f..88f0bedeb 100644 --- a/presets/ragengine/tests/vector_store/test_base_store.py +++ b/presets/ragengine/tests/vector_store/test_base_store.py @@ -90,7 +90,7 @@ def test_query_documents(self, mock_post, vector_store_manager): mock_post.assert_called_once_with( LLM_INFERENCE_URL, json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7}, - headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'} ) def test_add_document(self, vector_store_manager): From 6d5e0f2d5b49f805d440db75f12e5cd2b019eeb2 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 03:40:36 -0800 Subject: [PATCH 09/22] fix: remove pipeline param --- README.md | 5 ++--- .../custom-model-integration/custom-deployment-template.yaml | 2 -- .../custom-model-integration/reference-image-deployment.yaml | 2 -- presets/workspace/models/falcon/model.go | 1 - presets/workspace/models/mistral/model.go | 1 - presets/workspace/models/phi2/model.go | 1 - presets/workspace/models/phi3/model.go | 1 - presets/workspace/models/qwen/model.go | 1 - .../falcon-40b-instruct/falcon-40b-instruct_hf.yaml | 2 +- .../workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml | 2 +- .../test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml | 2 +- .../manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml | 2 +- presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml | 2 +- .../mistral-7b-instruct/mistral-7b-instruct_hf.yaml | 2 +- .../workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml | 2 +- presets/workspace/test/manifests/phi-2/phi-2_hf.yaml | 2 +- .../phi-3-medium-128k-instruct_hf.yaml | 2 +- .../phi-3-medium-4k-instruct_hf.yaml | 2 +- .../phi-3-mini-128k-instruct_hf.yaml | 2 +- .../phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml | 2 +- .../phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml | 2 +- .../phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml | 2 +- .../qwen2-5-coder-7b-instruct_hf.yaml | 2 +- 23 files changed, 17 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 54d208fab..d97a14c0e 100644 --- a/README.md +++ b/README.md @@ -146,18 +146,17 @@ Within the deployment specification, locate and modify the command field. #### Original ```sh -accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 +accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --torch_dtype bfloat16 ``` #### Modify to enable 4-bit Quantization ```sh -accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 --load_in_4bit +accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --torch_dtype bfloat16 --load_in_4bit ``` Currently, we allow users to change the following paramenters manually: -- `pipeline`: For text-generation models this can be either `text-generation` or `conversational`. - `load_in_4bit` or `load_in_8bit`: Model quantization resolution. Should you need to customize other parameters, kindly file an issue for potential future inclusion. diff --git a/docs/custom-model-integration/custom-deployment-template.yaml b/docs/custom-model-integration/custom-deployment-template.yaml index fe9c2c4ad..c369c9cdb 100644 --- a/docs/custom-model-integration/custom-deployment-template.yaml +++ b/docs/custom-model-integration/custom-deployment-template.yaml @@ -23,8 +23,6 @@ inference: - "--gpu_ids" - "all" - "tfs/inference_api.py" - - "--pipeline" - - "text-generation" - "--torch_dtype" - "float16" # Set to "float16" for compatibility with V100 GPUs; use "bfloat16" for A100, H100 or newer GPUs volumeMounts: diff --git a/docs/custom-model-integration/reference-image-deployment.yaml b/docs/custom-model-integration/reference-image-deployment.yaml index 3a77dba08..36d518638 100644 --- a/docs/custom-model-integration/reference-image-deployment.yaml +++ b/docs/custom-model-integration/reference-image-deployment.yaml @@ -23,8 +23,6 @@ inference: - "--gpu_ids" - "all" - "inference_api.py" - - "--pipeline" - - "text-generation" - "--trust_remote_code" - "--allow_remote_files" - "--pretrained_model_name_or_path" diff --git a/presets/workspace/models/falcon/model.go b/presets/workspace/models/falcon/model.go index 34aac7824..5cf07e221 100644 --- a/presets/workspace/models/falcon/model.go +++ b/presets/workspace/models/falcon/model.go @@ -48,7 +48,6 @@ var ( baseCommandPresetFalconTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" falconRunParams = map[string]string{ "torch_dtype": "bfloat16", - "pipeline": "text-generation", "chat_template": "/workspace/chat_templates/falcon-instruct.jinja", } falconRunParamsVLLM = map[string]string{ diff --git a/presets/workspace/models/mistral/model.go b/presets/workspace/models/mistral/model.go index b3b8497f0..54b2604e6 100644 --- a/presets/workspace/models/mistral/model.go +++ b/presets/workspace/models/mistral/model.go @@ -35,7 +35,6 @@ var ( baseCommandPresetMistralTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" mistralRunParams = map[string]string{ "torch_dtype": "bfloat16", - "pipeline": "text-generation", "chat_template": "/workspace/chat_templates/mistral-instruct.jinja", } mistralRunParamsVLLM = map[string]string{ diff --git a/presets/workspace/models/phi2/model.go b/presets/workspace/models/phi2/model.go index bb7989df9..8afa66f19 100644 --- a/presets/workspace/models/phi2/model.go +++ b/presets/workspace/models/phi2/model.go @@ -29,7 +29,6 @@ var ( baseCommandPresetPhiTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" phiRunParams = map[string]string{ "torch_dtype": "float16", - "pipeline": "text-generation", } phiRunParamsVLLM = map[string]string{ "dtype": "float16", diff --git a/presets/workspace/models/phi3/model.go b/presets/workspace/models/phi3/model.go index c8c40e4d1..84eb4d544 100644 --- a/presets/workspace/models/phi3/model.go +++ b/presets/workspace/models/phi3/model.go @@ -53,7 +53,6 @@ var ( baseCommandPresetPhiTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" phiRunParams = map[string]string{ "torch_dtype": "auto", - "pipeline": "text-generation", "trust_remote_code": "", } phiRunParamsVLLM = map[string]string{ diff --git a/presets/workspace/models/qwen/model.go b/presets/workspace/models/qwen/model.go index 20a09df74..f03dfdde0 100644 --- a/presets/workspace/models/qwen/model.go +++ b/presets/workspace/models/qwen/model.go @@ -29,7 +29,6 @@ var ( baseCommandPresetQwenTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" qwenRunParams = map[string]string{ "torch_dtype": "bfloat16", - "pipeline": "text-generation", } qwenRunParamsVLLM = map[string]string{ "dtype": "float16", diff --git a/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml b/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml index a44043894..25d8cd96a 100644 --- a/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml b/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml index 514d12e60..446d0b00c 100644 --- a/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml +++ b/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml b/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml index c48a1c2cf..7a02f0633 100644 --- a/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml +++ b/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml @@ -30,7 +30,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml b/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml index 1b2092b36..02b3bbb86 100644 --- a/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml b/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml index 56a775fff..36a97b2a8 100644 --- a/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml +++ b/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml b/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml index 75179683f..9980fcf42 100644 --- a/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml b/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml index 3eff5594f..7b810a353 100644 --- a/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml +++ b/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml b/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml index cbc6f94e7..9d382d96d 100644 --- a/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml +++ b/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml index 0adb122e4..5be9d124f 100644 --- a/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml index 1d0d64e47..800b80886 100644 --- a/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml index cf8898015..5f3759534 100644 --- a/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml index 1d7069a38..fb7619d75 100644 --- a/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml b/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml index 1827155f4..680d4efbe 100644 --- a/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml +++ b/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml b/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml index 1f515cc6a..c693b97af 100644 --- a/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml +++ b/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml b/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml index e92d906d7..81f096b05 100644 --- a/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 2 From d8b061ea240bbcc0d262191f6fbe5005bd688eed Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 17:48:38 -0800 Subject: [PATCH 10/22] fix: Removing faulty kwarg --- presets/ragengine/config.py | 2 +- presets/ragengine/inference/inference.py | 35 ++++++++++++++++-------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/presets/ragengine/config.py b/presets/ragengine/config.py index 0eae6413b..8adde1fbd 100644 --- a/presets/ragengine/config.py +++ b/presets/ragengine/config.py @@ -38,7 +38,7 @@ """ # LLM (Large Language Model) configuration -LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/chat") +LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/v1/completions") LLM_ACCESS_SECRET = os.getenv("LLM_ACCESS_SECRET", "default-access-secret") # LLM_RESPONSE_FIELD = os.getenv("LLM_RESPONSE_FIELD", "result") # Uncomment if needed in the future diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index 54a7c75cf..9ef26ae46 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -29,6 +29,9 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: @llm_completion_callback() def complete(self, prompt: str, **kwargs) -> CompletionResponse: + # The `llm_completion_callback` from llama_index adds a `formatted` parameter by default. + # We remove it here as it is unnecessary and errors as an unrecognized param in downstream API calls. + kwargs.pop("formatted", None) try: if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX): return self._openai_complete(prompt, **kwargs, **self.params) @@ -55,16 +58,8 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse if model: data["model"] = model # Include the model only if it is not None - # For Debugging Purposes - # import json - # # Construct curl command - # curl_command = ( - # f"curl -X POST {LLM_INFERENCE_URL} " - # + " ".join([f'-H "{key}: {value}"' for key, value in {"Authorization": f"Bearer {LLM_ACCESS_SECRET}", "Content-Type": "application/json"}.items()]) - # + f" -d '{json.dumps(data)}'" - # ) - # print("Equivalent curl command:") - # print(curl_command) + # DEBUG: Call the debugging function + # self._debug_curl_command(data) return self._post_request( data, @@ -95,7 +90,7 @@ def _fetch_default_model(self) -> str: models = response.json().get("data", []) return models[0].get("id") if models else None except Exception as e: - print(f"Error fetching default model from {models_url}: {e}") + print(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.") return None def _get_default_model(self) -> str: @@ -115,6 +110,24 @@ def _post_request(self, data: dict, headers: dict) -> CompletionResponse: except requests.RequestException as e: print(f"Error during POST request to {LLM_INFERENCE_URL}: {e}") raise + + def _debug_curl_command(self, data: dict) -> None: + """ + Constructs and prints the equivalent curl command for debugging purposes. + """ + import json + # Construct curl command + curl_command = ( + f"curl -X POST {LLM_INFERENCE_URL} " + + " ".join([f'-H "{key}: {value}"' for key, value in { + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" + }.items()]) + + f" -d '{json.dumps(data)}'" + ) + print("Equivalent curl command:") + print(curl_command) + @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" From e552e98b4d23e9f6a20f16ea8ba506da5ebe175d Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 12 Jan 2025 18:46:38 -0800 Subject: [PATCH 11/22] feat: update function signature --- presets/ragengine/inference/inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index 9ef26ae46..9e2643dc1 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -28,10 +28,7 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: pass @llm_completion_callback() - def complete(self, prompt: str, **kwargs) -> CompletionResponse: - # The `llm_completion_callback` from llama_index adds a `formatted` parameter by default. - # We remove it here as it is unnecessary and errors as an unrecognized param in downstream API calls. - kwargs.pop("formatted", None) + def complete(self, prompt: str, formatted: bool, **kwargs) -> CompletionResponse: try: if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX): return self._openai_complete(prompt, **kwargs, **self.params) From 45ba8cb4c5483474e79d6991deecc19bb8e5bd4a Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 15:30:14 -0800 Subject: [PATCH 12/22] fix: seperate out changes --- README.md | 5 +- .../custom-deployment-template.yaml | 2 + .../reference-image-deployment.yaml | 2 + .../text-generation/inference_api.py | 158 +++++++++++++----- .../tests/test_inference_api.py | 89 +++++++--- .../tests/test_model_config.py | 14 +- presets/workspace/models/falcon/model.go | 1 + presets/workspace/models/mistral/model.go | 1 + presets/workspace/models/phi2/model.go | 1 + presets/workspace/models/phi3/model.go | 1 + presets/workspace/models/qwen/model.go | 1 + .../falcon-40b-instruct_hf.yaml | 2 +- .../manifests/falcon-40b/falcon-40b_hf.yaml | 2 +- .../falcon-7b-adapter/falcon-7b-adapter.yaml | 2 +- .../falcon-7b-instruct_hf.yaml | 2 +- .../manifests/falcon-7b/falcon-7b_hf.yaml | 2 +- .../mistral-7b-instruct_hf.yaml | 2 +- .../manifests/mistral-7b/mistral-7b_hf.yaml | 2 +- .../test/manifests/phi-2/phi-2_hf.yaml | 2 +- .../phi-3-medium-128k-instruct_hf.yaml | 2 +- .../phi-3-medium-4k-instruct_hf.yaml | 2 +- .../phi-3-mini-128k-instruct_hf.yaml | 2 +- .../phi-3-mini-4k-instruct_hf.yaml | 2 +- .../phi-3-small-128k-instruct.yaml | 2 +- .../phi-3-small-8k-instruct.yaml | 2 +- .../qwen2-5-coder-7b-instruct_hf.yaml | 2 +- 26 files changed, 216 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index d97a14c0e..54d208fab 100644 --- a/README.md +++ b/README.md @@ -146,17 +146,18 @@ Within the deployment specification, locate and modify the command field. #### Original ```sh -accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --torch_dtype bfloat16 +accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 ``` #### Modify to enable 4-bit Quantization ```sh -accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --torch_dtype bfloat16 --load_in_4bit +accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 --load_in_4bit ``` Currently, we allow users to change the following paramenters manually: +- `pipeline`: For text-generation models this can be either `text-generation` or `conversational`. - `load_in_4bit` or `load_in_8bit`: Model quantization resolution. Should you need to customize other parameters, kindly file an issue for potential future inclusion. diff --git a/docs/custom-model-integration/custom-deployment-template.yaml b/docs/custom-model-integration/custom-deployment-template.yaml index c369c9cdb..fe9c2c4ad 100644 --- a/docs/custom-model-integration/custom-deployment-template.yaml +++ b/docs/custom-model-integration/custom-deployment-template.yaml @@ -23,6 +23,8 @@ inference: - "--gpu_ids" - "all" - "tfs/inference_api.py" + - "--pipeline" + - "text-generation" - "--torch_dtype" - "float16" # Set to "float16" for compatibility with V100 GPUs; use "bfloat16" for A100, H100 or newer GPUs volumeMounts: diff --git a/docs/custom-model-integration/reference-image-deployment.yaml b/docs/custom-model-integration/reference-image-deployment.yaml index 36d518638..3a77dba08 100644 --- a/docs/custom-model-integration/reference-image-deployment.yaml +++ b/docs/custom-model-integration/reference-image-deployment.yaml @@ -23,6 +23,8 @@ inference: - "--gpu_ids" - "all" - "inference_api.py" + - "--pipeline" + - "text-generation" - "--trust_remote_code" - "--allow_remote_files" - "--pretrained_model_name_or_path" diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index 9b4cdf4a2..23de55487 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -36,6 +36,7 @@ class ModelConfig: """ Transformers Model Configuration Parameters """ + pipeline: Optional[str] = field(default="text-generation", metadata={"help": "The model pipeline for the pre-trained model"}) pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"}) combination_type: Optional[str]=field(default="svd", metadata={"help": "The combination type of multi adapters"}) state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"}) @@ -86,6 +87,10 @@ def __post_init__(self): # validate parameters else: self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None + supported_pipelines = {"conversational", "text-generation"} + if self.pipeline not in supported_pipelines: + raise ValueError(f"Unsupported pipeline: {self.pipeline}") + def load_chat_template(chat_template: Optional[str]) -> Optional[str]: logger.info(chat_template) if chat_template is None: @@ -111,13 +116,14 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]: model_args = asdict(args) model_args["local_files_only"] = not model_args.pop('allow_remote_files') +model_pipeline = model_args.pop('pipeline') combination_type = model_args.pop('combination_type') app = FastAPI() -resolved_chat_template = load_chat_template(model_args.pop('chat_template')) +resovled_chat_template = load_chat_template(model_args.pop('chat_template')) tokenizer = AutoTokenizer.from_pretrained(**model_args) -if resolved_chat_template is not None: - tokenizer.chat_template = resolved_chat_template +if resovled_chat_template is not None: + tokenizer.chat_template = resovled_chat_template base_model = AutoModelForCausalLM.from_pretrained(**model_args) if not os.path.exists(ADAPTERS_DIR): @@ -171,16 +177,12 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]: if args.torch_dtype: pipeline_kwargs["torch_dtype"] = args.torch_dtype -try: - pipeline = transformers.pipeline( - task="text-generation", - model=model, - tokenizer=tokenizer, - **pipeline_kwargs - ) -except Exception as e: - logger.critical(f"Failed to initialize the pipeline: {e}") - raise RuntimeError("Pipeline initialization failed. Check logs for details.") +pipeline = transformers.pipeline( + task="text-generation", + model=model, + tokenizer=tokenizer, + **pipeline_kwargs +) try: # Attempt to load the generation configuration @@ -238,47 +240,59 @@ class HealthStatus(BaseModel): def health_check(): if not model: logger.error("Model not initialized") - raise HTTPException(status_code=500, detail="Model loading failed. Check configuration or weights.") + raise HTTPException(status_code=500, detail="Model not initialized") if not pipeline: logger.error("Pipeline not initialized") - raise HTTPException(status_code=500, detail="Pipeline setup failed. Check transformer configurations.") + raise HTTPException(status_code=500, detail="Pipeline not initialized") return {"status": "Healthy"} -class UnifiedRequestModel(BaseModel): - # Fields for text generation - prompt: Optional[str] = Field(None, description="Prompt for text generation. Required for text-generation pipeline.") - return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") - clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") - prefix: Optional[str] = Field(None, description="Prefix added to prompt") - handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") - - # Common Generation parameters (formerly in generate_kwargs) - max_length: int = 200 # Length of input prompt + max_new_tokens +class GenerateKwargs(BaseModel): + max_length: int = 200 # Length of input prompt+max_new_tokens min_length: int = 0 do_sample: bool = True + early_stopping: bool = False num_beams: int = 1 temperature: float = 1.0 top_k: int = 10 top_p: float = 1 - + typical_p: float = 1 + repetition_penalty: float = 1 + pad_token_id: Optional[int] = tokenizer.pad_token_id + eos_token_id: Optional[int] = tokenizer.eos_token_id class Config: - extra = "allow" # Allows for additional, unspecified fields + extra = 'allow' # Allows for additional fields not explicitly defined json_schema_extra = { "example": { - "prompt": "Tell me a joke", "max_length": 200, "temperature": 0.7, "top_p": 0.9, - "return_full_text": True, - "clean_up_tokenization_spaces": False, "additional_param": "Example value" } } + +class Message(BaseModel): + role: str + content: str + +class UnifiedRequestModel(BaseModel): + # Fields for text generation + prompt: Optional[str] = Field(None, description="Prompt for text generation. Required for text-generation pipeline. Do not use with 'messages'.") + return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") + clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") + prefix: Optional[str] = Field(None, description="Prefix added to prompt") + handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") + generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") + + # Field for conversational model + messages: Optional[List[Message]] = Field(None, description="Messages for conversational model. Required for conversational pipeline. Do not use with 'prompt'.") + def messages_to_dict_list(self): + return [message.dict() for message in self.messages] if self.messages else [] + class ErrorResponse(BaseModel): detail: str @app.post( - "/v1/completions", + "/chat", summary="Chat Endpoint", responses={ 200: { @@ -291,6 +305,12 @@ class ErrorResponse(BaseModel): "value": { "Result": "Generated text based on the prompt." } + }, + "conversation": { + "summary": "Conversation Response", + "value": { + "Result": "Response to the last message in the conversation." + } } } } @@ -305,6 +325,10 @@ class ErrorResponse(BaseModel): "missing_prompt": { "summary": "Missing Prompt", "value": {"detail": "Text generation parameter prompt required"} + }, + "missing_messages": { + "summary": "Missing Messages", + "value": {"detail": "Conversational parameter messages required"} } } } @@ -330,7 +354,23 @@ def generate_text( "clean_up_tokenization_spaces": False, "prefix": None, "handle_long_generation": None, - "temperature": 1.0, + "generate_kwargs": GenerateKwargs().dict(), + }, + }, + "conversation_example": { + "summary": "Conversation Example", + "description": "An example of a conversational request.", + "value": { + "messages": [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} + ], + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "prefix": None, + "handle_long_generation": None, + "generate_kwargs": GenerateKwargs().dict(), }, }, }, @@ -338,23 +378,49 @@ def generate_text( ], ): """ - Processes chat requests, generating text based on the specified text-generation pipeline. + Processes chat requests, generating text based on the specified pipeline (text generation or conversational). + Validates required parameters based on the pipeline and returns the generated text. """ - user_generate_kwargs = request_model.dict(exclude_unset=True) + user_generate_kwargs = request_model.generate_kwargs.dict() if request_model.generate_kwargs else {} + generate_kwargs = {**default_generate_config, **user_generate_kwargs} + + if args.pipeline == "text-generation": + if not request_model.prompt: + logger.error("Text generation parameter prompt required") + raise HTTPException(status_code=400, detail="Text generation parameter prompt required") + sequences = pipeline( + request_model.prompt, + # return_tensors=request_model.return_tensors, + # return_text=request_model.return_text, + return_full_text=request_model.return_full_text, + clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, + prefix=request_model.prefix, + handle_long_generation=request_model.handle_long_generation, + **generate_kwargs + ) - # Extract the prompt separately and remove it from model kwargs - prompt = user_generate_kwargs.pop("prompt", None) - if not prompt: - logger.error("Text generation parameter prompt required") - raise HTTPException(status_code=400, detail="Text generation parameter prompt required") + result = "" + for seq in sequences: + logger.debug(f"Result: {seq['generated_text']}") + result += seq['generated_text'] - try: - sequences = pipeline(request_model.prompt, **user_generate_kwargs) - result = "".join(seq["generated_text"] for seq in sequences) return {"Result": result} - except Exception as e: - logger.error(f"Error during text generation: {e}") - raise HTTPException(status_code=500, detail="Error during text generation") + + elif args.pipeline == "conversational": + if not request_model.messages: + logger.error("Conversational parameter messages required") + raise HTTPException(status_code=400, detail="Conversational parameter messages required") + + response = pipeline( + request_model.messages_to_dict_list(), + clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, + **generate_kwargs + ) + return {"Result": str(response[-1])} + + else: + logger.error("Invalid pipeline type") + raise HTTPException(status_code=400, detail="Invalid pipeline type") class MemoryInfo(BaseModel): used: str @@ -378,7 +444,7 @@ class MetricsResponse(BaseModel): cpu_info: Optional[CPUInfo] = None @app.get( - "/v1/metrics", + "/metrics", response_model=MetricsResponse, summary="Metrics Endpoint", responses={ diff --git a/presets/workspace/inference/text-generation/tests/test_inference_api.py b/presets/workspace/inference/text-generation/tests/test_inference_api.py index cf1eb91c5..480f1b480 100644 --- a/presets/workspace/inference/text-generation/tests/test_inference_api.py +++ b/presets/workspace/inference/text-generation/tests/test_inference_api.py @@ -17,13 +17,15 @@ "{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}") @pytest.fixture(params=[ - {"model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, + {"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, + {"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"}, ]) def configured_app(request): original_argv = sys.argv.copy() # Use request.param to set correct test arguments for each configuration test_args = [ 'program_name', + '--pipeline', request.param['pipeline'], '--pretrained_model_name_or_path', request.param['model_path'], '--device_map', request.param['device'], '--allow_remote_files', 'True', @@ -41,30 +43,64 @@ def configured_app(request): sys.argv = original_argv +def test_conversational(configured_app): + if configured_app.test_config['pipeline'] != 'conversational': + pytest.skip("Skipping non-conversational tests") + client = TestClient(configured_app) + messages = [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} + ] + request_data = { + "messages": messages, + "generate_kwargs": {"max_new_tokens": 20, "do_sample": True} + } + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "Result" in data + assert len(data["Result"]) > 0 # Check if the conversation result is not empty + +def test_missing_messages_for_conversation(configured_app): + if configured_app.test_config['pipeline'] != 'conversational': + pytest.skip("Skipping non-conversational tests") + client = TestClient(configured_app) + request_data = { + # "messages" is missing for conversational pipeline + } + response = client.post("/chat", json=request_data) + assert response.status_code == 400 # Expecting a Bad Request response due to missing messages + assert "Conversational parameter messages required" in response.json().get("detail", "") + def test_text_generation(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") client = TestClient(configured_app) request_data = { "prompt": "Hello, world!", "return_full_text": True, "clean_up_tokenization_spaces": False, - "max_length": 50, - "temperature": 0.7 + "generate_kwargs": {"max_length": 50, "min_length": 10} # Example generate_kwargs } - response = client.post("/v1/completions", json=request_data) + response = client.post("/chat", json=request_data) assert response.status_code == 200 data = response.json() assert "Result" in data assert len(data["Result"]) > 0 # Check if the result text is not empty def test_missing_prompt(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") client = TestClient(configured_app) request_data = { # "prompt" is missing "return_full_text": True, "clean_up_tokenization_spaces": False, - "max_length": 50 + "generate_kwargs": {"max_length": 50} } - response = client.post("/v1/completions", json=request_data) + response = client.post("/chat", json=request_data) assert response.status_code == 400 # Expecting a Bad Request response due to missing prompt assert "Text generation parameter prompt required" in response.json().get("detail", "") @@ -82,7 +118,7 @@ def test_health_check(configured_app): def test_get_metrics(configured_app): client = TestClient(configured_app) - response = client.get("/v1/metrics") + response = client.get("/metrics") assert response.status_code == 200 assert "gpu_info" in response.json() @@ -112,7 +148,7 @@ def __init__(self, id, name, load, temperature, memoryUsed, memoryTotal): # Mock GPUtil.getGPUs to return a list containing the mock GPU object with patch('torch.cuda.is_available', return_value=True), \ patch('GPUtil.getGPUs', return_value=[mock_gpu]): - response = client.get("/v1/metrics") + response = client.get("/metrics") assert response.status_code == 200 data = response.json() @@ -138,7 +174,7 @@ def test_get_metrics_no_gpus(configured_app): patch('psutil.virtual_memory') as mock_virtual_memory: mock_virtual_memory.return_value.used = 4 * (1024 ** 3) # 4 GB mock_virtual_memory.return_value.total = 16 * (1024 ** 3) # 16 GB - response = client.get("/v1/metrics") + response = client.get("/metrics") assert response.status_code == 200 data = response.json() assert data["gpu_info"] is None # No GPUs available @@ -150,25 +186,22 @@ def test_get_metrics_no_gpus(configured_app): assert data["cpu_info"]["memory"]["total"] == "16.00 GB" def test_default_generation_params(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + client = TestClient(configured_app) request_data = { "prompt": "Test default params", "return_full_text": True, - "clean_up_tokenization_spaces": False, - "max_length": 200, - "min_length": 0, - "do_sample": True, - "num_beams": 1, - "temperature": 1.0, - "top_k": 10, - "top_p": 1, + "clean_up_tokenization_spaces": False + # Note: generate_kwargs is not provided, so defaults should be used } with patch('inference_api.pipeline') as mock_pipeline: mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function - response = client.post("/v1/completions", json=request_data) + response = client.post("/chat", json=request_data) assert response.status_code == 200 data = response.json() @@ -180,12 +213,18 @@ def test_default_generation_params(configured_app): assert kwargs['max_length'] == 200 assert kwargs['min_length'] == 0 assert kwargs['do_sample'] is True - assert kwargs['num_beams'] == 1 assert kwargs['temperature'] == 1.0 assert kwargs['top_k'] == 10 assert kwargs['top_p'] == 1 + assert kwargs['typical_p'] == 1 + assert kwargs['repetition_penalty'] == 1 + assert kwargs['num_beams'] == 1 + assert kwargs['early_stopping'] is False def test_generation_with_max_length(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + client = TestClient(configured_app) prompt = "This prompt requests a response of a certain minimum length to test the functionality." avg_res_len = 15 @@ -195,10 +234,10 @@ def test_generation_with_max_length(configured_app): "prompt": prompt, "return_full_text": True, "clean_up_tokenization_spaces": False, - "max_length": max_length, + "generate_kwargs": {"max_length": max_length} } - response = client.post("/v1/completions", json=request_data) + response = client.post("/chat", json=request_data) assert response.status_code == 200 data = response.json() @@ -218,6 +257,9 @@ def test_generation_with_max_length(configured_app): assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length" def test_generation_with_min_length(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + client = TestClient(configured_app) prompt = "This prompt requests a response of a certain minimum length to test the functionality." min_length = 30 @@ -227,11 +269,10 @@ def test_generation_with_min_length(configured_app): "prompt": prompt, "return_full_text": True, "clean_up_tokenization_spaces": False, - "min_length": min_length, - "max_length": max_length, + "generate_kwargs": {"min_length": min_length, "max_length": max_length} } - response = client.post("/v1/completions", json=request_data) + response = client.post("/chat", json=request_data) assert response.status_code == 200 data = response.json() diff --git a/presets/workspace/inference/text-generation/tests/test_model_config.py b/presets/workspace/inference/text-generation/tests/test_model_config.py index 718890e9c..df5b98e8d 100644 --- a/presets/workspace/inference/text-generation/tests/test_model_config.py +++ b/presets/workspace/inference/text-generation/tests/test_model_config.py @@ -10,13 +10,15 @@ sys.path.append(parent_dir) @pytest.fixture(params=[ - {"model_path": "stanford-crfm/alias-gpt2-small-x21"}, + {"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21"}, + {"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21"}, ]) def configured_model_config(request): original_argv = sys.argv.copy() sys.argv = [ 'program_name', + '--pipeline', request.param['pipeline'], '--pretrained_model_name_or_path', request.param['model_path'], '--allow_remote_files', 'True' ] @@ -27,6 +29,7 @@ def configured_model_config(request): # Create and configure the ModelConfig instance model_config = ModelConfig( + pipeline=request.param['pipeline'], pretrained_model_name_or_path=request.param['model_path'], ) @@ -70,9 +73,16 @@ def test_ignore_double_dash_arguments(configured_model_config): assert getattr(config, "new_arg2", None) is True assert getattr(config, "new_arg3", None) == "correct_value" +# Test case to verify handling unsupported pipeline values +def test_unsupported_pipeline_raises_value_error(configured_model_config): + with pytest.raises(ValueError) as excinfo: + from inference_api import ModelConfig + ModelConfig(pipeline="unsupported_pipeline") + assert "Unsupported pipeline" in str(excinfo.value) + # Test case for validating torch_dtype def test_invalid_torch_dtype_raises_value_error(configured_model_config): with pytest.raises(ValueError) as excinfo: from inference_api import ModelConfig - ModelConfig(torch_dtype="unsupported_dtype") + ModelConfig(pipeline="text-generation", torch_dtype="unsupported_dtype") assert "Invalid torch dtype" in str(excinfo.value) \ No newline at end of file diff --git a/presets/workspace/models/falcon/model.go b/presets/workspace/models/falcon/model.go index 5cf07e221..34aac7824 100644 --- a/presets/workspace/models/falcon/model.go +++ b/presets/workspace/models/falcon/model.go @@ -48,6 +48,7 @@ var ( baseCommandPresetFalconTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" falconRunParams = map[string]string{ "torch_dtype": "bfloat16", + "pipeline": "text-generation", "chat_template": "/workspace/chat_templates/falcon-instruct.jinja", } falconRunParamsVLLM = map[string]string{ diff --git a/presets/workspace/models/mistral/model.go b/presets/workspace/models/mistral/model.go index 54b2604e6..b3b8497f0 100644 --- a/presets/workspace/models/mistral/model.go +++ b/presets/workspace/models/mistral/model.go @@ -35,6 +35,7 @@ var ( baseCommandPresetMistralTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" mistralRunParams = map[string]string{ "torch_dtype": "bfloat16", + "pipeline": "text-generation", "chat_template": "/workspace/chat_templates/mistral-instruct.jinja", } mistralRunParamsVLLM = map[string]string{ diff --git a/presets/workspace/models/phi2/model.go b/presets/workspace/models/phi2/model.go index 8afa66f19..bb7989df9 100644 --- a/presets/workspace/models/phi2/model.go +++ b/presets/workspace/models/phi2/model.go @@ -29,6 +29,7 @@ var ( baseCommandPresetPhiTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" phiRunParams = map[string]string{ "torch_dtype": "float16", + "pipeline": "text-generation", } phiRunParamsVLLM = map[string]string{ "dtype": "float16", diff --git a/presets/workspace/models/phi3/model.go b/presets/workspace/models/phi3/model.go index 84eb4d544..c8c40e4d1 100644 --- a/presets/workspace/models/phi3/model.go +++ b/presets/workspace/models/phi3/model.go @@ -53,6 +53,7 @@ var ( baseCommandPresetPhiTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" phiRunParams = map[string]string{ "torch_dtype": "auto", + "pipeline": "text-generation", "trust_remote_code": "", } phiRunParamsVLLM = map[string]string{ diff --git a/presets/workspace/models/qwen/model.go b/presets/workspace/models/qwen/model.go index f03dfdde0..20a09df74 100644 --- a/presets/workspace/models/qwen/model.go +++ b/presets/workspace/models/qwen/model.go @@ -29,6 +29,7 @@ var ( baseCommandPresetQwenTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" qwenRunParams = map[string]string{ "torch_dtype": "bfloat16", + "pipeline": "text-generation", } qwenRunParamsVLLM = map[string]string{ "dtype": "float16", diff --git a/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml b/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml index 25d8cd96a..a44043894 100644 --- a/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/falcon-40b-instruct/falcon-40b-instruct_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml b/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml index 446d0b00c..514d12e60 100644 --- a/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml +++ b/presets/workspace/test/manifests/falcon-40b/falcon-40b_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml b/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml index 7a02f0633..c48a1c2cf 100644 --- a/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml +++ b/presets/workspace/test/manifests/falcon-7b-adapter/falcon-7b-adapter.yaml @@ -30,7 +30,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 2 diff --git a/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml b/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml index 02b3bbb86..1b2092b36 100644 --- a/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/falcon-7b-instruct/falcon-7b-instruct_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml b/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml index 36a97b2a8..56a775fff 100644 --- a/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml +++ b/presets/workspace/test/manifests/falcon-7b/falcon-7b_hf.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml b/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml index 9980fcf42..75179683f 100644 --- a/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/mistral-7b-instruct/mistral-7b-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml b/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml index 7b810a353..3eff5594f 100644 --- a/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml +++ b/presets/workspace/test/manifests/mistral-7b/mistral-7b_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml b/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml index 9d382d96d..cbc6f94e7 100644 --- a/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml +++ b/presets/workspace/test/manifests/phi-2/phi-2_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype bfloat16 + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype bfloat16 resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml index 5be9d124f..0adb122e4 100644 --- a/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-medium-128k-instruct/phi-3-medium-128k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml index 800b80886..1d0d64e47 100644 --- a/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-medium-4k-instruct/phi-3-medium-4k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml index 5f3759534..cf8898015 100644 --- a/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-mini-128k-instruct/phi-3-mini-128k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml b/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml index fb7619d75..1d7069a38 100644 --- a/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml +++ b/presets/workspace/test/manifests/phi-3-mini-4k-instruct/phi-3-mini-4k-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml b/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml index 680d4efbe..1827155f4 100644 --- a/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml +++ b/presets/workspace/test/manifests/phi-3-small-128k-instruct/phi-3-small-128k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml b/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml index c693b97af..1f515cc6a 100644 --- a/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml +++ b/presets/workspace/test/manifests/phi-3-small-8k-instruct/phi-3-small-8k-instruct.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 1 diff --git a/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml b/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml index 81f096b05..e92d906d7 100644 --- a/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml +++ b/presets/workspace/test/manifests/qwen2-5-coder-7b-instruct/qwen2-5-coder-7b-instruct_hf.yaml @@ -18,7 +18,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --torch_dtype auto --trust_remote_code + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all /workspace/tfs/inference_api.py --pipeline text-generation --torch_dtype auto --trust_remote_code resources: requests: nvidia.com/gpu: 2 From fc6f0bab054ef04dbeba899e5ab79d39558ec96c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 19:26:36 -0800 Subject: [PATCH 13/22] fix: Better handling of vllm vs non-vllm --- presets/ragengine/inference/inference.py | 41 ++++++++++++++++++------ 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index 9e2643dc1..bcff51f1b 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import logging from typing import Any from dataclasses import field from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen @@ -10,12 +11,17 @@ from urllib.parse import urlparse, urljoin from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + OPENAI_URL_PREFIX = "https://api.openai.com" HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co" class Inference(CustomLLM): params: dict = field(default_factory=dict) _default_model: str = None + _custom_api_endpoint_type: str = None # "vllm", "non-vllm", or None def set_params(self, params: dict) -> None: self.params = params @@ -57,11 +63,24 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse # DEBUG: Call the debugging function # self._debug_curl_command(data) + try: + return self._post_request(data, headers={ + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" + }) + except Exception as e: + # Check for vLLM-specific missing model error + if "missing" in str(e) and "model" in str(e): + logger.warning("Detected missing 'model' parameter. Fetching default model and retrying...") + self._default_model = self._fetch_default_model() # Fetch default model dynamically + if self._default_model: + data["model"] = self._default_model + return self._post_request(data, headers={ + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" + }) + raise # Re-raise the exception if not recoverable - return self._post_request( - data, - headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", "Content-Type": "application/json"} - ) def _get_models_endpoint(self) -> str: """ @@ -82,17 +101,19 @@ def _fetch_default_model(self) -> str: } response = requests.get(models_url, headers=headers) - response.raise_for_status() # Raise an exception for HTTP errors + response.raise_for_status() # Raise an exception for HTTP errors (includes 404) models = response.json().get("data", []) + self._custom_api_endpoint_type = "vllm" return models[0].get("id") if models else None except Exception as e: - print(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.") + logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.") + self._custom_api_endpoint_type = "non-vllm" return None def _get_default_model(self) -> str: """ - Returns the cached default model if available; otherwise fetches and caches it. + Returns the cached default model if available, otherwise fetches and caches it. """ if not self._default_model: self._default_model = self._fetch_default_model() @@ -105,7 +126,7 @@ def _post_request(self, data: dict, headers: dict) -> CompletionResponse: response_data = response.json() return CompletionResponse(text=str(response_data)) except requests.RequestException as e: - print(f"Error during POST request to {LLM_INFERENCE_URL}: {e}") + logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}") raise def _debug_curl_command(self, data: dict) -> None: @@ -122,8 +143,8 @@ def _debug_curl_command(self, data: dict) -> None: }.items()]) + f" -d '{json.dumps(data)}'" ) - print("Equivalent curl command:") - print(curl_command) + logger.info("Equivalent curl command:") + logger.info(curl_command) @property def metadata(self) -> LLMMetadata: From dc3a0bef31b15ebe1e7f609466d459d0f0b207c6 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 20:14:07 -0800 Subject: [PATCH 14/22] feat: add more logging --- presets/ragengine/inference/inference.py | 33 ++++++++++++------------ 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index bcff51f1b..efede9bad 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -17,6 +17,10 @@ OPENAI_URL_PREFIX = "https://api.openai.com" HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co" +DEFAULT_HEADERS = { + "Authorization": f"Bearer {LLM_ACCESS_SECRET}", + "Content-Type": "application/json" +} class Inference(CustomLLM): params: dict = field(default_factory=dict) @@ -64,24 +68,24 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse # DEBUG: Call the debugging function # self._debug_curl_command(data) try: - return self._post_request(data, headers={ - "Authorization": f"Bearer {LLM_ACCESS_SECRET}", - "Content-Type": "application/json" - }) + return self._post_request(data, headers=DEFAULT_HEADERS) except Exception as e: + err_msg = str(e) # Check for vLLM-specific missing model error - if "missing" in str(e) and "model" in str(e): - logger.warning("Detected missing 'model' parameter. Fetching default model and retrying...") + if "missing" in err_msg and "model" in err_msg and "Field required" in err_msg: + logger.warning( + f"Detected missing 'model' parameter in API response. " + f"Response: {err_msg}. Attempting to fetch the default model..." + ) self._default_model = self._fetch_default_model() # Fetch default model dynamically if self._default_model: + logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") data["model"] = self._default_model - return self._post_request(data, headers={ - "Authorization": f"Bearer {LLM_ACCESS_SECRET}", - "Content-Type": "application/json" - }) + return self._post_request(data, headers=DEFAULT_HEADERS) + else: + logger.error("Failed to fetch a default model. Aborting retry.") raise # Re-raise the exception if not recoverable - def _get_models_endpoint(self) -> str: """ Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL. @@ -95,12 +99,7 @@ def _fetch_default_model(self) -> str: """ try: models_url = self._get_models_endpoint() - headers = { - "Authorization": f"Bearer {LLM_ACCESS_SECRET}", - "Content-Type": "application/json" - } - - response = requests.get(models_url, headers=headers) + response = requests.get(models_url, headers=DEFAULT_HEADERS) response.raise_for_status() # Raise an exception for HTTP errors (includes 404) models = response.json().get("data", []) From 13abd0173bf203b633c9e91bc65c80f1d6c48e42 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 21:18:21 -0800 Subject: [PATCH 15/22] fix: simplify logic --- presets/ragengine/inference/inference.py | 28 +++++------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index efede9bad..8103a3ff2 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -23,9 +23,9 @@ } class Inference(CustomLLM): - params: dict = field(default_factory=dict) + params: dict = {} _default_model: str = None - _custom_api_endpoint_type: str = None # "vllm", "non-vllm", or None + _model_retrieval_attempted: bool = False def set_params(self, params: dict) -> None: self.params = params @@ -67,24 +67,7 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse # DEBUG: Call the debugging function # self._debug_curl_command(data) - try: - return self._post_request(data, headers=DEFAULT_HEADERS) - except Exception as e: - err_msg = str(e) - # Check for vLLM-specific missing model error - if "missing" in err_msg and "model" in err_msg and "Field required" in err_msg: - logger.warning( - f"Detected missing 'model' parameter in API response. " - f"Response: {err_msg}. Attempting to fetch the default model..." - ) - self._default_model = self._fetch_default_model() # Fetch default model dynamically - if self._default_model: - logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") - data["model"] = self._default_model - return self._post_request(data, headers=DEFAULT_HEADERS) - else: - logger.error("Failed to fetch a default model. Aborting retry.") - raise # Re-raise the exception if not recoverable + return self._post_request(data, headers=DEFAULT_HEADERS) def _get_models_endpoint(self) -> str: """ @@ -103,18 +86,17 @@ def _fetch_default_model(self) -> str: response.raise_for_status() # Raise an exception for HTTP errors (includes 404) models = response.json().get("data", []) - self._custom_api_endpoint_type = "vllm" return models[0].get("id") if models else None except Exception as e: logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.") - self._custom_api_endpoint_type = "non-vllm" return None def _get_default_model(self) -> str: """ Returns the cached default model if available, otherwise fetches and caches it. """ - if not self._default_model: + if not self._default_model and not self._model_retrieval_attempted: + self._model_retrieval_attempted = True self._default_model = self._fetch_default_model() return self._default_model From 2c85ffea72dd6a4290e924d61f81fef827c3221e Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 21:44:35 -0800 Subject: [PATCH 16/22] fix --- presets/ragengine/inference/inference.py | 29 ++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index 8103a3ff2..821ed77f8 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -8,6 +8,7 @@ from llama_index.llms.openai import OpenAI from llama_index.core.llms.callbacks import llm_completion_callback import requests +from requests.exceptions import HTTPError from urllib.parse import urlparse, urljoin from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD @@ -60,14 +61,38 @@ def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> Completion ) def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - model = kwargs.pop("model", self._get_default_model()) + model = kwargs.pop("model", self.get_default_model()) data = {"prompt": prompt, **kwargs} if model: data["model"] = model # Include the model only if it is not None # DEBUG: Call the debugging function # self._debug_curl_command(data) - return self._post_request(data, headers=DEFAULT_HEADERS) + try: + return self._post_request(data, headers=DEFAULT_HEADERS) + except HTTPError as e: + if e.response.status_code == 400: + err_msg = str(e) + # Check for vLLM-specific missing model error + if "missing" in err_msg and "model" in err_msg and "Field required" in err_msg: + self._model_retrieval_attempted = False + logger.warning( + f"Detected missing 'model' parameter in API response. " + f"Response: {err_msg}. Attempting to fetch the default model..." + ) + self._default_model = self._fetch_default_model() # Fetch default model dynamically + if self._default_model: + logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") + data["model"] = self._default_model + return self._post_request(data, headers=DEFAULT_HEADERS) + else: + logger.error("Failed to fetch a default model. Aborting retry.") + else: + logger.error(f"HTTP 400 error occurred: {err_msg}") + raise # Re-raise the exception if not recoverable + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + raise def _get_models_endpoint(self) -> str: """ From b3f8a5c8c07e327b9ebf4302c2caf61f753d1f5d Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 21:47:34 -0800 Subject: [PATCH 17/22] remove --- presets/ragengine/inference/inference.py | 25 +++++++++++------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index 821ed77f8..e466f176d 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -74,21 +74,18 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse if e.response.status_code == 400: err_msg = str(e) # Check for vLLM-specific missing model error - if "missing" in err_msg and "model" in err_msg and "Field required" in err_msg: - self._model_retrieval_attempted = False - logger.warning( - f"Detected missing 'model' parameter in API response. " - f"Response: {err_msg}. Attempting to fetch the default model..." - ) - self._default_model = self._fetch_default_model() # Fetch default model dynamically - if self._default_model: - logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") - data["model"] = self._default_model - return self._post_request(data, headers=DEFAULT_HEADERS) - else: - logger.error("Failed to fetch a default model. Aborting retry.") + self._model_retrieval_attempted = False + logger.warning( + f"Detected missing 'model' parameter in API response. " + f"Response: {err_msg}. Attempting to fetch the default model..." + ) + self._default_model = self._fetch_default_model() # Fetch default model dynamically + if self._default_model: + logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") + data["model"] = self._default_model + return self._post_request(data, headers=DEFAULT_HEADERS) else: - logger.error(f"HTTP 400 error occurred: {err_msg}") + logger.error("Failed to fetch a default model. Aborting retry.") raise # Re-raise the exception if not recoverable except Exception as e: logger.error(f"An unexpected error occurred: {e}") From 27557d0e0b6b4d31eda2e0d7e2ea4e44bd495275 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 21:51:44 -0800 Subject: [PATCH 18/22] fix --- presets/ragengine/inference/inference.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index e466f176d..ec9de9d81 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -72,12 +72,10 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse return self._post_request(data, headers=DEFAULT_HEADERS) except HTTPError as e: if e.response.status_code == 400: - err_msg = str(e) - # Check for vLLM-specific missing model error self._model_retrieval_attempted = False logger.warning( f"Detected missing 'model' parameter in API response. " - f"Response: {err_msg}. Attempting to fetch the default model..." + f"Response: {str(e)}. Attempting to fetch the default model..." ) self._default_model = self._fetch_default_model() # Fetch default model dynamically if self._default_model: From 5cf21d31da8177376b4d4565b14e97f24beada26 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 21:55:47 -0800 Subject: [PATCH 19/22] fix --- presets/ragengine/inference/inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index ec9de9d81..945844dc8 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -72,7 +72,6 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse return self._post_request(data, headers=DEFAULT_HEADERS) except HTTPError as e: if e.response.status_code == 400: - self._model_retrieval_attempted = False logger.warning( f"Detected missing 'model' parameter in API response. " f"Response: {str(e)}. Attempting to fetch the default model..." From 7b10cd628732a4064b8a43ee21f9d6a33ee77a3a Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 22:01:32 -0800 Subject: [PATCH 20/22] fix --- presets/ragengine/inference/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index 945844dc8..c5d4655e3 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -73,8 +73,8 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse except HTTPError as e: if e.response.status_code == 400: logger.warning( - f"Detected missing 'model' parameter in API response. " - f"Response: {str(e)}. Attempting to fetch the default model..." + f"Potential issue with 'model' parameter in API response. " + f"Response: {str(e)}. Attempting to update the model name as a mitigation..." ) self._default_model = self._fetch_default_model() # Fetch default model dynamically if self._default_model: From 50fc8bc3dc8f23384bd2b7effdf6f27af38896ef Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 22:15:31 -0800 Subject: [PATCH 21/22] Update inference.py --- presets/ragengine/inference/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index c5d4655e3..7728c7ab3 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -61,7 +61,7 @@ def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> Completion ) def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - model = kwargs.pop("model", self.get_default_model()) + model = kwargs.pop("model", self._get_default_model()) data = {"prompt": prompt, **kwargs} if model: data["model"] = model # Include the model only if it is not None From 93e4f20142131d5403b2b768d34b387b003c1153 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 13 Jan 2025 22:20:48 -0800 Subject: [PATCH 22/22] nit remove --- presets/ragengine/tests/vector_store/test_base_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/ragengine/tests/vector_store/test_base_store.py b/presets/ragengine/tests/vector_store/test_base_store.py index 88f0bedeb..9f55bad95 100644 --- a/presets/ragengine/tests/vector_store/test_base_store.py +++ b/presets/ragengine/tests/vector_store/test_base_store.py @@ -89,7 +89,7 @@ def test_query_documents(self, mock_post, vector_store_manager): mock_post.assert_called_once_with( LLM_INFERENCE_URL, - json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7}, + json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", 'temperature': 0.7}, headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'} )