Skip to content

Commit

Permalink
DLLMs: dynamic parameters system
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Dec 23, 2024
1 parent b4a586f commit 6e85171
Show file tree
Hide file tree
Showing 28 changed files with 302 additions and 168 deletions.
4 changes: 2 additions & 2 deletions src/apps/chat/components/composer/Composer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import TelegramIcon from '@mui/icons-material/Telegram';

import { useChatAutoSuggestAttachmentPrompts, useChatMicTimeoutMsValue } from '../../store-app-chat';

import type { DOpenAILLMOptions } from '~/modules/llms/vendors/openai/openai.vendor';
import { useAgiAttachmentPrompts } from '~/modules/aifn/agiattachmentprompts/useAgiAttachmentPrompts';
import { useBrowseCapability } from '~/modules/browse/store-module-browsing';

Expand All @@ -36,6 +35,7 @@ import { copyToClipboard, supportsClipboardRead } from '~/common/util/clipboardU
import { createTextContentFragment, DMessageAttachmentFragment, DMessageContentFragment, duplicateDMessageFragmentsNoVoid } from '~/common/stores/chat/chat.fragments';
import { estimateTextTokens, glueForMessageTokens, marshallWrapDocFragments } from '~/common/stores/chat/chat.tokens';
import { getConversation, isValidConversation, useChatStore } from '~/common/stores/chat/store-chats';
import { getModelParameterValueOrThrow } from '~/common/stores/llms/llms.parameters';
import { launchAppCall } from '~/common/app.routes';
import { lineHeightTextareaMd } from '~/common/app.theme';
import { optimaOpenPreferences } from '~/common/layout/optima/useOptima';
Expand Down Expand Up @@ -221,7 +221,7 @@ export function Composer(props: {
if (props.chatLLM && tokensComposer > 0)
tokensComposer += glueForMessageTokens(props.chatLLM);
const tokensHistory = _historyTokenCount;
const tokensResponseMax = (props.chatLLM?.options as DOpenAILLMOptions /* FIXME: BIG ASSUMPTION */)?.llmResponseTokens || 0;
const tokensResponseMax = getModelParameterValueOrThrow('llmResponseTokens', props.chatLLM?.initialParameters, props.chatLLM?.userParameters, 0) ?? 0;
const tokenLimit = props.chatLLM?.contextTokens || 0;
const tokenChatPricing = props.chatLLM?.pricing?.chat;

Expand Down
155 changes: 155 additions & 0 deletions src/common/stores/llms/llms.parameters.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/**
* Parameter Registry and Model Configuration
*
* This module provides a type-safe parameter management system for LLM models.
* It handles parameter definitions, validation, and runtime values while
* maintaining strict type safety throughout the application.
*
* Key concepts:
* - ParameterRegistry: Defines all possible parameters and their constraints
* - ParameterSpec: Model-specific parameter configurations
* - ParameterValues: Runtime parameter values (initial and user overrides)
*
* @module llms
*/


// shared constants
export const FALLBACK_LLM_PARAM_RESPONSE_TOKENS = 4096;
export const FALLBACK_LLM_PARAM_TEMPERATURE = 0.5;
// const FALLBACK_LLM_PARAM_REF_UNKNOWN = 'unknown_id';


/// Registry

export const DModelParameterRegistry = {

/// Common parameters, normally available in all models ///
// Note: we still use pre-v2 names for compatibility and ease of migration

llmRef: {
label: 'Model ID',
type: 'string' as const,
description: 'Upstream model reference',
hidden: true,
} as const,

llmResponseTokens: {
label: 'Maximum Tokens',
type: 'integer' as const,
description: 'Maximum length of generated text',
nullable: {
meaning: 'Explicitly avoid sending max_tokens to upstream API',
} as const,
requiredFallback: FALLBACK_LLM_PARAM_RESPONSE_TOKENS, // if required and not specified/user overridden, use this value
} as const,

llmTemperature: {
label: 'Temperature',
type: 'float' as const,
description: 'Controls randomness in the output',
range: [0.0, 2.0] as const,
requiredFallback: FALLBACK_LLM_PARAM_TEMPERATURE,
} as const,

/// Extended parameters, specific to certain models/vendors

llmTopP: {
label: 'Top P',
type: 'float' as const,
description: 'Nucleus sampling threshold',
range: [0.0, 1.0] as const,
requiredFallback: 1.0,
incompatibleWith: ['temperature'] as const,
} as const,

'vnd.oai.reasoning_quantity': {
label: 'Reasoning Quantity',
type: 'enum' as const,
description: 'Controls reasoning depth (OpenAI specific)',
values: ['low', 'med', 'high'] as const,
requiredFallback: 'med',
} as const,

} as const;


/// Types

export interface DModelParameterSpec<T extends DModelParameterId> {
paramId: T;
required?: boolean;
hidden?: boolean;
upstreamDefault?: DModelParameterValue<T>;
}

export type DModelParameterValues = {
[K in DModelParameterId]?: DModelParameterValue<K>;
}

export type DModelParameterId = keyof typeof DModelParameterRegistry; // max_tokens, temperature, top_p, vnd.oai.reasoning_quantity, ...
// type _ExtendedParameterId = keyof typeof _ExtendedParameterRegistry;

type DModelParameterValue<T extends DModelParameterId> =
typeof DModelParameterRegistry[T]['type'] extends 'integer' ? number | null :
typeof DModelParameterRegistry[T]['type'] extends 'float' ? number :
typeof DModelParameterRegistry[T]['type'] extends 'string' ? string :
typeof DModelParameterRegistry[T]['type'] extends 'boolean' ? boolean :
typeof DModelParameterRegistry[T]['type'] extends 'enum'
? (T extends { type: 'enum', values: readonly (infer U)[] } ? U : never)
: never;


/// Utility Functions

const _requiredParamId: DModelParameterId[] = ['llmRef', 'llmResponseTokens', 'llmTemperature'] as const;

export function getAllModelParameterValues(initialParameters: undefined | DModelParameterValues, userParameters?: DModelParameterValues): DModelParameterValues {

// fallback values
const fallbackParameters: DModelParameterValues = {};
for (const requiredParamId of _requiredParamId) {
if ('requiredFallback' in DModelParameterRegistry[requiredParamId])
fallbackParameters[requiredParamId] = DModelParameterRegistry[requiredParamId].requiredFallback as DModelParameterValue<typeof requiredParamId>;
}

// accumulate initial and user values
return {
...fallbackParameters,
...initialParameters,
...userParameters,
};
}


export function getModelParameterValueOrThrow<T extends DModelParameterId>(
paramId: T,
initialValues: undefined | DModelParameterValues,
userValues: undefined | DModelParameterValues,
fallbackValue: undefined | DModelParameterValue<T>,
): DModelParameterValue<T> {

// check user values first
if (userValues && paramId in userValues) {
const value = userValues[paramId];
if (value !== undefined) return value;
}

// then check initial values
if (initialValues && paramId in initialValues) {
const value = initialValues[paramId];
if (value !== undefined) return value;
}

// then try provided fallback
if (fallbackValue !== undefined) return fallbackValue;

// finally the global registry fallback
const paramDef = DModelParameterRegistry[paramId];
if ('requiredFallback' in paramDef && paramDef.requiredFallback !== undefined)
return paramDef.requiredFallback as DModelParameterValue<T>;

// if we're here, we couldn't find a value
// [DANGER] VERY DANGEROUS, but shall NEVER happen
throw new Error(`getModelParameterValue: missing required parameter '${paramId}'`);
}
14 changes: 10 additions & 4 deletions src/common/stores/llms/llms.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import type { ModelVendorId } from '~/modules/llms/vendors/vendors.registry';

import type { DModelParameterId, DModelParameterSpec, DModelParameterValues } from './llms.parameters';
import type { DModelPricing } from './llms.pricing';
import type { DModelsServiceId } from './modelsservice.types';

Expand All @@ -17,7 +18,7 @@ export type DLLMId = string;
/**
* Large Language Model - description and configuration (data object, stored)
*/
export interface DLLM<TLLMOptions = Record<string, any>> {
export interface DLLM {
id: DLLMId;

// editable properties (kept on update, if isEdited)
Expand All @@ -26,7 +27,6 @@ export interface DLLM<TLLMOptions = Record<string, any>> {
updated?: number | 0;
description: string;
hidden: boolean; // hidden from UI selectors
isEdited?: boolean; // user has edited the soft properties

// hard properties (overwritten on update)
contextTokens: number | null; // null: must assume it's unknown
Expand All @@ -36,12 +36,18 @@ export interface DLLM<TLLMOptions = Record<string, any>> {
benchmark?: { cbaElo?: number, cbaMmlu?: number }; // benchmark values
pricing?: DModelPricing;

// parameters system
parameterSpecs: DModelParameterSpec<DModelParameterId>[];
initialParameters: DModelParameterValues;

// references
sId: DModelsServiceId;
vId: ModelVendorId;

// llm-specific
options: { llmRef: string } & Partial<TLLMOptions>;
// user edited properties - if not undefined/missing, they override the others
userLabel?: string;
userHidden?: boolean;
userParameters?: DModelParameterValues; // user has set these parameters
}


Expand Down
68 changes: 52 additions & 16 deletions src/common/stores/llms/store-llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import { getLlmCostForTokens, portModelPricingV2toV3 } from './llms.pricing';

interface LlmsState {

llms: DLLM<any>[];
llms: DLLM[];

sources: DModelsService<any>[];

Expand All @@ -32,7 +32,7 @@ interface LlmsActions {
removeLLM: (id: DLLMId) => void;
rerankLLMsByServices: (serviceIdOrder: DModelsServiceId[]) => void;
updateLLM: (id: DLLMId, partial: Partial<DLLM>) => void;
updateLLMOptions: <TLLMOptions>(id: DLLMId, partialOptions: Partial<TLLMOptions>) => void;
updateLLMUserParameters: (id: DLLMId, partial: Partial<DLLM['userParameters']>) => void;

addService: (service: DModelsService) => void;
removeService: (id: DModelsServiceId) => void;
Expand Down Expand Up @@ -75,13 +75,9 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
const existing = state.llms.find(m => m.id === llm.id);
return !existing ? llm : {
...llm,
label: existing.label, // keep label
hidden: existing.hidden, // keep hidden - FIXME: this must go, as we don't know if the underlying changed or the user changed it
options: {
// keep custom configurations, but overwrite as the new could have massively improved params
...existing.options,
...llm.options,
},
...(existing.userLabel !== undefined ? { userLabel: existing.userLabel } : {}),
...(existing.userHidden !== undefined ? { userHidden: existing.userHidden } : {}),
...(existing.userParameters !== undefined ? { userParameters: existing.userParameters } : {}),
};
});
}
Expand Down Expand Up @@ -136,11 +132,11 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
),
})),

updateLLMOptions: <TLLMOptions>(id: DLLMId, partialOptions: Partial<TLLMOptions>) =>
updateLLMUserParameters: (id: DLLMId, partialUserParameters: Partial<DLLM['userParameters']>) =>
set(state => ({
llms: state.llms.map((llm: DLLM): DLLM =>
llm.id === id
? { ...llm, options: { ...llm.options, ...partialOptions } }
? { ...llm, userParameters: { ...llm.userParameters, ...partialUserParameters } }
: llm,
),
})),
Expand Down Expand Up @@ -202,8 +198,9 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
* 1: adds maxOutputTokens (default to half of contextTokens)
* 2: large changes on all LLMs, and reset chat/fast/func LLMs
* 3: big-AGI v2
* 4: migrate .options to .initialParameters/.userParameters
*/
version: 3,
version: 4,
migrate: (_state: any, fromVersion: number): LlmsState => {

if (!_state) return _state;
Expand All @@ -227,8 +224,22 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
}

// 2 -> 3: big-AGI v2: update all models for pricing info
if (fromVersion < 3)
state.llms.forEach(portModelPricingV2toV3);
if (fromVersion < 3) {
try {
state.llms.forEach(portModelPricingV2toV3);
} catch (error) {
// ... if there's any error, ignore - shall be okay
}
}

// 3 -> 4: migrate .options to .initialParameters/.userParameters
if (fromVersion < 4) {
try {
state.llms.forEach(_port_V3Options_to_V4Parameters_inline);
} catch (error) {
// ... if there's any error, ignore - shall be okay
}
}

return state;
},
Expand Down Expand Up @@ -269,8 +280,8 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
));


export function findLLMOrThrow<TLLMOptions>(llmId: DLLMId): DLLM<TLLMOptions> {
const llm: DLLM<TLLMOptions> | undefined = llmsStoreState().llms.find(llm => llm.id === llmId);
export function findLLMOrThrow(llmId: DLLMId): DLLM {
const llm: DLLM | undefined = llmsStoreState().llms.find(llm => llm.id === llmId);
if (!llm)
throw new Error(`Large Language Model ${llmId} not found`);
return llm;
Expand Down Expand Up @@ -442,3 +453,28 @@ function _selectFastLlmID(vendors: GroupedVendorLLMs) {
}
return null;
}


function _port_V3Options_to_V4Parameters_inline(llm: DLLM): void {

// skip if already migrated
if ('initialParameters' in (llm as object)) return;

// initialize initialParameters and userParameters if they don't exist
if (!llm.initialParameters) llm.initialParameters = {};
if (!llm.userParameters) llm.userParameters = {};

// migrate options to initialParameters/userParameters
type DLLMV3_Options = DLLM & { options?: { llmRef: string, llmTemperature?: number, llmResponseTokens?: number } & Record<string, any> };
const llmV3 = llm as DLLMV3_Options;
if ('options' in llmV3 && typeof llmV3.options === 'object') {
if ('llmRef' in llmV3.options)
llm.initialParameters.llmRef = llmV3.options.llmRef;
if ('llmTemperature' in llmV3.options && typeof llmV3.options.llmTemperature === 'number')
llm.initialParameters.llmTemperature = Math.max(0, Math.min(1, llmV3.options.llmTemperature));
if ('llmResponseTokens' in llmV3.options && typeof llmV3.options.llmResponseTokens === 'number')
llm.initialParameters.llmResponseTokens = llmV3.options.llmResponseTokens;
delete llmV3.options;
}

}
4 changes: 3 additions & 1 deletion src/common/tokens/tokens.text.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { Tiktoken, TiktokenEncoding, TiktokenModel } from 'tiktoken';

import type { DLLM } from '~/common/stores/llms/llms.types';
import { getAllModelParameterValues } from '~/common/stores/llms/llms.parameters';


// Do not set this to true in production, it's very verbose
Expand Down Expand Up @@ -63,7 +64,8 @@ export function textTokensForLLM(text: string, llm: DLLM, debugFrom: string): nu
}

// Validate input
const openaiModel = llm?.options?.llmRef;
const llmParameters = getAllModelParameterValues(llm.initialParameters, llm.userParameters);
const openaiModel = llmParameters.llmRef;
if (!openaiModel) {
console.warn(`textTokensForLLM: LLM ${llm?.id} has no LLM reference id`);
return null;
Expand Down
Loading

0 comments on commit 6e85171

Please sign in to comment.