Skip to content

Commit

Permalink
feat: hf inference support for gated repos (#7598)
Browse files Browse the repository at this point in the history
* feat: HF inference support for gated repos

* add reno

* take use_auth_token instead of api_key

* add stop param to hf inference
  • Loading branch information
tstadel authored Apr 29, 2024
1 parent 7469850 commit 3c2a4fe
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
12 changes: 9 additions & 3 deletions haystack/nodes/prompt/invocation_layer/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod, ABC
from typing import Union, Dict
from typing import Optional, Union, Dict

from haystack.lazy_imports import LazyImport

Expand Down Expand Up @@ -61,8 +61,14 @@ class DefaultPromptHandler:
are within the model_max_length.
"""

def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
def __init__(
self,
model_name_or_path: str,
model_max_length: int,
max_length: int = 100,
use_auth_token: Optional[Union[str, bool]] = None,
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=use_auth_token)
self.tokenizer.model_max_length = model_max_length
self.model_max_length = model_max_length
self.max_length = max_length
Expand Down
11 changes: 10 additions & 1 deletion haystack/nodes/prompt/invocation_layer/hugging_face_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer):
"""

def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs):
def __init__(
self,
api_key: str,
model_name_or_path: str,
max_length: Optional[int] = 100,
use_auth_token: Optional[Union[str, bool]] = None,
**kwargs,
):
"""
Creates an instance of HFInferenceEndpointInvocationLayer
:param model_name_or_path: can be either:
Expand Down Expand Up @@ -76,6 +83,7 @@ def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[i
"repetition_penalty",
"return_full_text",
"seed",
"stop",
"stream",
"stream_handler",
"temperature",
Expand Down Expand Up @@ -104,6 +112,7 @@ def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[i
model_name_or_path=model_name_or_path,
model_max_length=model_max_length,
max_length=self.max_length or 100,
use_auth_token=use_auth_token,
)

def preprocess_prompt(self, prompt: str):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Support gated repos for Huggingface inference.

0 comments on commit 3c2a4fe

Please sign in to comment.