Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/encoding model config #1527

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241212081432148181.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Respect encoding_model option"
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241216223430364521.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix encoding model config parsing"
}
29 changes: 24 additions & 5 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def hydrate_llm_params(
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model

if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
Expand All @@ -106,6 +107,7 @@ def hydrate_llm_params(
organization=reader.str("organization") or base.organization,
proxy=reader.str("proxy") or base.proxy,
model=reader.str("model") or base.model,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
temperature=reader.float(Fragment.temperature) or base.temperature,
top_p=reader.float(Fragment.top_p) or base.top_p,
Expand Down Expand Up @@ -155,6 +157,7 @@ def hydrate_embeddings_params(
api_proxy = reader.str("proxy") or base.proxy
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model

if api_key is None and not _is_azure(api_type):
raise ApiKeyMissingError(embedding=True)
Expand All @@ -176,6 +179,7 @@ def hydrate_embeddings_params(
organization=api_organization,
proxy=api_proxy,
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
encoding_model=encoding_model,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
audience=audience,
Expand Down Expand Up @@ -217,6 +221,9 @@ def hydrate_parallelization_params(
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
fallback_oai_proxy = reader.str(Fragment.api_proxy)
global_encoding_model = (
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
)

with reader.envvar_prefix(Section.llm):
with reader.use(values.get("llm")):
Expand All @@ -231,6 +238,9 @@ def hydrate_parallelization_params(
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
Expand All @@ -252,6 +262,7 @@ def hydrate_parallelization_params(
proxy=api_proxy,
type=llm_type,
model=reader.str(Fragment.model) or defs.LLM_MODEL,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
temperature=reader.float(Fragment.temperature)
or defs.LLM_TEMPERATURE,
Expand Down Expand Up @@ -396,12 +407,15 @@ def hydrate_parallelization_params(
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
if group_by_columns is None:
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)
with (
reader.envvar_prefix(Section.snapshot),
Expand All @@ -428,6 +442,9 @@ def hydrate_parallelization_params(
if max_gleanings is not None
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

entity_extraction_model = EntityExtractionConfig(
llm=hydrate_llm_params(entity_extraction_config, llm_model),
Expand All @@ -440,7 +457,7 @@ def hydrate_parallelization_params(
max_gleanings=max_gleanings,
prompt=reader.str("prompt", Fragment.prompt_file),
strategy=entity_extraction_config.get("strategy"),
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)

claim_extraction_config = values.get("claim_extraction") or {}
Expand All @@ -452,6 +469,9 @@ def hydrate_parallelization_params(
max_gleanings = (
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
claim_extraction_model = ClaimExtractionConfig(
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
llm=hydrate_llm_params(claim_extraction_config, llm_model),
Expand All @@ -462,7 +482,7 @@ def hydrate_parallelization_params(
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
prompt=reader.str("prompt", Fragment.prompt_file),
max_gleanings=max_gleanings,
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)

community_report_config = values.get("community_reports") or {}
Expand Down Expand Up @@ -603,7 +623,6 @@ def hydrate_parallelization_params(
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)

encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
skip_workflows = reader.list("skip_workflows") or []

return GraphRagConfig(
Expand All @@ -626,7 +645,7 @@ def hydrate_parallelization_params(
summarize_descriptions=summarize_descriptions_model,
umap=umap_model,
cluster_graph=cluster_graph_model,
encoding_model=encoding_model,
encoding_model=global_encoding_model,
natoverse marked this conversation as resolved.
Show resolved Hide resolved
skip_workflows=skip_workflows,
local_search=local_search_model,
global_search=global_search_model,
Expand Down
4 changes: 2 additions & 2 deletions graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel):
default=None, description="The encoding model to use."
)

def resolved_strategy(self, encoding_model: str) -> dict:
def resolved_strategy(self, encoding_model: str | None) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.operations.chunk_text import ChunkStrategyType

Expand All @@ -36,5 +36,5 @@ def resolved_strategy(self, encoding_model: str) -> dict:
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
}
4 changes: 2 additions & 2 deletions graphrag/config/models/claim_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig):
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.operations.extract_covariates import (
ExtractClaimsStrategyType,
Expand All @@ -52,5 +52,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
}
4 changes: 2 additions & 2 deletions graphrag/config/models/entity_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig):
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved entity extraction strategy."""
from graphrag.index.operations.extract_entities import (
ExtractEntityStrategyType,
Expand All @@ -49,6 +49,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
else None,
"max_gleanings": self.max_gleanings,
# It's prechunked in create_base_text_units
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
"prechunked": True,
}
Loading