Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for local mode via Ollama #6

Merged
merged 7 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 159 additions & 62 deletions qiskit_code_assistant_jupyterlab/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import os
from datetime import datetime
from pathlib import Path

import requests
Expand All @@ -24,7 +25,13 @@
from jupyter_server.utils import url_path_join
from qiskit_ibm_runtime import QiskitRuntimeService

runtime_configs = {"service_url": "http://localhost", "api_token": ""}
OPENAI_VERSION = "v1"

runtime_configs = {
"service_url": "http://localhost",
"api_token": "",
"is_openai": False,
}


def update_token(token):
Expand Down Expand Up @@ -55,11 +62,26 @@ def init_token():


def get_header():
return {
header = {
"Accept": "application/json",
"Content-Type": "application/json",
"X-Caller": "qiskit-code-assistant-jupyterlab",
"Authorization": f"Bearer {runtime_configs['api_token']}",
}
if not runtime_configs["is_openai"]:
header["Authorization"] = f"Bearer {runtime_configs['api_token']}"
return header


def convert_openai(model):
return {
"_id": model["id"],
"disclaimer": {"accepted": "true"},
"display_name": model["id"],
"doc_link": "",
"license": {"name": "", "link": ""},
"model_id": model["id"],
"prompt_type": 1,
"token_limit": 255
}


Expand All @@ -74,13 +96,20 @@ def post(self):

runtime_configs["service_url"] = json_payload["url"]

self.finish(json.dumps({"url": runtime_configs["service_url"]}))
try:
r = requests.get(url_path_join(runtime_configs["service_url"]), headers=get_header())
runtime_configs["is_openai"] = (r.json()["name"] != "qiskit-code-assistant")
except (requests.exceptions.JSONDecodeError, KeyError):
runtime_configs["is_openai"] = True
finally:
self.finish(json.dumps({"url": runtime_configs["service_url"]}))


class TokenHandler(APIHandler):
@tornado.web.authenticated
def get(self):
self.finish(json.dumps({"success": (runtime_configs["api_token"] != "")}))
self.finish(json.dumps({"success": (runtime_configs["api_token"] != ""
or runtime_configs["is_openai"])}))

@tornado.web.authenticated
def post(self):
Expand All @@ -94,98 +123,166 @@ def post(self):
class ModelsHandler(APIHandler):
@tornado.web.authenticated
def get(self):
url = url_path_join(runtime_configs["service_url"], "models")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], OPENAI_VERSION, "models")
models = []
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
data = r.json()["data"]
models = list(map(convert_openai, data))
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps({"models": models}))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "models")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class ModelHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id)

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], OPENAI_VERSION, "models", id)
model = {}
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
model = convert_openai(r.json())
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(model))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id)

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class DisclaimerHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id, "disclaimer")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id, "disclaimer")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class DisclaimerAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(
runtime_configs["service_url"], "disclaimer", id, "acceptance"
)

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(
runtime_configs["service_url"], "disclaimer", id, "acceptance"
)

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class PromptHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id, "prompt")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], OPENAI_VERSION, "completions")
result = {}
try:
r = requests.post(url,
headers=get_header(),
json={
"model": id,
"prompt": self.get_json_body()["input"]
})
r.raise_for_status()

if r.ok:
response = r.json()
result = {
"results": list(map(lambda c: {"generated_text": c["text"]},
response["choices"])),
"prompt_id": response["id"],
"created_at": datetime.fromtimestamp(int(response["created"])).isoformat()
}
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(result))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id, "prompt")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class PromptAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
self.finish(json.dumps({"success": "true"}))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


def setup_handlers(web_app):
host_pattern = ".*$"
id_regex = r"(?P<id>[\w\-]+)"
id_regex = r"(?P<id>[\w\-\_\.\:]+)" # valid chars: alphanum | "-" | "_" | "." | ":"
base_url = url_path_join(web_app.settings["base_url"], "qiskit-code-assistant")

handlers = [
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ const plugin: JupyterFrontEndPlugin<void> = {

postServiceUrl(settings.composite['serviceUrl'] as string);
settings.changed.connect(() =>
postServiceUrl(settings.composite['serviceUrl'] as string)
postServiceUrl(settings.composite['serviceUrl'] as string).then(() =>
refreshModelsList()
)
);

const provider = new QiskitCompletionProvider({ settings });
Expand Down
28 changes: 12 additions & 16 deletions src/service/autocomplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { getModel, postModelPrompt } from './api';
import { postModelPrompt } from './api';
import { showDisclaimer } from './disclaimer';
import { getCurrentModel } from './modelHandler';
import { checkAPIToken } from './token';
Expand Down Expand Up @@ -51,25 +51,21 @@ export async function autoComplete(text: string): Promise<ICompletionReturn> {
const requestText = text.slice(startingOffset, text.length);
const model = getCurrentModel();

return await getModel(model?._id || '')
.then(async model => {
if (model.disclaimer?.accepted) {
if (model === undefined) {
console.error('Failed to send prompt', 'No model selected');
return emptyReturn;
} else if (model.disclaimer?.accepted) {
return await promptPromise(model._id, requestText);
} else {
return await showDisclaimer(model._id).then(async accepted => {
if (accepted) {
return await promptPromise(model._id, requestText);
} else {
return await showDisclaimer(model._id).then(async accepted => {
if (accepted) {
return await promptPromise(model._id, requestText);
} else {
console.error('Disclaimer not accepted');
return emptyReturn;
}
});
console.error('Disclaimer not accepted');
return emptyReturn;
}
})
.catch(reason => {
console.error('Failed to send prompt', reason);
return emptyReturn;
});
}
})
.catch(reason => {
console.error('Failed to send prompt', reason);
Expand Down
Loading