Skip to content

Commit

Permalink
Fix/encoding model config (#1527)
Browse files Browse the repository at this point in the history
* fix: include encoding_model option when initializing LLMParameters

* chore: add semver patch description

* Fix encoding model parsing

* Fix unit tests

---------

Co-authored-by: Nico Reinartz <[email protected]>
  • Loading branch information
AlonsoGuevara and nreinartz authored Dec 17, 2024
1 parent 329b83c commit f7cd155
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 11 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241212081432148181.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Respect encoding_model option"
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241216223430364521.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix encoding model config parsing"
}
29 changes: 24 additions & 5 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def hydrate_llm_params(
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model

if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
Expand All @@ -106,6 +107,7 @@ def hydrate_llm_params(
organization=reader.str("organization") or base.organization,
proxy=reader.str("proxy") or base.proxy,
model=reader.str("model") or base.model,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
temperature=reader.float(Fragment.temperature) or base.temperature,
top_p=reader.float(Fragment.top_p) or base.top_p,
Expand Down Expand Up @@ -155,6 +157,7 @@ def hydrate_embeddings_params(
api_proxy = reader.str("proxy") or base.proxy
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model

if api_key is None and not _is_azure(api_type):
raise ApiKeyMissingError(embedding=True)
Expand All @@ -176,6 +179,7 @@ def hydrate_embeddings_params(
organization=api_organization,
proxy=api_proxy,
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
encoding_model=encoding_model,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
audience=audience,
Expand Down Expand Up @@ -217,6 +221,9 @@ def hydrate_parallelization_params(
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
fallback_oai_proxy = reader.str(Fragment.api_proxy)
global_encoding_model = (
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
)

with reader.envvar_prefix(Section.llm):
with reader.use(values.get("llm")):
Expand All @@ -231,6 +238,9 @@ def hydrate_parallelization_params(
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
Expand All @@ -252,6 +262,7 @@ def hydrate_parallelization_params(
proxy=api_proxy,
type=llm_type,
model=reader.str(Fragment.model) or defs.LLM_MODEL,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
temperature=reader.float(Fragment.temperature)
or defs.LLM_TEMPERATURE,
Expand Down Expand Up @@ -396,12 +407,15 @@ def hydrate_parallelization_params(
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
if group_by_columns is None:
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)
with (
reader.envvar_prefix(Section.snapshot),
Expand All @@ -428,6 +442,9 @@ def hydrate_parallelization_params(
if max_gleanings is not None
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

entity_extraction_model = EntityExtractionConfig(
llm=hydrate_llm_params(entity_extraction_config, llm_model),
Expand All @@ -440,7 +457,7 @@ def hydrate_parallelization_params(
max_gleanings=max_gleanings,
prompt=reader.str("prompt", Fragment.prompt_file),
strategy=entity_extraction_config.get("strategy"),
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)

claim_extraction_config = values.get("claim_extraction") or {}
Expand All @@ -452,6 +469,9 @@ def hydrate_parallelization_params(
max_gleanings = (
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
claim_extraction_model = ClaimExtractionConfig(
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
llm=hydrate_llm_params(claim_extraction_config, llm_model),
Expand All @@ -462,7 +482,7 @@ def hydrate_parallelization_params(
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
prompt=reader.str("prompt", Fragment.prompt_file),
max_gleanings=max_gleanings,
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)

community_report_config = values.get("community_reports") or {}
Expand Down Expand Up @@ -603,7 +623,6 @@ def hydrate_parallelization_params(
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)

encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
skip_workflows = reader.list("skip_workflows") or []

return GraphRagConfig(
Expand All @@ -626,7 +645,7 @@ def hydrate_parallelization_params(
summarize_descriptions=summarize_descriptions_model,
umap=umap_model,
cluster_graph=cluster_graph_model,
encoding_model=encoding_model,
encoding_model=global_encoding_model,
skip_workflows=skip_workflows,
local_search=local_search_model,
global_search=global_search_model,
Expand Down
4 changes: 2 additions & 2 deletions graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel):
default=None, description="The encoding model to use."
)

def resolved_strategy(self, encoding_model: str) -> dict:
def resolved_strategy(self, encoding_model: str | None) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.operations.chunk_text import ChunkStrategyType

Expand All @@ -36,5 +36,5 @@ def resolved_strategy(self, encoding_model: str) -> dict:
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
}
4 changes: 2 additions & 2 deletions graphrag/config/models/claim_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig):
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.operations.extract_covariates import (
ExtractClaimsStrategyType,
Expand All @@ -52,5 +52,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
}
4 changes: 2 additions & 2 deletions graphrag/config/models/entity_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig):
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved entity extraction strategy."""
from graphrag.index.operations.extract_entities import (
ExtractEntityStrategyType,
Expand All @@ -49,6 +49,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
else None,
"max_gleanings": self.max_gleanings,
# It's prechunked in create_base_text_units
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
"prechunked": True,
}

0 comments on commit f7cd155

Please sign in to comment.