From 304f9958ec68950b29f26ff047db140b40d17abb Mon Sep 17 00:00:00 2001 From: Ian Johnson Date: Wed, 15 May 2024 09:48:38 -0400 Subject: [PATCH] add llama-3 instruct. add snowflake and bge-m3 embedding models. fix device usage for embeds. improve labeling prompt. upgrade tiktoken and other requirements --- latentscope/models/chat_models.json | 32 +- latentscope/models/embedding_models.json | 41 ++- latentscope/models/providers/openai.py | 7 +- latentscope/models/providers/transformers.py | 31 +- latentscope/scripts/label_clusters.py | 6 +- requirements.txt | 338 +++++++++---------- 6 files changed, 255 insertions(+), 200 deletions(-) diff --git a/latentscope/models/chat_models.json b/latentscope/models/chat_models.json index 0a1b2c5..dee8774 100644 --- a/latentscope/models/chat_models.json +++ b/latentscope/models/chat_models.json @@ -10,10 +10,10 @@ }, { "provider": "openai", - "name": "gpt-3.5-turbo-0125", - "id": "openai-gpt-3.5-turbo", + "name": "gpt-4o", + "id": "openai-gpt-4o", "params": { - "max_tokens": 4096 + "max_tokens": 128000 } }, { @@ -24,6 +24,14 @@ "max_tokens": 128000 } }, + { + "provider": "openai", + "name": "gpt-4-turbo", + "id": "openai-gpt-4-turbo", + "params": { + "max_tokens": 128000 + } + }, { "provider": "openai", "name": "gpt-4", @@ -58,21 +66,11 @@ }, { "provider": "transformers", - "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "sanitized_name": "TinyLlama___TinyLlama-1.1B-Chat-v1.0", - "id": "transformers-TinyLlama___TinyLlama-1.1B-Chat-v1.0", + "name": "meta-llama/Meta-Llama-3-8B-Instruct", + "sanitized_name": "meta-llama___Meta-Llama-3-8B-Instruct", + "id": "transformers-meta-llama___Meta-Llama-3-8B-Instruct", "params": { - "max_tokens": 2048 + "max_tokens": 8192 } - }, - { - "provider": "transformers", - "name": "HuggingFaceH4/zephyr-7b-beta", - "sanitized_name": "HuggingFaceH4___zephyr-7b-beta", - "id": "transformers-HuggingFaceH4___zephyr-7b-beta", - "params": { - "max_tokens": 4096 - } } - ] \ No newline at end of file diff --git a/latentscope/models/embedding_models.json b/latentscope/models/embedding_models.json index 45e89b6..160a7f4 100644 --- a/latentscope/models/embedding_models.json +++ b/latentscope/models/embedding_models.json @@ -55,6 +55,7 @@ "modality": "text", "params": { "max_tokens": 8192, + "rps": true, "truncation": true, "padding": true, "pooling": "mean" @@ -68,12 +69,39 @@ "modality": "text", "params": { "max_tokens": 8192, + "rps": true, "truncation": true, "padding": true, "pooling": "mean", "dimensions": [768, 512, 256, 128, 64] } }, + { + "provider": "transformers", + "name": "Snowflake/snowflake-arctic-embed-s", + "sanitized_name": "Snowflake___snowflake-arctic-embed-s", + "id": "transformers-Snowflake___snowflake-arctic-embed-s", + "modality": "text", + "params": { + "truncation": true, + "padding": true, + "pooling": "cls" + } + }, + { + "provider": "transformers", + "name": "Snowflake/snowflake-arctic-embed-m-long", + "sanitized_name": "Snowflake___snowflake-arctic-embed-m-long", + "id": "transformers-Snowflake___snowflake-arctic-embed-m-long", + "modality": "text", + "params": { + "max_tokens": 8192, + "rps": true, + "truncation": true, + "padding": true, + "pooling": "cls" + } + }, { "provider": "transformers", "name": "intfloat/e5-large-v2", @@ -97,7 +125,18 @@ "padding": true, "pooling": "mean" } -}, +},{ + "provider": "transformers", + "name": "BAAI/bge-m3", + "sanitized_name": "BAAI___bge-m3", + "id": "transformers-BAAI___bge-m3", + "modality": "text", + "params": { + "truncation": true, + "padding": true, + "pooling": "cls" + } + }, { "provider": "transformers", "name": "BAAI/bge-large-en-v1.5", diff --git a/latentscope/models/providers/openai.py b/latentscope/models/providers/openai.py index 4881789..75fc658 100644 --- a/latentscope/models/providers/openai.py +++ b/latentscope/models/providers/openai.py @@ -13,11 +13,7 @@ def load_model(self): print("ERROR: No API key found for OpenAI") print("Missing 'OPENAI_API_KEY' variable in:", f"{os.getcwd()}/.env") self.client = OpenAI(api_key=api_key) - # special case for the new embedding models - if self.name in ["text-embedding-3-small", "text-embedding-3-large"]: - self.encoder = tiktoken.get_encoding("cl100k_base") - else: - self.encoder = tiktoken.encoding_for_model(self.name) + self.encoder = tiktoken.encoding_for_model(self.name) def embed(self, inputs, dimensions=None): time.sleep(0.01) # TODO proper rate limiting @@ -46,6 +42,7 @@ def load_model(self): self.client = OpenAI(api_key=get_key("OPENAI_API_KEY")) self.encoder = tiktoken.encoding_for_model(self.name) + def chat(self, messages): response = self.client.chat.completions.create( model=self.name, diff --git a/latentscope/models/providers/transformers.py b/latentscope/models/providers/transformers.py index 13a6365..e075a95 100644 --- a/latentscope/models/providers/transformers.py +++ b/latentscope/models/providers/transformers.py @@ -5,6 +5,7 @@ def __init__(self, name, params): super().__init__(name, params) import torch self.torch = torch + self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") def cls_pooling(self, model_output): return model_output[0][:, 0] @@ -19,17 +20,26 @@ def mean_pooling(self, model_output, attention_mask): return self.torch.sum(token_embeddings * input_mask_expanded, 1) / self.torch.clamp(input_mask_expanded.sum(1), min=1e-9) def load_model(self): - from transformers import AutoTokenizer, AutoModel, pipeline + from transformers import AutoTokenizer, AutoModel + + if "rps" in self.params and self.params["rps"]: + self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True, safe_serialization=True, rotary_scaling_factor=2 ) + else: + self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True) + + print("CONFIG", self.model.config) + if self.name == "nomic-ai/nomic-embed-text-v1" or self.name == "nomic-ai/nomic-embed-text-v1.5": self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", model_max_length=self.params["max_tokens"]) - self.model = AutoModel.from_pretrained("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, rotary_scaling_factor=2 ) else: self.tokenizer = AutoTokenizer.from_pretrained(self.name) - self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True) + + self.model.to(self.device) self.model.eval() def embed(self, inputs, dimensions=None): encoded_input = self.tokenizer(inputs, padding=self.params["padding"], truncation=self.params["truncation"], return_tensors='pt') + encoded_input = {key: value.to(self.device) for key, value in encoded_input.items()} pool = self.params["pooling"] # Compute token embeddings with self.torch.no_grad(): @@ -52,11 +62,18 @@ def embed(self, inputs, dimensions=None): class TransformersChatProvider(ChatModelProvider): + def __init__(self, name, params): + super().__init__(name, params) + import torch + from transformers import pipeline + self.torch = torch + self.pipeline = pipeline + def load_model(self): # self.pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="auto") # self.pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="cpu") # TODO: support bfloat16 for non mac environmentss - self.pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto") + self.pipe = self.pipeline("text-generation", model=self.name, torch_dtype=self.torch.float16, device_map="auto") self.encoder = self.pipe.tokenizer def chat(self, messages, max_new_tokens=24): @@ -64,4 +81,8 @@ def chat(self, messages, max_new_tokens=24): outputs = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) generated_text = outputs[0]["generated_text"] print("GENERATED TEXT", generated_text) - return generated_text.split("<|assistant|>")[1].strip() \ No newline at end of file + if "<|start_header_id|>assistant<|end_header_id|>" in generated_text: + generated_text = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[1].strip() + elif "<|assistant|>" in generated_text: + generated_text = generated_text.split("<|assistant|>")[1].strip() + return generated_text \ No newline at end of file diff --git a/latentscope/scripts/label_clusters.py b/latentscope/scripts/label_clusters.py index ba24064..b949706 100644 --- a/latentscope/scripts/label_clusters.py +++ b/latentscope/scripts/label_clusters.py @@ -87,10 +87,10 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id=" model.load_model() enc = model.encoder - system_prompt = {"role":"system", "content": f"""You're job is to summarize lists of items with a short label of no more than 4 words. + system_prompt = {"role":"system", "content": f"""You're job is to summarize lists of items with a short label of no more than 4 words. The items are part of a cluster and the label will be used to distinguish this cluster from others, so pay attention to what makes this group of similar items distinct. {context} The user will submit a bulleted list of items and you should choose a label that best summarizes the theme of the list so that someone browsing the labels will have a good idea of what is in the list. -Do not use punctuation, just return a few words that summarize the list."""} +Do not use punctuation, Do not explain yourself, respond with only a few words that summarize the list."""} # TODO: why the extra 10 for openai? max_tokens = model.params["max_tokens"] - len(enc.encode(system_prompt["content"])) - 10 @@ -129,7 +129,7 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id=" try: time.sleep(0.01) messages=[ - system_prompt, {"role":"user", "content": batch[0]} # TODO hardcoded batch size + system_prompt, {"role":"user", "content": "Here is a list of items, please summarize the list into a label using only a few words:\n" + batch[0]} # TODO hardcoded batch size ] label = model.chat(messages) labels.append(label) diff --git a/requirements.txt b/requirements.txt index 8ddaf0c..ba3b738 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,169 +1,169 @@ -accelerate==0.26.1 -aiohttp==3.9.1 -aiolimiter==1.1.0 -aiosignal==1.3.1 -annotated-types==0.6.0 -anyio==4.2.0 -appnope==0.1.3 -argon2-cffi==23.1.0 -argon2-cffi-bindings==21.2.0 -arrow==1.3.0 -asttokens==2.4.1 -async-lru==2.0.4 -attrs==23.2.0 -Babel==2.14.0 -backoff==2.2.1 -beautifulsoup4==4.12.3 -bleach==6.1.0 -blinker==1.7.0 -certifi==2023.11.17 -cffi==1.16.0 -charset-normalizer==3.3.2 -click==8.1.7 -cohere==4.44 -comm==0.2.1 -contourpy==1.2.0 -cycler==0.12.1 -Cython==0.29.37 -debugpy==1.8.0 -decorator==5.1.1 -defusedxml==0.7.1 -distro==1.9.0 -einops==0.7.0 -executing==2.0.1 -fastavro==1.9.3 -fastjsonschema==2.19.1 -filelock==3.13.1 -Flask==3.0.0 -Flask-Cors==4.0.0 -fonttools==4.47.2 -fqdn==1.5.1 -frozenlist==1.4.1 -fsspec==2023.10.0 -h11==0.14.0 -h5py==3.10.0 -hdbscan==0.8.33 -httpcore==1.0.2 -httpx==0.25.2 -huggingface-hub==0.20.2 -idna==3.6 -importlib-metadata==6.11.0 -ipykernel==6.29.0 -ipython==8.20.0 -ipywidgets==8.1.1 -isoduration==20.11.0 -itsdangerous==2.1.2 -jedi==0.19.1 -Jinja2==3.1.3 -joblib==1.3.2 -json5==0.9.14 -jsonpointer==2.4 -jsonschema==4.21.1 -jsonschema-specifications==2023.12.1 -jupyter==1.0.0 -jupyter-console==6.6.3 -jupyter-events==0.9.0 -jupyter-lsp==2.2.2 -jupyter_client==8.6.0 -jupyter_core==5.7.1 -jupyter_server==2.12.5 -jupyter_server_terminals==0.5.2 -jupyterlab==4.0.12 -jupyterlab-widgets==3.0.9 -jupyterlab_pygments==0.3.0 -jupyterlab_server==2.25.2 -kiwisolver==1.4.5 -llvmlite==0.41.1 -MarkupSafe==2.1.3 -matplotlib==3.8.2 -matplotlib-inline==0.1.6 -mistralai==0.0.11 -mistune==3.0.2 -mpmath==1.3.0 -multidict==6.0.4 -nbclient==0.9.0 -nbconvert==7.14.2 -nbformat==5.9.2 -nest-asyncio==1.5.9 -networkx==3.2.1 -nltk==3.8.1 -notebook==7.0.7 -notebook_shim==0.2.3 -numba==0.58.1 -numpy==1.26.3 -openai==1.12.0 -opt-einsum==3.3.0 -orjson==3.9.12 -overrides==7.7.0 -packaging==23.2 -pandas==2.1.4 -pandocfilters==1.5.1 -parso==0.8.3 -pexpect==4.9.0 -pillow==10.2.0 -platformdirs==4.1.0 -prometheus-client==0.19.0 -prompt-toolkit==3.0.43 -psutil==5.9.7 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pyarrow==14.0.2 -pycparser==2.21 -pydantic==2.5.3 -pydantic_core==2.14.6 -Pygments==2.17.2 -pynndescent==0.5.11 -pyparsing==3.1.1 -python-dateutil==2.8.2 -python-dotenv==1.0.0 -python-json-logger==2.0.7 -pytz==2023.3.post1 -PyYAML==6.0.1 -pyzmq==25.1.2 -qtconsole==5.5.1 -QtPy==2.4.1 -referencing==0.33.0 -regex==2023.12.25 -requests==2.31.0 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -rpds-py==0.17.1 -safetensors==0.4.1 -scikit-learn==1.3.2 -scipy==1.11.4 -Send2Trash==1.8.2 -six==1.16.0 -sniffio==1.3.0 -soupsieve==2.5 -sseclient-py==1.8.0 -stack-data==0.6.3 -sympy==1.12 -tabulate==0.9.0 -tenacity==8.2.3 -terminado==0.18.0 -threadpoolctl==3.2.0 -tiktoken==0.5.2 -tinycss2==1.2.1 -together==0.2.10 -tokenizers==0.15.0 -torch==2.1.2 -tornado==6.4 -tqdm==4.66.1 -traitlets==5.14.1 -transformers==4.36.2 -typer==0.9.0 -types-python-dateutil==2.8.19.20240106 -typing_extensions==4.9.0 -tzdata==2023.4 -umap-learn==0.5.5 -uri-template==1.3.0 -urllib3==2.1.0 -voyageai==0.1.6 -wcwidth==0.2.13 -webcolors==1.13 -webencodings==0.5.1 -websocket-client==1.7.0 -Werkzeug==3.0.1 -widgetsnbextension==4.0.9 -yarl==1.9.4 -zipp==3.17.0 +accelerate~=0.26.1 +aiohttp~=3.9.1 +aiolimiter~=1.1.0 +aiosignal~=1.3.1 +annotated-types~=0.6.0 +anyio~=4.2.0 +appnope~=0.1.3 +argon2-cffi~=23.1.0 +argon2-cffi-bindings~=21.2.0 +arrow~=1.3.0 +asttokens~=2.4.1 +async-lru~=2.0.4 +attrs~=23.2.0 +Babel~=2.14.0 +backoff~=2.2.1 +beautifulsoup4~=4.12.3 +bleach~=6.1.0 +blinker~=1.7.0 +certifi~=2023.11.17 +cffi~=1.16.0 +charset-normalizer~=3.3.2 +click~=8.1.7 +cohere~=4.44 +comm~=0.2.1 +contourpy~=1.2.0 +cycler~=0.12.1 +Cython~=0.29.37 +debugpy~=1.8.0 +decorator~=5.1.1 +defusedxml~=0.7.1 +distro~=1.9.0 +einops~=0.7.0 +executing~=2.0.1 +fastavro~=1.9.3 +fastjsonschema~=2.19.1 +filelock~=3.13.1 +Flask~=3.0.0 +Flask-Cors~=4.0.0 +fonttools~=4.47.2 +fqdn~=1.5.1 +frozenlist~=1.4.1 +fsspec~=2023.10.0 +h11~=0.14.0 +h5py~=3.10.0 +hdbscan~=0.8.33 +httpcore~=1.0.2 +httpx~=0.25.2 +huggingface-hub~=0.20.2 +idna~=3.6 +importlib-metadata~=6.11.0 +ipykernel~=6.29.0 +ipython~=8.20.0 +ipywidgets~=8.1.1 +isoduration~=20.11.0 +itsdangerous~=2.1.2 +jedi~=0.19.1 +Jinja2~=3.1.3 +joblib~=1.3.2 +json5~=0.9.14 +jsonpointer~=2.4 +jsonschema~=4.21.1 +jsonschema-specifications~=2023.12.1 +jupyter~=1.0.0 +jupyter-console~=6.6.3 +jupyter-events~=0.9.0 +jupyter-lsp~=2.2.2 +jupyter_client~=8.6.0 +jupyter_core~=5.7.1 +jupyter_server~=2.12.5 +jupyter_server_terminals~=0.5.2 +jupyterlab~=4.0.12 +jupyterlab-widgets~=3.0.9 +jupyterlab_pygments~=0.3.0 +jupyterlab_server~=2.25.2 +kiwisolver~=1.4.5 +llvmlite~=0.41.1 +MarkupSafe~=2.1.3 +matplotlib~=3.8.2 +matplotlib-inline~=0.1.6 +mistralai~=0.0.11 +mistune~=3.0.2 +mpmath~=1.3.0 +multidict~=6.0.4 +nbclient~=0.9.0 +nbconvert~=7.14.2 +nbformat~=5.9.2 +nest-asyncio~=1.5.9 +networkx~=3.2.1 +nltk~=3.8.1 +notebook~=7.0.7 +notebook_shim~=0.2.3 +numba~=0.58.1 +numpy~=1.26.3 +openai~=1.12.0 +opt-einsum~=3.3.0 +orjson~=3.9.12 +overrides~=7.7.0 +packaging~=23.2 +pandas~=2.1.4 +pandocfilters~=1.5.1 +parso~=0.8.3 +pexpect~=4.9.0 +pillow~=10.2.0 +platformdirs~=4.1.0 +prometheus-client~=0.19.0 +prompt-toolkit~=3.0.43 +psutil~=5.9.7 +ptyprocess~=0.7.0 +pure-eval~=0.2.2 +pyarrow~=14.0.2 +pycparser~=2.21 +pydantic~=2.5.3 +pydantic_core~=2.14.6 +Pygments~=2.17.2 +pynndescent~=0.5.11 +pyparsing~=3.1.1 +python-dateutil~=2.8.2 +python-dotenv~=1.0.0 +python-json-logger~=2.0.7 +pytz~=2023.3.post1 +PyYAML~=6.0.1 +pyzmq~=25.1.2 +qtconsole~=5.5.1 +QtPy~=2.4.1 +referencing~=0.33.0 +regex~=2023.12.25 +requests~=2.31.0 +rfc3339-validator~=0.1.4 +rfc3986-validator~=0.1.1 +rpds-py~=0.17.1 +safetensors~=0.4.1 +scikit-learn~=1.3.2 +scipy~=1.11.4 +Send2Trash~=1.8.2 +six~=1.16.0 +sniffio~=1.3.0 +soupsieve~=2.5 +sseclient-py~=1.8.0 +stack-data~=0.6.3 +sympy~=1.12 +tabulate~=0.9.0 +tenacity~=8.2.3 +terminado~=0.18.0 +threadpoolctl~=3.2.0 +tiktoken~=0.7.0 +tinycss2~=1.2.1 +together~=0.2.10 +tokenizers~=0.15.0 +torch~=2.1.2 +tornado~=6.4 +tqdm~=4.66.1 +traitlets~=5.14.1 +transformers~=4.36.2 +typer~=0.9.0 +types-python-dateutil~=2.8.19.20240106 +typing_extensions~=4.9.0 +tzdata~=2023.4 +umap-learn~=0.5.5 +uri-template~=1.3.0 +urllib3~=2.1.0 +voyageai~=0.1.6 +wcwidth~=0.2.13 +webcolors~=1.13 +webencodings~=0.5.1 +websocket-client~=1.7.0 +Werkzeug~=3.0.1 +widgetsnbextension~=4.0.9 +yarl~=1.9.4 +zipp~=3.17.0