diff --git a/.semversioner/next-release/patch-20241212081432148181.json b/.semversioner/next-release/patch-20241212081432148181.json new file mode 100644 index 000000000..a50029f04 --- /dev/null +++ b/.semversioner/next-release/patch-20241212081432148181.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Respect encoding_model option" +} diff --git a/.semversioner/next-release/patch-20241216223430364521.json b/.semversioner/next-release/patch-20241216223430364521.json new file mode 100644 index 000000000..79c734e1d --- /dev/null +++ b/.semversioner/next-release/patch-20241216223430364521.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fix encoding model config parsing" +} diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 0e0815515..c1da51126 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -85,6 +85,7 @@ def hydrate_llm_params( deployment_name = ( reader.str(Fragment.deployment_name) or base.deployment_name ) + encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model if api_key is None and not _is_azure(llm_type): raise ApiKeyMissingError @@ -106,6 +107,7 @@ def hydrate_llm_params( organization=reader.str("organization") or base.organization, proxy=reader.str("proxy") or base.proxy, model=reader.str("model") or base.model, + encoding_model=encoding_model, max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens, temperature=reader.float(Fragment.temperature) or base.temperature, top_p=reader.float(Fragment.top_p) or base.top_p, @@ -155,6 +157,7 @@ def hydrate_embeddings_params( api_proxy = reader.str("proxy") or base.proxy audience = reader.str(Fragment.audience) or base.audience deployment_name = reader.str(Fragment.deployment_name) + encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model if api_key is None and not _is_azure(api_type): raise ApiKeyMissingError(embedding=True) @@ -176,6 +179,7 @@ def hydrate_embeddings_params( organization=api_organization, proxy=api_proxy, model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL, + encoding_model=encoding_model, request_timeout=reader.float(Fragment.request_timeout) or defs.LLM_REQUEST_TIMEOUT, audience=audience, @@ -217,6 +221,9 @@ def hydrate_parallelization_params( fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version fallback_oai_proxy = reader.str(Fragment.api_proxy) + global_encoding_model = ( + reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL + ) with reader.envvar_prefix(Section.llm): with reader.use(values.get("llm")): @@ -231,6 +238,9 @@ def hydrate_parallelization_params( api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy audience = reader.str(Fragment.audience) deployment_name = reader.str(Fragment.deployment_name) + encoding_model = ( + reader.str(Fragment.encoding_model) or global_encoding_model + ) if api_key is None and not _is_azure(llm_type): raise ApiKeyMissingError @@ -252,6 +262,7 @@ def hydrate_parallelization_params( proxy=api_proxy, type=llm_type, model=reader.str(Fragment.model) or defs.LLM_MODEL, + encoding_model=encoding_model, max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS, temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE, @@ -396,12 +407,15 @@ def hydrate_parallelization_params( group_by_columns = reader.list("group_by_columns", "BY_COLUMNS") if group_by_columns is None: group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS + encoding_model = ( + reader.str(Fragment.encoding_model) or global_encoding_model + ) chunks_model = ChunkingConfig( size=reader.int("size") or defs.CHUNK_SIZE, overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, group_by_columns=group_by_columns, - encoding_model=reader.str(Fragment.encoding_model), + encoding_model=encoding_model, ) with ( reader.envvar_prefix(Section.snapshot), @@ -428,6 +442,9 @@ def hydrate_parallelization_params( if max_gleanings is not None else defs.ENTITY_EXTRACTION_MAX_GLEANINGS ) + encoding_model = ( + reader.str(Fragment.encoding_model) or global_encoding_model + ) entity_extraction_model = EntityExtractionConfig( llm=hydrate_llm_params(entity_extraction_config, llm_model), @@ -440,7 +457,7 @@ def hydrate_parallelization_params( max_gleanings=max_gleanings, prompt=reader.str("prompt", Fragment.prompt_file), strategy=entity_extraction_config.get("strategy"), - encoding_model=reader.str(Fragment.encoding_model), + encoding_model=encoding_model, ) claim_extraction_config = values.get("claim_extraction") or {} @@ -452,6 +469,9 @@ def hydrate_parallelization_params( max_gleanings = ( max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS ) + encoding_model = ( + reader.str(Fragment.encoding_model) or global_encoding_model + ) claim_extraction_model = ClaimExtractionConfig( enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED, llm=hydrate_llm_params(claim_extraction_config, llm_model), @@ -462,7 +482,7 @@ def hydrate_parallelization_params( description=reader.str("description") or defs.CLAIM_DESCRIPTION, prompt=reader.str("prompt", Fragment.prompt_file), max_gleanings=max_gleanings, - encoding_model=reader.str(Fragment.encoding_model), + encoding_model=encoding_model, ) community_report_config = values.get("community_reports") or {} @@ -603,7 +623,6 @@ def hydrate_parallelization_params( or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS, ) - encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL skip_workflows = reader.list("skip_workflows") or [] return GraphRagConfig( @@ -626,7 +645,7 @@ def hydrate_parallelization_params( summarize_descriptions=summarize_descriptions_model, umap=umap_model, cluster_graph=cluster_graph_model, - encoding_model=encoding_model, + encoding_model=global_encoding_model, skip_workflows=skip_workflows, local_search=local_search_model, global_search=global_search_model, diff --git a/graphrag/config/models/chunking_config.py b/graphrag/config/models/chunking_config.py index a2b40017f..f2a39bf1b 100644 --- a/graphrag/config/models/chunking_config.py +++ b/graphrag/config/models/chunking_config.py @@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel): default=None, description="The encoding model to use." ) - def resolved_strategy(self, encoding_model: str) -> dict: + def resolved_strategy(self, encoding_model: str | None) -> dict: """Get the resolved chunking strategy.""" from graphrag.index.operations.chunk_text import ChunkStrategyType @@ -36,5 +36,5 @@ def resolved_strategy(self, encoding_model: str) -> dict: "chunk_size": self.size, "chunk_overlap": self.overlap, "group_by_columns": self.group_by_columns, - "encoding_name": self.encoding_model or encoding_model, + "encoding_name": encoding_model or self.encoding_model, } diff --git a/graphrag/config/models/claim_extraction_config.py b/graphrag/config/models/claim_extraction_config.py index 716f64480..fcecd905d 100644 --- a/graphrag/config/models/claim_extraction_config.py +++ b/graphrag/config/models/claim_extraction_config.py @@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig): default=None, description="The encoding model to use." ) - def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: + def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict: """Get the resolved claim extraction strategy.""" from graphrag.index.operations.extract_covariates import ( ExtractClaimsStrategyType, @@ -52,5 +52,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: else None, "claim_description": self.description, "max_gleanings": self.max_gleanings, - "encoding_name": self.encoding_model or encoding_model, + "encoding_name": encoding_model or self.encoding_model, } diff --git a/graphrag/config/models/entity_extraction_config.py b/graphrag/config/models/entity_extraction_config.py index 40f155d0e..9a7b07829 100644 --- a/graphrag/config/models/entity_extraction_config.py +++ b/graphrag/config/models/entity_extraction_config.py @@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig): default=None, description="The encoding model to use." ) - def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: + def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict: """Get the resolved entity extraction strategy.""" from graphrag.index.operations.extract_entities import ( ExtractEntityStrategyType, @@ -49,6 +49,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: else None, "max_gleanings": self.max_gleanings, # It's prechunked in create_base_text_units - "encoding_name": self.encoding_model or encoding_model, + "encoding_name": encoding_model or self.encoding_model, "prechunked": True, }