From 4b1598658f6dfc93bb162bb50c1ac0e56badc818 Mon Sep 17 00:00:00 2001 From: jamescalam Date: Tue, 7 Feb 2023 12:14:32 +0400 Subject: [PATCH 1/6] support for special token params --- langchain/text_splitter.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 71f656e854a38..3a8550f4601ef 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -17,7 +17,7 @@ def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, + length_function: Callable[[str], int] = len ): """Create a new TextSplitter.""" if chunk_overlap > chunk_size: @@ -114,7 +114,7 @@ def _huggingface_tokenizer_length(text: str) -> int: @classmethod def from_tiktoken_encoder( - cls, encoding_name: str = "gpt2", **kwargs: Any + cls, encoding_name: str = "gpt2", allowed_special=set(), disallowed_special="all", **kwargs: Any ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -125,11 +125,12 @@ def from_tiktoken_encoder( "This is needed in order to calculate max_tokens_for_prompt. " "Please it install it with `pip install tiktoken`." ) + # create a GPT-3 encoder instance enc = tiktoken.get_encoding(encoding_name) - def _tiktoken_encoder(text: str) -> int: - return len(enc.encode(text)) + def _tiktoken_encoder(text: str, **kwargs) -> int: + return len(enc.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special, **kwargs)) return cls(length_function=_tiktoken_encoder, **kwargs) @@ -169,10 +170,14 @@ def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) - def split_text(self, text: str) -> List[str]: + def split_text(self, text: str, allowed_special=set(), disallowed_special="all") -> List[str]: """Split incoming text and return chunks.""" splits = [] - input_ids = self._tokenizer.encode(text) + input_ids = self._tokenizer.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special + ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] From 354f6bbe1135673324390171e3747e8824ef58d0 Mon Sep 17 00:00:00 2001 From: jamescalam Date: Tue, 7 Feb 2023 12:21:14 +0400 Subject: [PATCH 2/6] black reformat --- langchain/text_splitter.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 3a8550f4601ef..e4140e56accd8 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -17,7 +17,7 @@ def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, - length_function: Callable[[str], int] = len + length_function: Callable[[str], int] = len, ): """Create a new TextSplitter.""" if chunk_overlap > chunk_size: @@ -114,7 +114,11 @@ def _huggingface_tokenizer_length(text: str) -> int: @classmethod def from_tiktoken_encoder( - cls, encoding_name: str = "gpt2", allowed_special=set(), disallowed_special="all", **kwargs: Any + cls, + encoding_name: str = "gpt2", + allowed_special=set(), + disallowed_special="all", + **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -130,7 +134,14 @@ def from_tiktoken_encoder( enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str, **kwargs) -> int: - return len(enc.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special, **kwargs)) + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + **kwargs, + ) + ) return cls(length_function=_tiktoken_encoder, **kwargs) @@ -170,13 +181,13 @@ def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) - def split_text(self, text: str, allowed_special=set(), disallowed_special="all") -> List[str]: + def split_text( + self, text: str, allowed_special=set(), disallowed_special="all" + ) -> List[str]: """Split incoming text and return chunks.""" splits = [] input_ids = self._tokenizer.encode( - text, - allowed_special=allowed_special, - disallowed_special=disallowed_special + text, allowed_special=allowed_special, disallowed_special=disallowed_special ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids)) From 1941d9fc2650f431c549d3db5d4a5fe7fae019d4 Mon Sep 17 00:00:00 2001 From: jamescalam Date: Tue, 7 Feb 2023 12:27:22 +0400 Subject: [PATCH 3/6] type annotations and reformat --- langchain/text_splitter.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index e4140e56accd8..74a80daa6aa43 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -3,7 +3,18 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, List, Optional +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + AbstractSet, + Collection, + Literal, + Optional, + Union, +) from langchain.docstore.document import Document @@ -116,8 +127,8 @@ def _huggingface_tokenizer_length(text: str) -> int: def from_tiktoken_encoder( cls, encoding_name: str = "gpt2", - allowed_special=set(), - disallowed_special="all", + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" @@ -133,7 +144,7 @@ def from_tiktoken_encoder( # create a GPT-3 encoder instance enc = tiktoken.get_encoding(encoding_name) - def _tiktoken_encoder(text: str, **kwargs) -> int: + def _tiktoken_encoder(text: str, **kwargs: Any) -> int: return len( enc.encode( text, @@ -182,7 +193,10 @@ def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): self._tokenizer = tiktoken.get_encoding(encoding_name) def split_text( - self, text: str, allowed_special=set(), disallowed_special="all" + self, + text: str, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", ) -> List[str]: """Split incoming text and return chunks.""" splits = [] From 0e3477ba34be51caa7be21d1d1ebbbf949e96809 Mon Sep 17 00:00:00 2001 From: jamescalam Date: Tue, 7 Feb 2023 12:49:16 +0400 Subject: [PATCH 4/6] reformat imports --- langchain/text_splitter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 74a80daa6aa43..682241983389c 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -4,13 +4,12 @@ import logging from abc import ABC, abstractmethod from typing import ( + AbstractSet, Any, Callable, + Collection, Iterable, List, - Optional, - AbstractSet, - Collection, Literal, Optional, Union, From f99472497f85488bd7a8f382bf782e6150ee8fe8 Mon Sep 17 00:00:00 2001 From: jamescalam Date: Fri, 10 Feb 2023 10:28:32 +0400 Subject: [PATCH 5/6] switched to add special token params in __init__ --- langchain/text_splitter.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 682241983389c..bd81b8c79e846 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -177,7 +177,13 @@ def split_text(self, text: str) -> List[str]: class TokenTextSplitter(TextSplitter): """Implementation of splitting text that looks at tokens.""" - def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): + def __init__( + self, + encoding_name: str = "gpt2", + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ): """Create a new TextSplitter.""" super().__init__(**kwargs) try: @@ -188,19 +194,22 @@ def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): "This is needed in order to for TokenTextSplitter. " "Please it install it with `pip install tiktoken`." ) + # get special token settings + self.allowed_special = allowed_special + self.disallowed_special = disallowed_special # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) def split_text( self, text: str, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", ) -> List[str]: """Split incoming text and return chunks.""" splits = [] input_ids = self._tokenizer.encode( - text, allowed_special=allowed_special, disallowed_special=disallowed_special + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids)) From 42a7143a2bf536f8796c1fbd411198b4cf5e01c9 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Fri, 10 Feb 2023 11:03:32 +0400 Subject: [PATCH 6/6] Revert "switched to add special token params in __init__" This reverts commit f99472497f85488bd7a8f382bf782e6150ee8fe8. --- langchain/text_splitter.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index bd81b8c79e846..682241983389c 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -177,13 +177,7 @@ def split_text(self, text: str) -> List[str]: class TokenTextSplitter(TextSplitter): """Implementation of splitting text that looks at tokens.""" - def __init__( - self, - encoding_name: str = "gpt2", - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, - ): + def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) try: @@ -194,22 +188,19 @@ def __init__( "This is needed in order to for TokenTextSplitter. " "Please it install it with `pip install tiktoken`." ) - # get special token settings - self.allowed_special = allowed_special - self.disallowed_special = disallowed_special # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) def split_text( self, text: str, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", ) -> List[str]: """Split incoming text and return chunks.""" splits = [] input_ids = self._tokenizer.encode( - text, - allowed_special=self.allowed_special, - disallowed_special=self.disallowed_special, + text, allowed_special=allowed_special, disallowed_special=disallowed_special ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids))