From 98bef7b7cffa811a67945d8c8f4659862c15026c Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 17 Sep 2024 08:34:58 +0700 Subject: [PATCH] test: add model parameter validation rules and persistence tests (#3618) * test: add model parameter validation rules and persistence tests * chore: fix CI cov step * fix: invalid model settings should fallback to origin value * test: support fallback integer settings --- .../src/node/index.ts | 32 +- .../tensorrt-llm-extension/src/node/index.ts | 4 +- web/containers/Providers/EventHandler.tsx | 4 +- web/containers/SliderRightPanel/index.tsx | 24 +- web/hooks/useSendChatMessage.ts | 9 +- web/hooks/useUpdateModelParameters.test.ts | 314 ++++++++++++++++++ web/hooks/useUpdateModelParameters.ts | 16 +- .../LocalServerRightPanel/index.tsx | 13 +- web/screens/Thread/ThreadRightPanel/index.tsx | 26 +- web/utils/modelParam.test.ts | 183 ++++++++++ web/utils/modelParam.ts | 106 +++++- 11 files changed, 679 insertions(+), 52 deletions(-) create mode 100644 web/hooks/useUpdateModelParameters.test.ts create mode 100644 web/utils/modelParam.test.ts diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index edc2d013de..3a969ad5e7 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise { if (!settings?.ngl) { settings.ngl = 100 } - log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`) + log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`) return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { method: 'POST', headers: { @@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise { }) .then((res) => { log( - `[CORTEX]::Debug: Load model success with response ${JSON.stringify( + `[CORTEX]:: Load model success with response ${JSON.stringify( res )}` ) @@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise { async function validateModelStatus(modelId: string): Promise { // Send a GET request to the validation URL. // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. - log(`[CORTEX]::Debug: Validating model ${modelId}`) + log(`[CORTEX]:: Validating model ${modelId}`) return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { method: 'POST', body: JSON.stringify({ @@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise { retryDelay: 300, }).then(async (res: Response) => { log( - `[CORTEX]::Debug: Validate model state with response ${JSON.stringify( + `[CORTEX]:: Validate model state with response ${JSON.stringify( res.status )}` ) @@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise { // Otherwise, return an object with an error message. if (body.model_loaded) { log( - `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify( + `[CORTEX]:: Validate model state success with response ${JSON.stringify( body )}` ) @@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise { } const errorBody = await res.text() log( - `[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( + `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( res.statusText )}` ) @@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise { async function killSubprocess(): Promise { const controller = new AbortController() setTimeout(() => controller.abort(), 5000) - log(`[CORTEX]::Debug: Request to kill cortex`) + log(`[CORTEX]:: Request to kill cortex`) const killRequest = () => { return fetch(NITRO_HTTP_KILL_URL, { @@ -321,17 +321,17 @@ async function killSubprocess(): Promise { .then(() => tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) ) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch((err) => { log( - `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` + `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` ) throw 'PORT_NOT_AVAILABLE' }) } if (subprocess?.pid && process.platform !== 'darwin') { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -341,7 +341,7 @@ async function killSubprocess(): Promise { } else { tcpPortUsed .waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .then(() => resolve()) .catch(() => { log( @@ -362,7 +362,7 @@ async function killSubprocess(): Promise { * @returns A promise that resolves when the Nitro subprocess is started. */ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { - log(`[CORTEX]::Debug: Spawning cortex subprocess...`) + log(`[CORTEX]:: Spawning cortex subprocess...`) return new Promise(async (resolve, reject) => { let executableOptions = executableNitroFile( @@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { const args: string[] = ['1', LOCAL_HOST, PORT.toString()] // Execute the binary log( - `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` + `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` ) log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`) @@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { // Handle subprocess output subprocess.stdout.on('data', (data: any) => { - log(`[CORTEX]::Debug: ${data}`) + log(`[CORTEX]:: ${data}`) }) subprocess.stderr.on('data', (data: any) => { @@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { }) subprocess.on('close', (code: any) => { - log(`[CORTEX]::Debug: cortex exited with code: ${code}`) + log(`[CORTEX]:: cortex exited with code: ${code}`) subprocess = undefined reject(`child process exited with code ${code}`) }) @@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { tcpPortUsed .waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000) .then(() => { - log(`[CORTEX]::Debug: cortex is ready`) + log(`[CORTEX]:: cortex is ready`) resolve() }) }) diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts index c8bc48459e..77003389fd 100644 --- a/extensions/tensorrt-llm-extension/src/node/index.ts +++ b/extensions/tensorrt-llm-extension/src/node/index.ts @@ -97,7 +97,7 @@ function unloadModel(): Promise { } if (subprocess?.pid) { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -107,7 +107,7 @@ function unloadModel(): Promise { return tcpPortUsed .waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000) .then(() => resolve()) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch(() => { killRequest() }) diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index e4c96aeb70..4809ce83eb 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -20,7 +20,7 @@ import { ulid } from 'ulidx' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' -import { toRuntimeParams } from '@/utils/modelParam' +import { extractInferenceParams } from '@/utils/modelParam' import { extensionManager } from '@/extension' import { @@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) { }, ] - const runtimeParams = toRuntimeParams(activeModelParamsRef.current) + const runtimeParams = extractInferenceParams(activeModelParamsRef.current) const messageRequest: MessageRequest = { id: msgId, diff --git a/web/containers/SliderRightPanel/index.tsx b/web/containers/SliderRightPanel/index.tsx index df415ffb59..c00d9f0022 100644 --- a/web/containers/SliderRightPanel/index.tsx +++ b/web/containers/SliderRightPanel/index.tsx @@ -87,26 +87,28 @@ const SliderRightPanel = ({ onValueChanged?.(Number(min)) setVal(min.toString()) setShowTooltip({ max: false, min: true }) + } else { + setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5 } }} onChange={(e) => { + // TODO: How to support negative number input? + // Passthru since it validates again onBlur + if (/^\d*\.?\d*$/.test(e.target.value)) { + setVal(e.target.value) + } + // Should not accept invalid value or NaN // E.g. anything changes that trigger onValueChanged // Which is incorrect - if (Number(e.target.value) > Number(max)) { - setVal(max.toString()) - } else if ( + if ( + Number(e.target.value) > Number(max) || Number(e.target.value) < Number(min) || - !e.target.value.length + Number.isNaN(Number(e.target.value)) ) { - setVal(min.toString()) - } else if (Number.isNaN(Number(e.target.value))) return - - onValueChanged?.(Number(e.target.value)) - // TODO: How to support negative number input? - if (/^\d*\.?\d*$/.test(e.target.value)) { - setVal(e.target.value) + return } + onValueChanged?.(Number(e.target.value)) }} /> } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 8c6013505b..1dbd5b45e8 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -23,7 +23,10 @@ import { import { Stack } from '@/utils/Stack' import { compressImage, getBase64 } from '@/utils/base64' import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' @@ -189,8 +192,8 @@ export default function useSendChatMessage() { if (engineParamsUpdate) setReloadModel(true) - const runtimeParams = toRuntimeParams(activeModelParams) - const settingParams = toSettingParams(activeModelParams) + const runtimeParams = extractInferenceParams(activeModelParams) + const settingParams = extractModelLoadParams(activeModelParams) const prompt = message.trim() diff --git a/web/hooks/useUpdateModelParameters.test.ts b/web/hooks/useUpdateModelParameters.test.ts new file mode 100644 index 0000000000..bc60aa631c --- /dev/null +++ b/web/hooks/useUpdateModelParameters.test.ts @@ -0,0 +1,314 @@ +import { renderHook, act } from '@testing-library/react' +// Mock dependencies +jest.mock('ulidx') +jest.mock('@/extension') + +import useUpdateModelParameters from './useUpdateModelParameters' +import { extensionManager } from '@/extension' + +// Mock data +let model: any = { + id: 'model-1', + engine: 'nitro', +} + +let extension: any = { + saveThread: jest.fn(), +} + +const mockThread: any = { + id: 'thread-1', + assistants: [ + { + model: { + parameters: {}, + settings: {}, + }, + }, + ], + object: 'thread', + title: 'New Thread', + created: 0, + updated: 0, +} + +describe('useUpdateModelParameters', () => { + beforeAll(() => { + jest.clearAllMocks() + jest.mock('./useRecommendedModel', () => ({ + useRecommendedModel: () => ({ + recommendedModel: model, + setRecommendedModel: jest.fn(), + downloadedModels: [], + }), + })) + }) + + it('should update model parameters and save thread when params are valid', async () => { + const mockValidParameters: any = { + params: { + // Inference + stop: ['', ''], + temperature: 0.5, + token_limit: 1000, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 0.3, + presence_penalty: 0.2, + + // Load model + ctx_len: 1024, + ngl: 12, + embedding: true, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + vision_model: 'vision', + text_model: 'text', + }, + modelId: 'model-1', + engine: 'nitro', + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter(mockThread, mockValidParameters) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + stop: ['', ''], + temperature: 0.5, + token_limit: 1000, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 0.3, + presence_penalty: 0.2, + }, + settings: { + ctx_len: 1024, + ngl: 12, + embedding: true, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + }, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) + + it('should not update invalid model parameters', async () => { + const mockInvalidParameters: any = { + params: { + // Inference + stop: [1, ''], + temperature: '0.5', + token_limit: '1000', + top_k: '0.7', + top_p: '0.1', + stream: 'true', + max_tokens: '1000', + frequency_penalty: '0.3', + presence_penalty: '0.2', + + // Load model + ctx_len: '1024', + ngl: '12', + embedding: 'true', + n_parallel: '2', + cpu_threads: '4', + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + vision_model: 'vision', + text_model: 'text', + }, + modelId: 'model-1', + engine: 'nitro', + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter( + mockThread, + mockInvalidParameters + ) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + max_tokens: 1000, + token_limit: 1000, + }, + settings: { + cpu_threads: 4, + ctx_len: 1024, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + n_parallel: 2, + ngl: 12, + }, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) + + it('should update valid model parameters only', async () => { + const mockInvalidParameters: any = { + params: { + // Inference + stop: [''], + temperature: -0.5, + token_limit: 100.2, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 1.2, + presence_penalty: 0.2, + + // Load model + ctx_len: 1024, + ngl: 0, + embedding: 'true', + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + vision_model: 'vision', + text_model: 'text', + }, + modelId: 'model-1', + engine: 'nitro', + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter( + mockThread, + mockInvalidParameters + ) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + stop: [''], + top_k: 0.7, + top_p: 0.1, + stream: true, + token_limit: 100, + max_tokens: 1000, + presence_penalty: 0.2, + }, + settings: { + ctx_len: 1024, + ngl: 0, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + }, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) + + it('should handle missing modelId and engine gracefully', async () => { + const mockParametersWithoutModelIdAndEngine: any = { + params: { + stop: ['', ''], + temperature: 0.5, + }, + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter( + mockThread, + mockParametersWithoutModelIdAndEngine + ) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + stop: ['', ''], + temperature: 0.5, + }, + settings: {}, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) +}) diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 79d8774567..46bf07cd50 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -12,7 +12,10 @@ import { import { useAtom, useAtomValue, useSetAtom } from 'jotai' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import useRecommendedModel from './useRecommendedModel' @@ -47,12 +50,17 @@ export default function useUpdateModelParameters() { const toUpdateSettings = processStopWords(settings.params ?? {}) const updatedModelParams = settings.modelId ? toUpdateSettings - : { ...activeModelParams, ...toUpdateSettings } + : { + ...selectedModel?.parameters, + ...selectedModel?.settings, + ...activeModelParams, + ...toUpdateSettings, + } // update the state setThreadModelParams(thread.id, updatedModelParams) - const runtimeParams = toRuntimeParams(updatedModelParams) - const settingParams = toSettingParams(updatedModelParams) + const runtimeParams = extractInferenceParams(updatedModelParams) + const settingParams = extractModelLoadParams(updatedModelParams) const assistants = thread.assistants.map( (assistant: ThreadAssistantInfo) => { diff --git a/web/screens/LocalServer/LocalServerRightPanel/index.tsx b/web/screens/LocalServer/LocalServerRightPanel/index.tsx index 309709c268..13e3cad578 100644 --- a/web/screens/LocalServer/LocalServerRightPanel/index.tsx +++ b/web/screens/LocalServer/LocalServerRightPanel/index.tsx @@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel' import { getConfigurationsData } from '@/utils/componentSettings' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' @@ -27,16 +30,18 @@ const LocalServerRightPanel = () => { const selectedModel = useAtomValue(selectedModelAtom) const [currentModelSettingParams, setCurrentModelSettingParams] = useState( - toSettingParams(selectedModel?.settings) + extractModelLoadParams(selectedModel?.settings) ) useEffect(() => { if (selectedModel) { - setCurrentModelSettingParams(toSettingParams(selectedModel?.settings)) + setCurrentModelSettingParams( + extractModelLoadParams(selectedModel?.settings) + ) } }, [selectedModel]) - const modelRuntimeParams = toRuntimeParams(selectedModel?.settings) + const modelRuntimeParams = extractInferenceParams(selectedModel?.settings) const componentDataRuntimeSetting = getConfigurationsData( modelRuntimeParams, diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 9e7cdf7d86..e7d0a27b9b 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { getConfigurationsData } from '@/utils/componentSettings' import { localEngines } from '@/utils/modelEngine' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import PromptTemplateSetting from './PromptTemplateSetting' import Tools from './Tools' @@ -68,14 +71,26 @@ const ThreadRightPanel = () => { const settings = useMemo(() => { // runtime setting - const modelRuntimeParams = toRuntimeParams(activeModelParams) + const modelRuntimeParams = extractInferenceParams( + { + ...selectedModel?.parameters, + ...activeModelParams, + }, + selectedModel?.parameters + ) const componentDataRuntimeSetting = getConfigurationsData( modelRuntimeParams, selectedModel ).filter((x) => x.key !== 'prompt_template') // engine setting - const modelEngineParams = toSettingParams(activeModelParams) + const modelEngineParams = extractModelLoadParams( + { + ...selectedModel?.settings, + ...activeModelParams, + }, + selectedModel?.settings + ) const componentDataEngineSetting = getConfigurationsData( modelEngineParams, selectedModel @@ -126,7 +141,10 @@ const ThreadRightPanel = () => { }, [activeModelParams, selectedModel]) const promptTemplateSettings = useMemo(() => { - const modelEngineParams = toSettingParams(activeModelParams) + const modelEngineParams = extractModelLoadParams({ + ...selectedModel?.settings, + ...activeModelParams, + }) const componentDataEngineSetting = getConfigurationsData( modelEngineParams, selectedModel diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts new file mode 100644 index 0000000000..f1b8589555 --- /dev/null +++ b/web/utils/modelParam.test.ts @@ -0,0 +1,183 @@ +// web/utils/modelParam.test.ts +import { normalizeValue, validationRules } from './modelParam' + +describe('validationRules', () => { + it('should validate temperature correctly', () => { + expect(validationRules.temperature(0.5)).toBe(true) + expect(validationRules.temperature(2)).toBe(true) + expect(validationRules.temperature(0)).toBe(true) + expect(validationRules.temperature(-0.1)).toBe(false) + expect(validationRules.temperature(2.3)).toBe(false) + expect(validationRules.temperature('0.5')).toBe(false) + }) + + it('should validate token_limit correctly', () => { + expect(validationRules.token_limit(100)).toBe(true) + expect(validationRules.token_limit(1)).toBe(true) + expect(validationRules.token_limit(0)).toBe(true) + expect(validationRules.token_limit(-1)).toBe(false) + expect(validationRules.token_limit('100')).toBe(false) + }) + + it('should validate top_k correctly', () => { + expect(validationRules.top_k(0.5)).toBe(true) + expect(validationRules.top_k(1)).toBe(true) + expect(validationRules.top_k(0)).toBe(true) + expect(validationRules.top_k(-0.1)).toBe(false) + expect(validationRules.top_k(1.1)).toBe(false) + expect(validationRules.top_k('0.5')).toBe(false) + }) + + it('should validate top_p correctly', () => { + expect(validationRules.top_p(0.5)).toBe(true) + expect(validationRules.top_p(1)).toBe(true) + expect(validationRules.top_p(0)).toBe(true) + expect(validationRules.top_p(-0.1)).toBe(false) + expect(validationRules.top_p(1.1)).toBe(false) + expect(validationRules.top_p('0.5')).toBe(false) + }) + + it('should validate stream correctly', () => { + expect(validationRules.stream(true)).toBe(true) + expect(validationRules.stream(false)).toBe(true) + expect(validationRules.stream('true')).toBe(false) + expect(validationRules.stream(1)).toBe(false) + }) + + it('should validate max_tokens correctly', () => { + expect(validationRules.max_tokens(100)).toBe(true) + expect(validationRules.max_tokens(1)).toBe(true) + expect(validationRules.max_tokens(0)).toBe(true) + expect(validationRules.max_tokens(-1)).toBe(false) + expect(validationRules.max_tokens('100')).toBe(false) + }) + + it('should validate stop correctly', () => { + expect(validationRules.stop(['word1', 'word2'])).toBe(true) + expect(validationRules.stop([])).toBe(true) + expect(validationRules.stop(['word1', 2])).toBe(false) + expect(validationRules.stop('word1')).toBe(false) + }) + + it('should validate frequency_penalty correctly', () => { + expect(validationRules.frequency_penalty(0.5)).toBe(true) + expect(validationRules.frequency_penalty(1)).toBe(true) + expect(validationRules.frequency_penalty(0)).toBe(true) + expect(validationRules.frequency_penalty(-0.1)).toBe(false) + expect(validationRules.frequency_penalty(1.1)).toBe(false) + expect(validationRules.frequency_penalty('0.5')).toBe(false) + }) + + it('should validate presence_penalty correctly', () => { + expect(validationRules.presence_penalty(0.5)).toBe(true) + expect(validationRules.presence_penalty(1)).toBe(true) + expect(validationRules.presence_penalty(0)).toBe(true) + expect(validationRules.presence_penalty(-0.1)).toBe(false) + expect(validationRules.presence_penalty(1.1)).toBe(false) + expect(validationRules.presence_penalty('0.5')).toBe(false) + }) + + it('should validate ctx_len correctly', () => { + expect(validationRules.ctx_len(1024)).toBe(true) + expect(validationRules.ctx_len(1)).toBe(true) + expect(validationRules.ctx_len(0)).toBe(true) + expect(validationRules.ctx_len(-1)).toBe(false) + expect(validationRules.ctx_len('1024')).toBe(false) + }) + + it('should validate ngl correctly', () => { + expect(validationRules.ngl(12)).toBe(true) + expect(validationRules.ngl(1)).toBe(true) + expect(validationRules.ngl(0)).toBe(true) + expect(validationRules.ngl(-1)).toBe(false) + expect(validationRules.ngl('12')).toBe(false) + }) + + it('should validate embedding correctly', () => { + expect(validationRules.embedding(true)).toBe(true) + expect(validationRules.embedding(false)).toBe(true) + expect(validationRules.embedding('true')).toBe(false) + expect(validationRules.embedding(1)).toBe(false) + }) + + it('should validate n_parallel correctly', () => { + expect(validationRules.n_parallel(2)).toBe(true) + expect(validationRules.n_parallel(1)).toBe(true) + expect(validationRules.n_parallel(0)).toBe(true) + expect(validationRules.n_parallel(-1)).toBe(false) + expect(validationRules.n_parallel('2')).toBe(false) + }) + + it('should validate cpu_threads correctly', () => { + expect(validationRules.cpu_threads(4)).toBe(true) + expect(validationRules.cpu_threads(1)).toBe(true) + expect(validationRules.cpu_threads(0)).toBe(true) + expect(validationRules.cpu_threads(-1)).toBe(false) + expect(validationRules.cpu_threads('4')).toBe(false) + }) + + it('should validate prompt_template correctly', () => { + expect(validationRules.prompt_template('template')).toBe(true) + expect(validationRules.prompt_template('')).toBe(true) + expect(validationRules.prompt_template(123)).toBe(false) + }) + + it('should validate llama_model_path correctly', () => { + expect(validationRules.llama_model_path('path')).toBe(true) + expect(validationRules.llama_model_path('')).toBe(true) + expect(validationRules.llama_model_path(123)).toBe(false) + }) + + it('should validate mmproj correctly', () => { + expect(validationRules.mmproj('mmproj')).toBe(true) + expect(validationRules.mmproj('')).toBe(true) + expect(validationRules.mmproj(123)).toBe(false) + }) + + it('should validate vision_model correctly', () => { + expect(validationRules.vision_model(true)).toBe(true) + expect(validationRules.vision_model(false)).toBe(true) + expect(validationRules.vision_model('true')).toBe(false) + expect(validationRules.vision_model(1)).toBe(false) + }) + + it('should validate text_model correctly', () => { + expect(validationRules.text_model(true)).toBe(true) + expect(validationRules.text_model(false)).toBe(true) + expect(validationRules.text_model('true')).toBe(false) + expect(validationRules.text_model(1)).toBe(false) + }) +}) + +describe('normalizeValue', () => { + it('should normalize ctx_len correctly', () => { + expect(normalizeValue('ctx_len', 100.5)).toBe(100) + expect(normalizeValue('ctx_len', '2')).toBe(2) + expect(normalizeValue('ctx_len', 100)).toBe(100) + }) + it('should normalize token_limit correctly', () => { + expect(normalizeValue('token_limit', 100.5)).toBe(100) + expect(normalizeValue('token_limit', '1')).toBe(1) + expect(normalizeValue('token_limit', 0)).toBe(0) + }) + it('should normalize max_tokens correctly', () => { + expect(normalizeValue('max_tokens', 100.5)).toBe(100) + expect(normalizeValue('max_tokens', '1')).toBe(1) + expect(normalizeValue('max_tokens', 0)).toBe(0) + }) + it('should normalize ngl correctly', () => { + expect(normalizeValue('ngl', 12.5)).toBe(12) + expect(normalizeValue('ngl', '2')).toBe(2) + expect(normalizeValue('ngl', 0)).toBe(0) + }) + it('should normalize n_parallel correctly', () => { + expect(normalizeValue('n_parallel', 2.5)).toBe(2) + expect(normalizeValue('n_parallel', '2')).toBe(2) + expect(normalizeValue('n_parallel', 0)).toBe(0) + }) + it('should normalize cpu_threads correctly', () => { + expect(normalizeValue('cpu_threads', 4.5)).toBe(4) + expect(normalizeValue('cpu_threads', '4')).toBe(4) + expect(normalizeValue('cpu_threads', 0)).toBe(0) + }) +}) diff --git a/web/utils/modelParam.ts b/web/utils/modelParam.ts index a6d144c3ee..dda9cf7611 100644 --- a/web/utils/modelParam.ts +++ b/web/utils/modelParam.ts @@ -1,9 +1,69 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/naming-convention */ import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core' import { ModelParams } from '@/helpers/atoms/Thread.atom' -export const toRuntimeParams = ( - modelParams?: ModelParams +/** + * Validation rules for model parameters + */ +export const validationRules: { [key: string]: (value: any) => boolean } = { + temperature: (value: any) => + typeof value === 'number' && value >= 0 && value <= 2, + token_limit: (value: any) => Number.isInteger(value) && value >= 0, + top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + stream: (value: any) => typeof value === 'boolean', + max_tokens: (value: any) => Number.isInteger(value) && value >= 0, + stop: (value: any) => + Array.isArray(value) && value.every((v) => typeof v === 'string'), + frequency_penalty: (value: any) => + typeof value === 'number' && value >= 0 && value <= 1, + presence_penalty: (value: any) => + typeof value === 'number' && value >= 0 && value <= 1, + + ctx_len: (value: any) => Number.isInteger(value) && value >= 0, + ngl: (value: any) => Number.isInteger(value) && value >= 0, + embedding: (value: any) => typeof value === 'boolean', + n_parallel: (value: any) => Number.isInteger(value) && value >= 0, + cpu_threads: (value: any) => Number.isInteger(value) && value >= 0, + prompt_template: (value: any) => typeof value === 'string', + llama_model_path: (value: any) => typeof value === 'string', + mmproj: (value: any) => typeof value === 'string', + vision_model: (value: any) => typeof value === 'boolean', + text_model: (value: any) => typeof value === 'boolean', +} + +/** + * There are some parameters that need to be normalized before being sent to the server + * E.g. ctx_len should be an integer, but it can be a float from the input field + * @param key + * @param value + * @returns + */ +export const normalizeValue = (key: string, value: any) => { + if ( + key === 'token_limit' || + key === 'max_tokens' || + key === 'ctx_len' || + key === 'ngl' || + key === 'n_parallel' || + key === 'cpu_threads' + ) { + // Convert to integer + return Math.floor(Number(value)) + } + return value +} + +/** + * Extract inference parameters from flat model parameters + * @param modelParams + * @returns + */ +export const extractInferenceParams = ( + modelParams?: ModelParams, + originParams?: ModelParams ): ModelRuntimeParams => { if (!modelParams) return {} const defaultModelParams: ModelRuntimeParams = { @@ -22,15 +82,35 @@ export const toRuntimeParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultModelParams) { - Object.assign(runtimeParams, { ...runtimeParams, [key]: value }) + const validate = validationRules[key] + if (validate && !validate(normalizeValue(key, value))) { + // Invalid value - fall back to origin value + if (originParams && key in originParams) { + Object.assign(runtimeParams, { + ...runtimeParams, + [key]: originParams[key as keyof typeof originParams], + }) + } + } else { + Object.assign(runtimeParams, { + ...runtimeParams, + [key]: normalizeValue(key, value), + }) + } } } return runtimeParams } -export const toSettingParams = ( - modelParams?: ModelParams +/** + * Extract model load parameters from flat model parameters + * @param modelParams + * @returns + */ +export const extractModelLoadParams = ( + modelParams?: ModelParams, + originParams?: ModelParams ): ModelSettingParams => { if (!modelParams) return {} const defaultSettingParams: ModelSettingParams = { @@ -49,7 +129,21 @@ export const toSettingParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultSettingParams) { - Object.assign(settingParams, { ...settingParams, [key]: value }) + const validate = validationRules[key] + if (validate && !validate(normalizeValue(key, value))) { + // Invalid value - fall back to origin value + if (originParams && key in originParams) { + Object.assign(modelParams, { + ...modelParams, + [key]: originParams[key as keyof typeof originParams], + }) + } + } else { + Object.assign(settingParams, { + ...settingParams, + [key]: normalizeValue(key, value), + }) + } } }