Skip to content

Commit

Permalink
Merge branch 'main' into text-router
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl authored Mar 15, 2024
2 parents 80f6734 + abda78c commit 89048f4
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}

@staticmethod
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
def get_embedding_backend(
model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, trust_remote_code: bool = False
):
embedding_backend_id = f"{model}{device}{auth_token}"

if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
embedding_backend = _SentenceTransformersEmbeddingBackend(
model=model, device=device, auth_token=auth_token, trust_remote_code=trust_remote_code
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -30,10 +34,19 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings.
"""

def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
def __init__(
self,
model: str,
device: Optional[str] = None,
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
model_name_or_path=model,
device=device,
use_auth_token=auth_token.resolve_value() if auth_token else None,
trust_remote_code=trust_remote_code,
)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
normalize_embeddings: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
trust_remote_code: bool = False,
):
"""
Create a SentenceTransformersDocumentEmbedder component.
Expand All @@ -65,6 +66,9 @@ def __init__(
List of meta fields that will be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
:param trust_remote_code:
If `False`, only Hugging Face verified model architectures are allowed.
If `True`, custom models and scripts are allowed.
"""

self.model = model
Expand All @@ -77,6 +81,7 @@ def __init__(
self.normalize_embeddings = normalize_embeddings
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.trust_remote_code = trust_remote_code

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -103,6 +108,7 @@ def to_dict(self) -> Dict[str, Any]:
normalize_embeddings=self.normalize_embeddings,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
trust_remote_code=self.trust_remote_code,
)

@classmethod
Expand All @@ -127,7 +133,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
trust_remote_code: bool = False,
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand All @@ -59,6 +60,9 @@ def __init__(
If True shows a progress bar when running.
:param normalize_embeddings:
If True returned vectors will have length 1.
:param trust_remote_code:
If `False`, only Hugging Face verified model architectures are allowed.
If `True`, custom models and scripts are allowed.
"""

self.model = model
Expand All @@ -69,6 +73,7 @@ def __init__(
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.trust_remote_code = trust_remote_code

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -93,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
trust_remote_code=self.trust_remote_code,
)

@classmethod
Expand All @@ -117,7 +123,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
)

@component.output_types(embedding=List[float])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
"tqdm",
"tenacity",
"lazy-imports",
"openai==1.13.3", # unpin after fix for https://github.com/deepset-ai/haystack/issues/7358
"openai>=1.1.0",
"Jinja2",
"posthog", # telemetry
"pyyaml",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add trust_remote_code parameter to SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder for allowing custom models and scripts.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import patch, MagicMock
import pytest
from unittest.mock import MagicMock, patch

import numpy as np
from haystack.utils import Secret, ComponentDevice
import pytest

from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.utils import ComponentDevice, Secret


class TestSentenceTransformersDocumentEmbedder:
Expand All @@ -20,6 +21,7 @@ def test_init_default(self):
assert embedder.normalize_embeddings is False
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.trust_remote_code is False

def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
Expand All @@ -33,6 +35,7 @@ def test_init_with_parameters(self):
normalize_embeddings=True,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
trust_remote_code=True,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
Expand All @@ -44,6 +47,7 @@ def test_init_with_parameters(self):
assert embedder.normalize_embeddings is True
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.trust_remote_code

def test_to_dict(self):
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
Expand All @@ -61,6 +65,7 @@ def test_to_dict(self):
"normalize_embeddings": False,
"embedding_separator": "\n",
"meta_fields_to_embed": [],
"trust_remote_code": False,
},
}

Expand All @@ -76,6 +81,7 @@ def test_to_dict_with_custom_init_parameters(self):
normalize_embeddings=True,
meta_fields_to_embed=["meta_field"],
embedding_separator=" - ",
trust_remote_code=True,
)
data = component.to_dict()

Expand All @@ -91,6 +97,7 @@ def test_to_dict_with_custom_init_parameters(self):
"progress_bar": False,
"normalize_embeddings": True,
"embedding_separator": " - ",
"trust_remote_code": True,
"meta_fields_to_embed": ["meta_field"],
},
}
Expand All @@ -107,6 +114,7 @@ def test_from_dict(self):
"normalize_embeddings": True,
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
"trust_remote_code": True,
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
Expand All @@ -123,6 +131,7 @@ def test_from_dict(self):
assert component.progress_bar is False
assert component.normalize_embeddings is True
assert component.embedding_separator == " - "
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]

@patch(
Expand All @@ -134,7 +143,9 @@ def test_warmup(self, mocked_factory):
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
)

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import patch

import pytest

from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
Expand All @@ -23,10 +25,10 @@ def test_factory_behavior(mock_sentence_transformer):
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from unittest.mock import patch, MagicMock
import pytest
from haystack.utils import Secret, ComponentDevice
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret


class TestSentenceTransformersTextEmbedder:
Expand All @@ -18,6 +18,7 @@ def test_init_default(self):
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
assert embedder.trust_remote_code is False

def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
Expand All @@ -29,6 +30,7 @@ def test_init_with_parameters(self):
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
Expand All @@ -38,6 +40,7 @@ def test_init_with_parameters(self):
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
assert embedder.trust_remote_code

def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
Expand All @@ -53,6 +56,7 @@ def test_to_dict(self):
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}

Expand All @@ -66,6 +70,7 @@ def test_to_dict_with_custom_init_parameters(self):
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
)
data = component.to_dict()
assert data == {
Expand All @@ -79,6 +84,7 @@ def test_to_dict_with_custom_init_parameters(self):
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"trust_remote_code": True,
},
}

Expand All @@ -99,6 +105,7 @@ def test_from_dict(self):
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
Expand All @@ -110,6 +117,7 @@ def test_from_dict(self):
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand All @@ -118,7 +126,9 @@ def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
)

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand Down
9 changes: 5 additions & 4 deletions test/components/generators/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from datetime import datetime
from typing import Iterator
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

import pytest
from openai import Stream
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import ChoiceDelta, Choice
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta


@pytest.fixture
Expand Down Expand Up @@ -33,8 +33,9 @@ def mock_chat_completion_chunk():
"""

class MockStream(Stream[ChatCompletionChunk]):
def __init__(self, mock_chunk: ChatCompletionChunk, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs):
client = client or MagicMock()
super().__init__(client=client, *args, **kwargs)
self.mock_chunk = mock_chunk

def __stream__(self) -> Iterator[ChatCompletionChunk]:
Expand Down

0 comments on commit 89048f4

Please sign in to comment.