-
-
Notifications
You must be signed in to change notification settings - Fork 519
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add WebLLM example to gallery (#7265)
* Add WebLLM example * Add to homepage index
- Loading branch information
1 parent
0ab5e8d
commit 9d4bfb9
Showing
2 changed files
with
198 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "04c6078a-0398-425c-82f7-d516b01b713d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import asyncio\n", | ||
"import panel as pn\n", | ||
"import param\n", | ||
"\n", | ||
"from panel.custom import JSComponent, ESMEvent\n", | ||
"\n", | ||
"pn.extension(template='material')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9aea88fd-3ae9-453f-a66e-b3bdc44c22e3", | ||
"metadata": {}, | ||
"source": [ | ||
"This example demonstrates how to wrap an external library (specifically [WebLLM](https://github.com/mlc-ai/web-llm)) as a `JSComponent` and interface it with the `ChatInterface`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "957c9fab-3fa7-48d7-83d0-5532bde6e547", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"MODELS = {\n", | ||
" 'Mistral-7b-Instruct': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',\n", | ||
" 'SmolLM': 'SmolLM-360M-Instruct-q4f16_1-MLC',\n", | ||
" 'Gemma-2b': 'gemma-2-2b-it-q4f16_1-MLC',\n", | ||
" 'Llama-3.1-8b-Instruct': 'Llama-3.1-8B-Instruct-q4f32_1-MLC-1k'\n", | ||
"}\n", | ||
"\n", | ||
"class WebLLM(JSComponent):\n", | ||
"\n", | ||
" loaded = param.Boolean(default=False, doc=\"\"\"\n", | ||
" Whether the model is loaded.\"\"\")\n", | ||
"\n", | ||
" model = param.Selector(default='SmolLM-360M-Instruct-q4f16_1-MLC', objects=MODELS)\n", | ||
"\n", | ||
" temperature = param.Number(default=1, bounds=(0, 2))\n", | ||
"\n", | ||
" load_model = param.Event()\n", | ||
" \n", | ||
" _esm = \"\"\"\n", | ||
" import * as webllm from \"https://esm.run/@mlc-ai/web-llm\";\n", | ||
"\n", | ||
" const engines = new Map()\n", | ||
"\n", | ||
" export async function render({ model }) {\n", | ||
" model.on(\"msg:custom\", async (event) => {\n", | ||
" console.log(event)\n", | ||
" if (event.type === 'load') {\n", | ||
" if (!engines.has(model.model)) {\n", | ||
" engines.set(model.model, await webllm.CreateMLCEngine(model.model))\n", | ||
" }\n", | ||
" model.loaded = true\n", | ||
" } else if (event.type === 'completion') {\n", | ||
" const engine = engines.get(model.model)\n", | ||
" if (engine == null) {\n", | ||
" model.send_msg({'finish_reason': 'error'})\n", | ||
" }\n", | ||
" const chunks = await engine.chat.completions.create({\n", | ||
" messages: event.messages,\n", | ||
" temperature: model.temperature ,\n", | ||
" stream: true,\n", | ||
" })\n", | ||
" for await (const chunk of chunks) {\n", | ||
" model.send_msg(chunk.choices[0])\n", | ||
" }\n", | ||
" }\n", | ||
" })\n", | ||
" }\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" def __init__(self, **params):\n", | ||
" super().__init__(**params)\n", | ||
" self._buffer = []\n", | ||
"\n", | ||
" @param.depends('load_model', watch=True)\n", | ||
" def _load_model(self):\n", | ||
" self.loading = True\n", | ||
" self._send_msg({'type': 'load'})\n", | ||
"\n", | ||
" @param.depends('loaded', watch=True)\n", | ||
" def _loaded(self):\n", | ||
" self.loading = False\n", | ||
" self.param.load_model.constant = True\n", | ||
"\n", | ||
" @param.depends('model', watch=True)\n", | ||
" def _update_load_model(self):\n", | ||
" self.param.load_model.constant = False\n", | ||
"\n", | ||
" def _handle_msg(self, msg):\n", | ||
" self._buffer.insert(0, msg)\n", | ||
"\n", | ||
" async def create_completion(self, msgs):\n", | ||
" self._send_msg({'type': 'completion', 'messages': msgs})\n", | ||
" latest = None\n", | ||
" while True:\n", | ||
" await asyncio.sleep(0.01)\n", | ||
" if not self._buffer:\n", | ||
" continue\n", | ||
" choice = self._buffer.pop()\n", | ||
" yield choice\n", | ||
" reason = choice['finish_reason']\n", | ||
" if reason == 'error':\n", | ||
" raise RuntimeError('Model not loaded')\n", | ||
" elif reason:\n", | ||
" return\n", | ||
"\n", | ||
" async def callback(self, contents: str, user: str):\n", | ||
" if not self.loaded:\n", | ||
" yield f'Model `{self.model}` is loading.' if self.param.load_model.constant else 'Load the model'\n", | ||
" return\n", | ||
" message = \"\"\n", | ||
" async for chunk in llm.create_completion([{'role': 'user', 'content': contents}]):\n", | ||
" message += chunk['delta'].get('content', '')\n", | ||
" yield message\n", | ||
"\n", | ||
" def menu(self):\n", | ||
" return pn.Column(\n", | ||
" pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),\n", | ||
" pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),\n", | ||
" pn.widgets.Button.from_param(\n", | ||
" self.param.load_model, sizing_mode='stretch_width',\n", | ||
" loading=self.param.loading\n", | ||
" )\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a663e937-797b-468f-875d-5bb8c2af002b", | ||
"metadata": {}, | ||
"source": [ | ||
"Having implemented the `WebLLM` component we can render the WebLLM UI:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "58269444-868b-41e4-abe2-c4fcf031dc4b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"llm = WebLLM()\n", | ||
"\n", | ||
"pn.Column(llm.menu(), llm).servable(area='sidebar')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "96229aa4-c5ed-4c4e-944a-789ee65d768f", | ||
"metadata": {}, | ||
"source": [ | ||
"And connect it to a `ChatInterface`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6f899068-8975-4cf4-9e1d-f3fdb5772a71", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chat_interface = pn.chat.ChatInterface(callback=llm.callback)\n", | ||
"chat_interface.send(\n", | ||
" \"Load a model and start chatting.\",\n", | ||
" user=\"System\",\n", | ||
" respond=False,\n", | ||
")\n", | ||
"\n", | ||
"chat_interface.servable(title='WebLLM')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python", | ||
"pygments_lexer": "ipython3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |