Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/encoding model config #1527

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: include encoding_model option when initializing LLMParameters
  • Loading branch information
nreinartz committed Dec 12, 2024
commit 35582aa955c127ce7e4376484e57570a7dbace68
14 changes: 10 additions & 4 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ def hydrate_llm_params(
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model

if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
@@ -106,6 +107,7 @@ def hydrate_llm_params(
organization=reader.str("organization") or base.organization,
proxy=reader.str("proxy") or base.proxy,
model=reader.str("model") or base.model,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
temperature=reader.float(Fragment.temperature) or base.temperature,
top_p=reader.float(Fragment.top_p) or base.top_p,
@@ -155,6 +157,7 @@ def hydrate_embeddings_params(
api_proxy = reader.str("proxy") or base.proxy
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model

if api_key is None and not _is_azure(api_type):
raise ApiKeyMissingError(embedding=True)
@@ -176,6 +179,7 @@ def hydrate_embeddings_params(
organization=api_organization,
proxy=api_proxy,
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
encoding_model=encoding_model,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
audience=audience,
@@ -217,7 +221,8 @@ def hydrate_parallelization_params(
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
fallback_oai_proxy = reader.str(Fragment.api_proxy)

global_encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL

with reader.envvar_prefix(Section.llm):
with reader.use(values.get("llm")):
llm_type = reader.str(Fragment.type)
@@ -231,6 +236,7 @@ def hydrate_parallelization_params(
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = reader.str(Fragment.encoding_model) or global_encoding_model

if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
@@ -252,6 +258,7 @@ def hydrate_parallelization_params(
proxy=api_proxy,
type=llm_type,
model=reader.str(Fragment.model) or defs.LLM_MODEL,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
temperature=reader.float(Fragment.temperature)
or defs.LLM_TEMPERATURE,
@@ -603,9 +610,8 @@ def hydrate_parallelization_params(
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)

encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
skip_workflows = reader.list("skip_workflows") or []

return GraphRagConfig(
root_dir=root_dir,
llm=llm_model,
@@ -626,7 +632,7 @@ def hydrate_parallelization_params(
summarize_descriptions=summarize_descriptions_model,
umap=umap_model,
cluster_graph=cluster_graph_model,
encoding_model=encoding_model,
encoding_model=global_encoding_model,
natoverse marked this conversation as resolved.
Show resolved Hide resolved
skip_workflows=skip_workflows,
local_search=local_search_model,
global_search=global_search_model,