Skip to content

Commit

Permalink
fix main example
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Dec 1, 2024
1 parent c2d1f83 commit 5e1df82
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 162 deletions.
10 changes: 5 additions & 5 deletions examples/main/src/components/ChatScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export default function ChatScreen() {
isGenerating,
createCompletion,
navigateTo,
currModel,
loadedModel,
getWllamaInstance,
stopCompletion,
} = useWllama();
Expand Down Expand Up @@ -64,8 +64,8 @@ export default function ChatScreen() {
}

// generate response
if (!currModel) {
throw new Error('currModel is null');
if (!loadedModel) {
throw new Error('loadedModel is null');
}
const formattedChat = await formatChat(getWllamaInstance(), [
...currHistory,
Expand Down Expand Up @@ -118,7 +118,7 @@ export default function ChatScreen() {
</div>
)}

{currModel && (
{loadedModel && (
<textarea
className="textarea textarea-bordered w-full"
placeholder="Your message..."
Expand All @@ -134,7 +134,7 @@ export default function ChatScreen() {
/>
)}

{!currModel && <WarnNoModel />}
{!loadedModel && <WarnNoModel />}

<small className="text-center mx-auto opacity-70 pt-2">
wllama may generate inaccurate information. Use with your own risk.
Expand Down
106 changes: 85 additions & 21 deletions examples/main/src/components/ModelScreen.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ManageModel, ModelState, Screen } from '../utils/types';
import { ModelState, Screen } from '../utils/types';
import { useWllama } from '../utils/wllama.context';
import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import {
Expand All @@ -8,23 +8,24 @@ import {
faCheck,
} from '@fortawesome/free-solid-svg-icons';
import { DEFAULT_INFERENCE_PARAMS, MAX_GGUF_SIZE } from '../config';
import { toHumanReadableSize } from '../utils/utils';
import { useState } from 'react';
import { toHumanReadableSize, useDebounce } from '../utils/utils';
import { useEffect, useState } from 'react';
import ScreenWrapper from './ScreenWrapper';
import { DisplayedModel } from '../utils/displayed-model';

export default function ModelScreen() {
const [showAddCustom, setShowAddCustom] = useState(false);
const {
models,
removeModel,
removeCachedModel,
isLoadingModel,
isDownloading,
currModel,
loadedModel,
currParams,
setParams,
} = useWllama();

const blockModelBtn = !!(currModel || isDownloading || isLoadingModel);
const blockModelBtn = !!(loadedModel || isDownloading || isLoadingModel);

const onChange = (key: keyof typeof currParams) => (e: any) => {
setParams({ ...currParams, [key]: parseFloat(e.target.value || -1) });
Expand Down Expand Up @@ -101,7 +102,7 @@ export default function ModelScreen() {
)
) {
for (const m of models) {
await removeModel(m);
await removeCachedModel(m);
}
}
}}
Expand All @@ -123,7 +124,7 @@ export default function ModelScreen() {
</h1>

{models
.filter((m) => m.userAdded)
.filter((m) => m.isUserAdded)
.map((m) => (
<ModelCard key={m.url} model={m} blockModelBtn={blockModelBtn} />
))}
Expand All @@ -133,7 +134,7 @@ export default function ModelScreen() {
<h1 className="text-2xl mt-6 mb-4">Recommended models</h1>

{models
.filter((m) => !m.userAdded)
.filter((m) => !m.isUserAdded)
.map((m) => (
<ModelCard key={m.url} model={m} blockModelBtn={blockModelBtn} />
))}
Expand All @@ -150,12 +151,58 @@ export default function ModelScreen() {

function AddCustomModelDialog({ onClose }: { onClose(): void }) {
const { isLoadingModel, addCustomModel } = useWllama();
const [url, setUrl] = useState<string>('');
const [hfRepo, setHfRepo] = useState<string>('');
const [hfFile, setHfFile] = useState<string>('');
const [hfFiles, setHfFiles] = useState<string[]>([]);
const [abortSignal, setAbortSignal] = useState<AbortController>(
new AbortController()
);
const [err, setErr] = useState<string>();

useDebounce(
async () => {
if (hfRepo.length < 2) {
setHfFiles([]);
return;
}
try {
const res = await fetch(`https://huggingface.co/api/models/${hfRepo}`, {
signal: abortSignal.signal,
});
const data: { siblings?: { rfilename: string }[] } = await res.json();
if (data.siblings) {
setHfFiles(
data.siblings
.map((s) => s.rfilename)
.filter((f) => f.endsWith('.gguf'))
);
setErr('');
} else {
setErr('no model found or it is private');
setHfFiles([]);
}
} catch (e) {
if ((e as Error).name !== 'AbortError') {
setErr((e as any)?.message ?? 'unknown error');
setHfFiles([]);
}
}
},
[hfRepo],
500
);

useEffect(() => {
if (hfFiles.length === 0) {
setHfFile('');
}
}, [hfFiles]);

const onSubmit = async () => {
try {
await addCustomModel(url);
await addCustomModel(
`https://huggingface.co/${hfRepo}/resolve/main/${hfFile}`
);
onClose();
} catch (e) {
setErr((e as any)?.message ?? 'unknown error');
Expand All @@ -180,21 +227,36 @@ function AddCustomModelDialog({ onClose }: { onClose(): void }) {
</div>
<div className="mt-4">
<label className="input input-bordered flex items-center gap-2 mb-2">
URL
HF repo
<input
type="text"
className="grow"
placeholder="https://example.com/your_model-00001-of-00XXX.gguf"
value={url}
onChange={(e) => setUrl(e.target.value)}
placeholder="{username}/{repo}"
value={hfRepo}
onChange={(e) => {
abortSignal.abort();
setHfRepo(e.target.value);
setAbortSignal(new AbortController());
}}
/>
</label>
<select
className="select select-bordered w-full"
onChange={(e) => setHfFile(e.target.value)}
>
<option value="">Select a model file</option>
{hfFiles.map((f) => (
<option key={f} value={f}>
{f}
</option>
))}
</select>
</div>
{err && <div className="mt-4 text-error">Error: {err}</div>}
<div className="modal-action">
<button
className="btn btn-primary"
disabled={isLoadingModel || url.length < 5}
disabled={isLoadingModel || hfRepo.length < 2 || hfFile.length < 5}
onClick={onSubmit}
>
{isLoadingModel && (
Expand All @@ -215,12 +277,12 @@ function ModelCard({
model,
blockModelBtn,
}: {
model: ManageModel;
model: DisplayedModel;
blockModelBtn: boolean;
}) {
const {
downloadModel,
removeModel,
removeCachedModel,
loadModel,
unloadModel,
removeCustomModel,
Expand All @@ -237,9 +299,11 @@ function ModelCard({
>
<div className="card-body p-4 flex flex-row">
<div className="grow">
<b>{m.name}</b>
<b>{m.hfPath.replace(/-\d{5}-of-\d{5}/, '-(shards)')}</b>
<br />
<small>
HF repo: {m.hfModel}
<br />
Size: {toHumanReadableSize(m.size)}
{m.size > MAX_GGUF_SIZE && (
<div
Expand Down Expand Up @@ -312,7 +376,7 @@ function ModelCard({
if (
confirm('Are you sure to remove this model from cache?')
) {
removeModel(m);
removeCachedModel(m);
}
}}
disabled={blockModelBtn}
Expand All @@ -337,7 +401,7 @@ function ModelCard({
</button>
</>
)}
{m.state === ModelState.NOT_DOWNLOADED && m.userAdded && (
{m.state === ModelState.NOT_DOWNLOADED && m.isUserAdded && (
<button
className="btn btn-outline btn-error btn-sm mr-2"
onClick={() => {
Expand Down
8 changes: 5 additions & 3 deletions examples/main/src/components/Sidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
import { WLLAMA_VERSION } from '../config';

export default function Sidebar({ children }: { children: any }) {
const { currentConvId, navigateTo, currScreen, currModel } = useWllama();
const { currentConvId, navigateTo, currScreen, loadedModel } = useWllama();
const { conversations, getConversationById, deleteConversation } =
useMessages();

Expand Down Expand Up @@ -75,8 +75,10 @@ export default function Sidebar({ children }: { children: any }) {
<div className="w-80 px-4 pt-0 pb-8">
<div className="divider my-2"></div>

{currModel && (
<div className="text-sm px-4 pb-2">Model: {currModel.name}</div>
{loadedModel && (
<div className="text-sm px-4 pb-2">
Model: {loadedModel.hfModel}
</div>
)}

<ul className="menu gap-1">
Expand Down
9 changes: 2 additions & 7 deletions examples/main/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import wllamaSingle from '@wllama/wllama/src/single-thread/wllama.wasm?url';
import wllamaMulti from '@wllama/wllama/src/multi-thread/wllama.wasm?url';
import wllamaPackageJson from '@wllama/wllama/package.json';
import { InferenceParams, Model } from './utils/types';
import { InferenceParams } from './utils/types';

export const WLLAMA_VERSION = wllamaPackageJson.version;

Expand All @@ -13,12 +13,7 @@ export const WLLAMA_CONFIG_PATHS = {

export const MAX_GGUF_SIZE = 2 * 1024 * 1024 * 1024; // 2GB

export const LIST_MODELS: Model[] = [
// FIXME: chat template for tinyllama is broken
// {
// url: 'https://huggingface.co/ngxson/wllama-split-models/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M-00001-of-00003.gguf',
// size: 668788416,
// },
export const LIST_MODELS = [
{
url: 'https://huggingface.co/ngxson/SmolLM2-360M-Instruct-Q8_0-GGUF/resolve/main/smollm2-360m-instruct-q8_0.gguf',
size: 386404992,
Expand Down
12 changes: 4 additions & 8 deletions examples/main/src/utils/custom-models.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { MAX_GGUF_SIZE } from '../config';
import { Model } from './types';
import { DisplayedModel } from './displayed-model';
import { WllamaStorage } from './utils';

const ggufMagicNumber = new Uint8Array([0x47, 0x47, 0x55, 0x46]);

export async function verifyCustomModel(url: string): Promise<Model> {
export async function verifyCustomModel(url: string): Promise<DisplayedModel> {
const _url = url.replace(/\?.*/, '');

const response = await fetch(_url, {
Expand All @@ -24,11 +24,7 @@ export async function verifyCustomModel(url: string): Promise<Model> {
throw new Error(`Fetch error with status code = ${response.status}`);
}

return {
url: _url,
size: await getModelSize(_url),
userAdded: true,
};
return new DisplayedModel(_url, await getModelSize(_url), true, undefined);
}

const checkBuffer = (buffer: Uint8Array, header: Uint8Array) => {
Expand Down Expand Up @@ -92,7 +88,7 @@ const sumArr = (arr: number[]) => arr.reduce((sum, num) => sum + num, 0);
// for debugging only
// @ts-ignore
window._exportModelList = function () {
const list: Model[] = WllamaStorage.load('custom_models', []);
const list: any[] = WllamaStorage.load('custom_models', []);
const listExported = list.map((m) => {
delete m.userAdded;
return m;
Expand Down
Loading

0 comments on commit 5e1df82

Please sign in to comment.