diff --git a/src/commands/local.ts b/src/commands/local.ts index c313f7c..8ad2296 100644 --- a/src/commands/local.ts +++ b/src/commands/local.ts @@ -1,6 +1,6 @@ import { cli } from "../cli.js"; import { llama } from "../plugins/local-llm-rename/llama.js"; -import { DEFAULT_MODEL, getEnsuredModelPath } from "../local-models.js"; +import { DEFAULT_MODEL } from "../local-models.js"; import { unminify } from "../unminify.js"; import prettier from "../plugins/prettier.js"; import babel from "../plugins/babel/babel.js"; @@ -26,7 +26,7 @@ export const local = cli() } const prompt = await llama({ - modelPath: getEnsuredModelPath(opts.model), + model: opts.model, disableGPU: opts.disableGPU, seed: opts.seed ? parseInt(opts.seed) : undefined }); diff --git a/src/local-models.ts b/src/local-models.ts index 2ba5d51..e917c5a 100644 --- a/src/local-models.ts +++ b/src/local-models.ts @@ -8,20 +8,36 @@ import { showProgress } from "./progress.js"; import { err } from "./cli-error.js"; import { homedir } from "os"; import { join } from "path"; +import { ChatWrapper, Llama3_1ChatWrapper } from "node-llama-cpp"; const MODEL_DIRECTORY = join(homedir(), ".humanifyjs", "models"); -export const MODELS: { [modelName: string]: URL } = { - "2gb": url`https://huggingface.co/bartowski/Phi-3.1-mini-4k-instruct-GGUF/resolve/main/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf?download=true` +type ModelDefinition = { url: URL; wrapper?: ChatWrapper }; + +export const MODELS: { [modelName: string]: ModelDefinition } = { + "2gb": { + url: url`https://huggingface.co/bartowski/Phi-3.1-mini-4k-instruct-GGUF/resolve/main/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf?download=true` + }, + "8b": { + url: url`https://huggingface.co/lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf?download=true`, + wrapper: new Llama3_1ChatWrapper() + } }; async function ensureModelDirectory() { await fs.mkdir(MODEL_DIRECTORY, { recursive: true }); } +export function getModelWrapper(model: string) { + if (!(model in MODELS)) { + err(`Model ${model} not found`); + } + return MODELS[model].wrapper; +} + export async function downloadModel(model: string) { await ensureModelDirectory(); - const url = MODELS[model]; + const url = MODELS[model].url; if (url === undefined) { err(`Model ${model} not found`); } @@ -54,7 +70,7 @@ export function getModelPath(model: string) { if (!(model in MODELS)) { err(`Model ${model} not found`); } - const filename = basename(MODELS[model].pathname); + const filename = basename(MODELS[model].url.pathname); return `${MODEL_DIRECTORY}/${filename}`; } diff --git a/src/plugins/local-llm-rename/define-filename.llmtest.ts b/src/plugins/local-llm-rename/define-filename.llmtest.ts index c908c55..c0671eb 100644 --- a/src/plugins/local-llm-rename/define-filename.llmtest.ts +++ b/src/plugins/local-llm-rename/define-filename.llmtest.ts @@ -1,13 +1,9 @@ import test from "node:test"; -import { llama } from "./llama.js"; import { assertMatches } from "../../test-utils.js"; -import { DEFAULT_MODEL, getEnsuredModelPath } from "../../local-models.js"; import { defineFilename } from "./define-filename.js"; +import { testPrompt } from "../../test/test-prompt.js"; -const prompt = await llama({ - seed: 1, - modelPath: getEnsuredModelPath(process.env["MODEL"] ?? DEFAULT_MODEL) -}); +const prompt = await testPrompt(); test("Defines a good name for a file with a function", async () => { const result = await defineFilename(prompt, "const a = b => b + 1;"); diff --git a/src/plugins/local-llm-rename/llama.ts b/src/plugins/local-llm-rename/llama.ts index 781a8e9..48cc493 100644 --- a/src/plugins/local-llm-rename/llama.ts +++ b/src/plugins/local-llm-rename/llama.ts @@ -1,5 +1,6 @@ import { getLlama, LlamaChatSession, LlamaGrammar } from "node-llama-cpp"; import { Gbnf } from "./gbnf.js"; +import { getModelPath, getModelWrapper } from "../../local-models.js"; export type Prompt = ( systemPrompt: string, @@ -11,12 +12,12 @@ const IS_CI = process.env["CI"] === "true"; export async function llama(opts: { seed?: number; - modelPath: string; + model: string; disableGPU?: boolean; }): Promise { const llama = await getLlama(); const model = await llama.loadModel({ - modelPath: opts?.modelPath, + modelPath: getModelPath(opts?.model), gpuLayers: (opts?.disableGPU ?? IS_CI) ? 0 : undefined }); @@ -26,7 +27,8 @@ export async function llama(opts: { const session = new LlamaChatSession({ contextSequence: context.getSequence(), autoDisposeSequence: true, - systemPrompt + systemPrompt, + chatWrapper: getModelWrapper(opts.model) }); const response = await session.promptWithMeta(userPrompt, { temperature: 0.8, diff --git a/src/test/test-prompt.ts b/src/test/test-prompt.ts index 4d167b2..14b33fe 100644 --- a/src/test/test-prompt.ts +++ b/src/test/test-prompt.ts @@ -1,8 +1,8 @@ -import { DEFAULT_MODEL, getEnsuredModelPath } from "../local-models.js"; +import { DEFAULT_MODEL } from "../local-models.js"; import { llama } from "../plugins/local-llm-rename/llama.js"; export const testPrompt = async () => await llama({ seed: 1, - modelPath: getEnsuredModelPath(process.env["MODEL"] ?? DEFAULT_MODEL) + model: process.env["MODEL"] ?? DEFAULT_MODEL });