diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 71f656e854a38..682241983389c 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -3,7 +3,17 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, List, Optional +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Iterable, + List, + Literal, + Optional, + Union, +) from langchain.docstore.document import Document @@ -114,7 +124,11 @@ def _huggingface_tokenizer_length(text: str) -> int: @classmethod def from_tiktoken_encoder( - cls, encoding_name: str = "gpt2", **kwargs: Any + cls, + encoding_name: str = "gpt2", + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -125,11 +139,19 @@ def from_tiktoken_encoder( "This is needed in order to calculate max_tokens_for_prompt. " "Please it install it with `pip install tiktoken`." ) + # create a GPT-3 encoder instance enc = tiktoken.get_encoding(encoding_name) - def _tiktoken_encoder(text: str) -> int: - return len(enc.encode(text)) + def _tiktoken_encoder(text: str, **kwargs: Any) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + **kwargs, + ) + ) return cls(length_function=_tiktoken_encoder, **kwargs) @@ -169,10 +191,17 @@ def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) - def split_text(self, text: str) -> List[str]: + def split_text( + self, + text: str, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> List[str]: """Split incoming text and return chunks.""" splits = [] - input_ids = self._tokenizer.encode(text) + input_ids = self._tokenizer.encode( + text, allowed_special=allowed_special, disallowed_special=disallowed_special + ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx]