Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM Studio Support #1538

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions py/cli/commands/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ async def serve(
"Warning: `config-name` corresponds to an existing file. If you intended a custom config, use `config-path`."
)

if config_name in ["local_llm", "full_local_llm"]:
click.secho(
"WARNING: `local_llm` and `full_local_llm` are deprecated and support for them will be removed. Please use `local_llm_ollama`, `local_llm_lm_studio`, `full_local_llm_ollama`, `full_local_llm_lm_studio`.",
fg="red",
)

if build:
click.echo(
"`build` flag detected. Building Docker image from local repository..."
Expand Down
10 changes: 7 additions & 3 deletions py/compose.full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ services:
retries: 5

r2r:
image: ${R2R_IMAGE:-ragtoriches/prod:latest}
# image: ${R2R_IMAGE:-ragtoriches/prod:latest}
image: r2r/test
build:
context: .
args:
Expand Down Expand Up @@ -333,6 +334,9 @@ services:
# Ollama
- OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}

# LM Studio
- LM_STUDIO_API_BASE=${LM_STUDIO_API_BASE:-http://host.docker.internal:1234}

# Unstructured
- UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
- UNSTRUCTURED_API_URL=${UNSTRUCTURED_API_URL:-https://api.unstructured.io/general/v0/general}
Expand All @@ -352,6 +356,8 @@ services:
'
networks:
- r2r-network
extra_hosts:
- host.docker.internal:host-gateway
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v2/health"]
interval: 6s
Expand All @@ -361,8 +367,6 @@ services:
volumes:
- ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
- hatchet_api_key:/hatchet_api_key:ro
extra_hosts:
- host.docker.internal:host-gateway
depends_on:
setup-token:
condition: service_completed_successfully
Expand Down
6 changes: 5 additions & 1 deletion py/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ services:
-c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}

r2r:
image: ${R2R_IMAGE:-ragtoriches/prod:latest}
# image: ${R2R_IMAGE:-ragtoriches/prod:latest}
image: r2r/test
build:
context: .
args:
Expand Down Expand Up @@ -99,6 +100,9 @@ services:
# Ollama
- OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}

# LM Studio
- LM_STUDIO_API_BASE=${LM_STUDIO_API_BASE:-http://host.docker.internal:1234}

networks:
- r2r-network
healthcheck:
Expand Down
69 changes: 69 additions & 0 deletions py/core/configs/full_local_llm_lm_studio.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]

[agent.generation_config]
model = "lm_studio/llama-3.2-3b-instruct"

[completion]
provider = "litellm"
concurrent_request_limit = 1

[completion.generation_config]
model = "lm_studio/llama-3.2-3b-instruct"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[database]
provider = "postgres"

[database.kg_creation_settings]
kg_entity_description_prompt = "graphrag_entity_description"
kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot"
entity_types = [] # if empty, all entities are extracted
relation_types = [] # if empty, all relations are extracted
fragment_merge_count = 4 # number of fragments to merge into a single extraction
max_knowledge_triples = 100
max_description_input_length = 65536
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" } # and other params, model used for triplet extraction

[database.kg_entity_deduplication_settings]
kg_entity_deduplication_type = "by_name"
kg_entity_deduplication_prompt = "graphrag_entity_deduplication"
max_description_input_length = 65536
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" } # and other params, model used for deduplication

[database.kg_enrichment_settings]
community_reports_prompt = "graphrag_community_reports"
max_summary_input_length = 65536
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" } # and other params, model used for node description and graph clustering
leiden_params = {}

[database.kg_search_settings]
entities_level = "document" # set to collection if you've run deduplication
map_system_prompt = "graphrag_map_system"
reduce_system_prompt = "graphrag_reduce_system"
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" }

[embedding]
provider = "litellm"
base_model = "lm_studio/text-embedding-nomic-embed-text-v1.5-embedding"
base_dimension = 768
batch_size = 128
add_title_as_prefix = true
concurrent_request_limit = 2

[ingestion]
provider = "unstructured_local"
strategy = "auto"
chunking_strategy = "by_title"
new_after_n_chars = 512
max_characters = 1_024
combine_under_n_chars = 128
overlap = 20

[orchestration]
provider = "hatchet"
71 changes: 71 additions & 0 deletions py/core/configs/full_local_llm_ollama.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]

[agent.generation_config]
model = "ollama/llama3.1"

[completion]
provider = "litellm"
concurrent_request_limit = 1

[completion.generation_config]
model = "ollama/llama3.1"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }


[database]
provider = "postgres"

[database.kg_creation_settings]
kg_entity_description_prompt = "graphrag_entity_description"
kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot"
entity_types = [] # if empty, all entities are extracted
relation_types = [] # if empty, all relations are extracted
fragment_merge_count = 4 # number of fragments to merge into a single extraction
max_knowledge_triples = 100
max_description_input_length = 65536
generation_config = { model = "ollama/llama3.1" } # and other params, model used for triplet extraction

[database.kg_entity_deduplication_settings]
kg_entity_deduplication_type = "by_name"
kg_entity_deduplication_prompt = "graphrag_entity_deduplication"
max_description_input_length = 65536
generation_config = { model = "ollama/llama3.1" } # and other params, model used for deduplication

[database.kg_enrichment_settings]
community_reports_prompt = "graphrag_community_reports"
max_summary_input_length = 65536
generation_config = { model = "ollama/llama3.1" } # and other params, model used for node description and graph clustering
leiden_params = {}

[database.kg_search_settings]
entities_level = "document" # set to collection if you've run deduplication
map_system_prompt = "graphrag_map_system"
reduce_system_prompt = "graphrag_reduce_system"
generation_config = { model = "ollama/llama3.1" }


[embedding]
provider = "ollama"
base_model = "mxbai-embed-large"
base_dimension = 1_024
batch_size = 128
add_title_as_prefix = true
concurrent_request_limit = 2

[ingestion]
provider = "unstructured_local"
strategy = "auto"
chunking_strategy = "by_title"
new_after_n_chars = 512
max_characters = 1_024
combine_under_n_chars = 128
overlap = 20

[orchestration]
provider = "hatchet"
8 changes: 0 additions & 8 deletions py/core/configs/local_llm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,3 @@ provider = "postgres"

[orchestration]
provider = "simple"


[ingestion]
vision_img_model = "ollama/llama3.2-vision"
vision_pdf_model = "ollama/llama3.2-vision"

[ingestion.extra_parsers]
pdf = "zerox"
62 changes: 62 additions & 0 deletions py/core/configs/local_llm_lm_studio.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[agent]
provider = "litellm"
system_instruction_name = "rag_agent"
tool_names = ["search"]

[agent.generation_config]
model = "lm_studio/llama-3.2-3b-instruct"

[completion]
provider = "litellm"
concurrent_request_limit = 1

[completion.generation_config]
model = "lm_studio/llama-3.2-3b-instruct"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[embedding]
provider = "litellm"
base_model = "lm_studio/text-embedding-nomic-embed-text-v1.5-embedding"
base_dimension = 1_024
batch_size = 128
add_title_as_prefix = true
concurrent_request_limit = 2

[database]
provider = "postgres"

[database.kg_creation_settings]
kg_entity_description_prompt = "graphrag_entity_description"
kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot"
entity_types = [] # if empty, all entities are extracted
relation_types = [] # if empty, all relations are extracted
fragment_merge_count = 4 # number of fragments to merge into a single extraction
max_knowledge_triples = 100
max_description_input_length = 65536
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" } # and other params, model used for triplet extraction

[database.kg_entity_deduplication_settings]
kg_entity_deduplication_type = "by_name"
kg_entity_deduplication_prompt = "graphrag_entity_deduplication"
max_description_input_length = 65536
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" } # and other params, model used for deduplication

[database.kg_enrichment_settings]
community_reports_prompt = "graphrag_community_reports"
max_summary_input_length = 65536
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" } # and other params, model used for node description and graph clustering
leiden_params = {}

[database.kg_search_settings]
entities_level = "document" # set to collection if you've run deduplication
map_system_prompt = "graphrag_map_system"
reduce_system_prompt = "graphrag_reduce_system"
generation_config = { model = "lm_studio/llama-3.2-3b-instruct" }


[orchestration]
provider = "simple"
61 changes: 61 additions & 0 deletions py/core/configs/local_llm_ollama.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]

[agent.generation_config]
model = "ollama/llama3.1"

[completion]
provider = "litellm"
concurrent_request_limit = 1

[completion.generation_config]
model = "ollama/llama3.1"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[embedding]
provider = "ollama"
base_model = "mxbai-embed-large"
base_dimension = 1_024
batch_size = 128
add_title_as_prefix = true
concurrent_request_limit = 2

[database]
provider = "postgres"

[database.kg_creation_settings]
kg_entity_description_prompt = "graphrag_entity_description"
kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot"
entity_types = [] # if empty, all entities are extracted
relation_types = [] # if empty, all relations are extracted
fragment_merge_count = 4 # number of fragments to merge into a single extraction
max_knowledge_triples = 100
max_description_input_length = 65536
generation_config = { model = "ollama/llama3.1" } # and other params, model used for triplet extraction

[database.kg_entity_deduplication_settings]
kg_entity_deduplication_type = "by_name"
kg_entity_deduplication_prompt = "graphrag_entity_deduplication"
max_description_input_length = 65536
generation_config = { model = "ollama/llama3.1" } # and other params, model used for deduplication

[database.kg_enrichment_settings]
community_reports_prompt = "graphrag_community_reports"
max_summary_input_length = 65536
generation_config = { model = "ollama/llama3.1" } # and other params, model used for node description and graph clustering
leiden_params = {}

[database.kg_search_settings]
entities_level = "document" # set to collection if you've run deduplication
map_system_prompt = "graphrag_map_system"
reduce_system_prompt = "graphrag_reduce_system"
generation_config = { model = "ollama/llama3.1" }


[orchestration]
provider = "simple"
17 changes: 7 additions & 10 deletions py/core/providers/embeddings/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ def __init__(
) -> None:
super().__init__(config)

# Allow LiteLLM to automatically drop parameters that are not supported by the model
litellm.drop_params = True

self.litellm_embedding = embedding
self.litellm_aembedding = aembedding

provider = config.provider
if not provider:
if not config.provider:
raise ValueError(
"Must set provider in order to initialize `LiteLLMEmbeddingProvider`."
)
if provider != "litellm":
if config.provider != "litellm":
raise ValueError(
"LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
)
Expand All @@ -42,18 +44,13 @@ def __init__(
)

self.base_model = config.base_model
if "amazon" in self.base_model:
logger.warn("Amazon embedding model detected, dropping params")
litellm.drop_params = True
self.base_dimension = config.base_dimension

def _get_embedding_kwargs(self, **kwargs):
embedding_kwargs = {
return {
"model": self.base_model,
"dimensions": self.base_dimension,
}
embedding_kwargs.update(kwargs)
return embedding_kwargs
} | kwargs

async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
texts = task["texts"]
Expand Down
Loading
Loading