@@ -11,22 +11,41 @@ import {
11
11
InitProgressReport ,
12
12
CreateMLCEngine ,
13
13
ChatCompletionMessageParam ,
14
+ prebuiltAppConfig ,
14
15
} from "@mlc-ai/web-llm" ;
15
16
import { ProgressBar , Line } from "progressbar.js" ;
16
17
18
+ // modified setLabel to not throw error
19
+ function setLabel ( id : string , text : string ) {
20
+ const label = document . getElementById ( id ) ;
21
+ if ( label != null ) {
22
+ label . innerText = text ;
23
+ }
24
+ }
25
+
26
+ function getElementAndCheck ( id : string ) : HTMLElement {
27
+ const element = document . getElementById ( id ) ;
28
+ if ( element == null ) {
29
+ throw Error ( "Cannot find element " + id ) ;
30
+ }
31
+ return element ;
32
+ }
33
+
17
34
const sleep = ( ms : number ) => new Promise ( ( r ) => setTimeout ( r , ms ) ) ;
18
35
19
- const queryInput = document . getElementById ( "query-input" ) ! ;
20
- const submitButton = document . getElementById ( "submit-button" ) ! ;
36
+ const queryInput = getElementAndCheck ( "query-input" ) ! ;
37
+ const submitButton = getElementAndCheck ( "submit-button" ) ! ;
38
+ const modelName = getElementAndCheck ( "model-name" ) ;
21
39
22
40
let context = "" ;
23
- let isLoadingParams = false ;
41
+ let modelDisplayName = "" ;
24
42
43
+ // throws runtime.lastError if you refresh extension AND try to access a webpage that is already open
25
44
fetchPageContents ( ) ;
26
45
27
46
( < HTMLButtonElement > submitButton ) . disabled = true ;
28
47
29
- const progressBar : ProgressBar = new Line ( "#loadingContainer" , {
48
+ let progressBar : ProgressBar = new Line ( "#loadingContainer" , {
30
49
strokeWidth : 4 ,
31
50
easing : "easeInOut" ,
32
51
duration : 1400 ,
@@ -36,8 +55,10 @@ const progressBar: ProgressBar = new Line("#loadingContainer", {
36
55
svgStyle : { width : "100%" , height : "100%" } ,
37
56
} ) ;
38
57
39
- const initProgressCallback = ( report : InitProgressReport ) => {
40
- console . log ( report . text , report . progress ) ;
58
+ let isLoadingParams = true ;
59
+
60
+ let initProgressCallback = ( report : InitProgressReport ) => {
61
+ setLabel ( "init-label" , report . text ) ;
41
62
progressBar . animate ( report . progress , {
42
63
duration : 50 ,
43
64
} ) ;
@@ -46,29 +67,67 @@ const initProgressCallback = (report: InitProgressReport) => {
46
67
}
47
68
} ;
48
69
49
- // const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k";
50
- const selectedModel = "Mistral-7B-Instruct-v0.2-q4f16_1-MLC" ;
70
+ // initially selected model
71
+ let selectedModel = "Qwen2-0.5B-Instruct-q4f16_1-MLC" ;
72
+
73
+ // populate model-selection
74
+ const modelSelector = getElementAndCheck (
75
+ "model-selection" ,
76
+ ) as HTMLSelectElement ;
77
+ for ( let i = 0 ; i < prebuiltAppConfig . model_list . length ; ++ i ) {
78
+ const model = prebuiltAppConfig . model_list [ i ] ;
79
+ const opt = document . createElement ( "option" ) ;
80
+ opt . value = model . model_id ;
81
+ opt . innerHTML = model . model_id ;
82
+ opt . selected = false ;
83
+
84
+ // set initial selection as the initially selected model
85
+ if ( model . model_id == selectedModel ) {
86
+ opt . selected = true ;
87
+ }
88
+
89
+ modelSelector . appendChild ( opt ) ;
90
+ }
91
+
92
+ modelName . innerText = "Loading initial model..." ;
51
93
const engine : MLCEngineInterface = await CreateMLCEngine ( selectedModel , {
52
94
initProgressCallback : initProgressCallback ,
53
95
} ) ;
54
- const chatHistory : ChatCompletionMessageParam [ ] = [ ] ;
96
+ modelName . innerText = "Now chatting with " + modelDisplayName ;
55
97
56
- isLoadingParams = true ;
98
+ let chatHistory : ChatCompletionMessageParam [ ] = [ ] ;
57
99
58
100
function enableInputs ( ) {
59
101
if ( isLoadingParams ) {
60
102
sleep ( 500 ) ;
61
- ( < HTMLButtonElement > submitButton ) . disabled = false ;
62
- const loadingBarContainer = document . getElementById ( "loadingContainer" ) ! ;
63
- loadingBarContainer . remove ( ) ;
64
- queryInput . focus ( ) ;
65
103
isLoadingParams = false ;
66
104
}
105
+
106
+ // remove loading bar and loading bar descriptors, if exists
107
+ const initLabel = document . getElementById ( "init-label" ) ;
108
+ initLabel ?. remove ( ) ;
109
+ const loadingBarContainer = document . getElementById ( "loadingContainer" ) ! ;
110
+ loadingBarContainer ?. remove ( ) ;
111
+ queryInput . focus ( ) ;
112
+
113
+ const modelNameArray = selectedModel . split ( "-" ) ;
114
+ modelDisplayName = modelNameArray [ 0 ] ;
115
+ let j = 1 ;
116
+ while ( j < modelNameArray . length && modelNameArray [ j ] [ 0 ] != "q" ) {
117
+ modelDisplayName = modelDisplayName + "-" + modelNameArray [ j ] ;
118
+ j ++ ;
119
+ }
67
120
}
68
121
122
+ let requestInProgress = false ;
123
+
69
124
// Disable submit button if input field is empty
70
125
queryInput . addEventListener ( "keyup" , ( ) => {
71
- if ( ( < HTMLInputElement > queryInput ) . value === "" ) {
126
+ if (
127
+ ( < HTMLInputElement > queryInput ) . value === "" ||
128
+ requestInProgress ||
129
+ isLoadingParams
130
+ ) {
72
131
( < HTMLButtonElement > submitButton ) . disabled = true ;
73
132
} else {
74
133
( < HTMLButtonElement > submitButton ) . disabled = false ;
@@ -85,6 +144,9 @@ queryInput.addEventListener("keyup", (event) => {
85
144
86
145
// Listen for clicks on submit button
87
146
async function handleClick ( ) {
147
+ requestInProgress = true ;
148
+ ( < HTMLButtonElement > submitButton ) . disabled = true ;
149
+
88
150
// Get the message from the input field
89
151
const message = ( < HTMLInputElement > queryInput ) . value ;
90
152
console . log ( "message" , message ) ;
@@ -123,9 +185,72 @@ async function handleClick() {
123
185
const response = await engine . getMessage ( ) ;
124
186
chatHistory . push ( { role : "assistant" , content : await engine . getMessage ( ) } ) ;
125
187
console . log ( "response" , response ) ;
188
+
189
+ requestInProgress = false ;
190
+ ( < HTMLButtonElement > submitButton ) . disabled = false ;
126
191
}
127
192
submitButton . addEventListener ( "click" , handleClick ) ;
128
193
194
+ // listen for changes in modelSelector
195
+ async function handleSelectChange ( ) {
196
+ if ( isLoadingParams ) {
197
+ return ;
198
+ }
199
+
200
+ modelName . innerText = "" ;
201
+
202
+ const initLabel = document . createElement ( "p" ) ;
203
+ initLabel . id = "init-label" ;
204
+ initLabel . innerText = "Initializing model..." ;
205
+ const loadingContainer = document . createElement ( "div" ) ;
206
+ loadingContainer . id = "loadingContainer" ;
207
+
208
+ const loadingBox = getElementAndCheck ( "loadingBox" ) ;
209
+ loadingBox . appendChild ( initLabel ) ;
210
+ loadingBox . appendChild ( loadingContainer ) ;
211
+
212
+ isLoadingParams = true ;
213
+ ( < HTMLButtonElement > submitButton ) . disabled = true ;
214
+
215
+ if ( requestInProgress ) {
216
+ engine . interruptGenerate ( ) ;
217
+ }
218
+ engine . resetChat ( ) ;
219
+ chatHistory = [ ] ;
220
+ await engine . unload ( ) ;
221
+
222
+ selectedModel = modelSelector . value ;
223
+
224
+ progressBar = new Line ( "#loadingContainer" , {
225
+ strokeWidth : 4 ,
226
+ easing : "easeInOut" ,
227
+ duration : 1400 ,
228
+ color : "#ffd166" ,
229
+ trailColor : "#eee" ,
230
+ trailWidth : 1 ,
231
+ svgStyle : { width : "100%" , height : "100%" } ,
232
+ } ) ;
233
+
234
+ initProgressCallback = ( report : InitProgressReport ) => {
235
+ setLabel ( "init-label" , report . text ) ;
236
+ progressBar . animate ( report . progress , {
237
+ duration : 50 ,
238
+ } ) ;
239
+ if ( report . progress == 1.0 ) {
240
+ enableInputs ( ) ;
241
+ }
242
+ } ;
243
+
244
+ engine . setInitProgressCallback ( initProgressCallback ) ;
245
+
246
+ requestInProgress = true ;
247
+ modelName . innerText = "Reloading with new model..." ;
248
+ await engine . reload ( selectedModel ) ;
249
+ requestInProgress = false ;
250
+ modelName . innerText = "Now chatting with " + modelDisplayName ;
251
+ }
252
+ modelSelector . addEventListener ( "change" , handleSelectChange ) ;
253
+
129
254
// Listen for messages from the background script
130
255
chrome . runtime . onMessage . addListener ( ( { answer, error } ) => {
131
256
if ( answer ) {
0 commit comments