Skip to content

Commit

Permalink
Merge pull request #39 from jehna/add/8b-model
Browse files Browse the repository at this point in the history
Add 8b model
  • Loading branch information
jehna authored Aug 15, 2024
2 parents 30534a6 + b3a72b3 commit ebe827a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/commands/local.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { cli } from "../cli.js";
import { llama } from "../plugins/local-llm-rename/llama.js";
import { DEFAULT_MODEL, getEnsuredModelPath } from "../local-models.js";
import { DEFAULT_MODEL } from "../local-models.js";
import { unminify } from "../unminify.js";
import prettier from "../plugins/prettier.js";
import babel from "../plugins/babel/babel.js";
Expand All @@ -26,7 +26,7 @@ export const local = cli()
}

const prompt = await llama({
modelPath: getEnsuredModelPath(opts.model),
model: opts.model,
disableGPU: opts.disableGPU,
seed: opts.seed ? parseInt(opts.seed) : undefined
});
Expand Down
24 changes: 20 additions & 4 deletions src/local-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,36 @@ import { showProgress } from "./progress.js";
import { err } from "./cli-error.js";
import { homedir } from "os";
import { join } from "path";
import { ChatWrapper, Llama3_1ChatWrapper } from "node-llama-cpp";

const MODEL_DIRECTORY = join(homedir(), ".humanifyjs", "models");

export const MODELS: { [modelName: string]: URL } = {
"2gb": url`https://huggingface.co/bartowski/Phi-3.1-mini-4k-instruct-GGUF/resolve/main/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf?download=true`
type ModelDefinition = { url: URL; wrapper?: ChatWrapper };

export const MODELS: { [modelName: string]: ModelDefinition } = {
"2gb": {
url: url`https://huggingface.co/bartowski/Phi-3.1-mini-4k-instruct-GGUF/resolve/main/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf?download=true`
},
"8b": {
url: url`https://huggingface.co/lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf?download=true`,
wrapper: new Llama3_1ChatWrapper()
}
};

async function ensureModelDirectory() {
await fs.mkdir(MODEL_DIRECTORY, { recursive: true });
}

export function getModelWrapper(model: string) {
if (!(model in MODELS)) {
err(`Model ${model} not found`);
}
return MODELS[model].wrapper;
}

export async function downloadModel(model: string) {
await ensureModelDirectory();
const url = MODELS[model];
const url = MODELS[model].url;
if (url === undefined) {
err(`Model ${model} not found`);
}
Expand Down Expand Up @@ -54,7 +70,7 @@ export function getModelPath(model: string) {
if (!(model in MODELS)) {
err(`Model ${model} not found`);
}
const filename = basename(MODELS[model].pathname);
const filename = basename(MODELS[model].url.pathname);
return `${MODEL_DIRECTORY}/${filename}`;
}

Expand Down
8 changes: 2 additions & 6 deletions src/plugins/local-llm-rename/define-filename.llmtest.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import test from "node:test";
import { llama } from "./llama.js";
import { assertMatches } from "../../test-utils.js";
import { DEFAULT_MODEL, getEnsuredModelPath } from "../../local-models.js";
import { defineFilename } from "./define-filename.js";
import { testPrompt } from "../../test/test-prompt.js";

const prompt = await llama({
seed: 1,
modelPath: getEnsuredModelPath(process.env["MODEL"] ?? DEFAULT_MODEL)
});
const prompt = await testPrompt();

test("Defines a good name for a file with a function", async () => {
const result = await defineFilename(prompt, "const a = b => b + 1;");
Expand Down
8 changes: 5 additions & 3 deletions src/plugins/local-llm-rename/llama.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getLlama, LlamaChatSession, LlamaGrammar } from "node-llama-cpp";
import { Gbnf } from "./gbnf.js";
import { getModelPath, getModelWrapper } from "../../local-models.js";

export type Prompt = (
systemPrompt: string,
Expand All @@ -11,12 +12,12 @@ const IS_CI = process.env["CI"] === "true";

export async function llama(opts: {
seed?: number;
modelPath: string;
model: string;
disableGPU?: boolean;
}): Promise<Prompt> {
const llama = await getLlama();
const model = await llama.loadModel({
modelPath: opts?.modelPath,
modelPath: getModelPath(opts?.model),
gpuLayers: (opts?.disableGPU ?? IS_CI) ? 0 : undefined
});

Expand All @@ -26,7 +27,8 @@ export async function llama(opts: {
const session = new LlamaChatSession({
contextSequence: context.getSequence(),
autoDisposeSequence: true,
systemPrompt
systemPrompt,
chatWrapper: getModelWrapper(opts.model)
});
const response = await session.promptWithMeta(userPrompt, {
temperature: 0.8,
Expand Down
4 changes: 2 additions & 2 deletions src/test/test-prompt.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { DEFAULT_MODEL, getEnsuredModelPath } from "../local-models.js";
import { DEFAULT_MODEL } from "../local-models.js";
import { llama } from "../plugins/local-llm-rename/llama.js";

export const testPrompt = async () =>
await llama({
seed: 1,
modelPath: getEnsuredModelPath(process.env["MODEL"] ?? DEFAULT_MODEL)
model: process.env["MODEL"] ?? DEFAULT_MODEL
});

0 comments on commit ebe827a

Please sign in to comment.