Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: default for max_new_tokens to 512 in Hugging Face generators #7370

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions haystack/components/generators/chat/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
generation_kwargs.setdefault("max_new_tokens", 512)

self.model = model
self.url = url
Expand Down
1 change: 1 addition & 0 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
"Please specify only one of them."
)
generation_kwargs.setdefault("max_new_tokens", 512)

self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs
Expand Down
1 change: 1 addition & 0 deletions haystack/components/generators/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
generation_kwargs.setdefault("max_new_tokens", 512)

self.model = model
self.url = url
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Set max_new_tokens default to 512 in Hugging Face generators.
28 changes: 15 additions & 13 deletions test/components/generators/chat/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from unittest.mock import patch, MagicMock, Mock

from haystack.utils.auth import Secret
from unittest.mock import MagicMock, Mock, patch

import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators.chat import HuggingFaceTGIChatGenerator

from haystack.dataclasses import StreamingChunk, ChatMessage
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret


@pytest.fixture
Expand Down Expand Up @@ -70,7 +68,11 @@ def test_initialize_with_valid_model_and_generation_parameters(
)
generator.warm_up()

assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.generation_kwargs == {
**generation_kwargs,
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is not None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
Expand All @@ -92,7 +94,7 @@ def test_to_dict(self, mock_check_valid_model):
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIChatGenerator(
Expand All @@ -106,7 +108,7 @@ def test_from_dict(self, mock_check_valid_model):

generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler

def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
Expand Down Expand Up @@ -205,7 +207,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
# check kwargs passed to text_generation
# note how n because it is not text generation parameter was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

assert isinstance(response, dict)
assert "replies" in response
Expand Down Expand Up @@ -240,7 +242,7 @@ def test_generate_multiple_text_responses_with_valid_prompt_and_generation_param

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

# note how n caused n replies to be generated
assert isinstance(response, dict)
Expand Down Expand Up @@ -268,7 +270,7 @@ def test_generate_text_with_stop_words(
# check kwargs passed to text_generation
# we translate stop_words to stop_sequences
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

# Assert that the response contains the generated replies
assert "replies" in response
Expand Down Expand Up @@ -343,7 +345,7 @@ def mock_iter(self):

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}

# Assert that the streaming callback was called twice
assert streaming_call_count == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import pytest
import torch
from transformers import PreTrainedTokenizerFast
from haystack.utils.auth import Secret

from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret


class TestHuggingFaceLocalGenerator:
Expand All @@ -23,7 +23,7 @@ def test_init_default(self, model_info_mock, monkeypatch):
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
}
assert generator.generation_kwargs == {}
assert generator.generation_kwargs == {"max_new_tokens": 512}
assert generator.pipeline is None

def test_init_custom_token(self):
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_init_set_return_full_text(self):
"""
generator = HuggingFaceLocalGenerator(task="text-generation")

assert generator.generation_kwargs == {"return_full_text": False}
assert generator.generation_kwargs == {"max_new_tokens": 512, "return_full_text": False}

def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
with pytest.raises(
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_to_dict_default(self, model_info_mock):
"task": "text2text-generation",
"device": ComponentDevice.resolve_device(None).to_hf(),
},
"generation_kwargs": {},
"generation_kwargs": {"max_new_tokens": 512},
"stop_words": None,
},
}
Expand Down
22 changes: 13 additions & 9 deletions test/components/generators/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch, MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch

import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators import HuggingFaceTGIGenerator
Expand Down Expand Up @@ -63,7 +63,11 @@ def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_
)

assert generator.model == model
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.generation_kwargs == {
**generation_kwargs,
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
Expand All @@ -84,7 +88,7 @@ def test_to_dict(self, mock_check_valid_model):
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIGenerator(
Expand All @@ -99,7 +103,7 @@ def test_from_dict(self, mock_check_valid_model):
# now deserialize, call from_dict
generator_2 = HuggingFaceTGIGenerator.from_dict(result)
assert generator_2.model == "mistralai/Mistral-7B-v0.1"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler

def test_initialize_with_invalid_url(self, mock_check_valid_model):
Expand Down Expand Up @@ -135,7 +139,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
# check kwargs passed to text_generation
# note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

assert isinstance(response, dict)
assert "replies" in response
Expand Down Expand Up @@ -168,7 +172,7 @@ def test_generate_multiple_text_responses_with_valid_prompt_and_generation_param
# check kwargs passed to text_generation
# note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

assert isinstance(response, dict)
assert "replies" in response
Expand Down Expand Up @@ -208,7 +212,7 @@ def test_generate_text_with_stop_words(

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

# Assert that the response contains the generated replies
assert "replies" in response
Expand Down Expand Up @@ -283,7 +287,7 @@ def mock_iter(self):

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}

# Assert that the streaming callback was called twice
assert streaming_call_count == 2
Expand Down