diff --git a/core/llm/constants.ts b/core/llm/constants.ts index 19b067d154..0975fc142a 100644 --- a/core/llm/constants.ts +++ b/core/llm/constants.ts @@ -1,5 +1,5 @@ const DEFAULT_MAX_TOKENS = 4096; -const DEFAULT_CONTEXT_LENGTH = 8192; +const DEFAULT_CONTEXT_LENGTH = 32_768; const DEFAULT_TEMPERATURE = 0.5; const DEFAULT_ARGS = { diff --git a/core/llm/index.test.ts b/core/llm/index.test.ts index 7584e21223..1e7031f508 100644 --- a/core/llm/index.test.ts +++ b/core/llm/index.test.ts @@ -1,6 +1,10 @@ import { ChatMessage, LLMOptions } from ".."; +import { allModelProviders } from "@continuedev/llm-info"; +import { LlmInfo } from "@continuedev/llm-info/dist/types"; import { BaseLLM } from "."; +import { DEFAULT_CONTEXT_LENGTH } from "./constants"; +import { LLMClasses } from "./llms"; import { LLMLogger } from "./logger"; class DummyLLM extends BaseLLM { @@ -140,4 +144,31 @@ describe("BaseLLM", () => { describe("*streamChat", () => { // TODO: Implement tests for *streamChat method }); + + describe("default context length", () => { + allModelProviders.map((modelProvider) => { + const LLMClass = LLMClasses.find( + (llm) => llm.providerName === modelProvider.id, + ); + if (!LLMClass) { + throw new Error(`did not find LLM provider for ${modelProvider.id}`); + } + const testContextLength = (llmInfo: LlmInfo) => () => { + const llm = new LLMClass({ model: llmInfo.model }); + if (llmInfo.contextLength) { + expect(llm.contextLength).toEqual(llmInfo.contextLength); + } else { + expect(llm.contextLength).toEqual(DEFAULT_CONTEXT_LENGTH); + } + }; + describe(`${modelProvider.id}`, () => { + modelProvider.models.forEach((llmInfo) => { + test( + `should have correct context length for ${llmInfo.model}`, + testContextLength(llmInfo), + ); + }); + }); + }); + }); }); diff --git a/core/llm/llms/Anthropic.ts b/core/llm/llms/Anthropic.ts index 8d1d3a67f3..ff0f78c060 100644 --- a/core/llm/llms/Anthropic.ts +++ b/core/llm/llms/Anthropic.ts @@ -8,7 +8,6 @@ class Anthropic extends BaseLLM { static providerName = "anthropic"; static defaultOptions: Partial = { model: "claude-3-5-sonnet-latest", - contextLength: 200_000, completionOptions: { model: "claude-3-5-sonnet-latest", maxTokens: 8192, diff --git a/core/llm/llms/Bedrock.ts b/core/llm/llms/Bedrock.ts index ae595f3f33..30ce69e7ab 100644 --- a/core/llm/llms/Bedrock.ts +++ b/core/llm/llms/Bedrock.ts @@ -48,7 +48,6 @@ class Bedrock extends BaseLLM { static defaultOptions: Partial = { region: "us-east-1", model: "anthropic.claude-3-sonnet-20240229-v1:0", - contextLength: 200_000, profile: "bedrock", };