Skip to content

Commit

Permalink
Add support for local mode via Ollama (#6)
Browse files Browse the repository at this point in the history
Adds support for running against local models by supporting the OpenAI API
in addition to the Qiskit Code Assistant service API.

This allows users to input any OpenAI compatible API URL, such as Ollama,
instead of a Qiskit Code Assistant service URL and the server extension will
detect which API is set and call the correct endpoints.
  • Loading branch information
ajbozarth authored Nov 11, 2024
1 parent 2f3a7be commit 8b7d3a7
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 85 deletions.
7 changes: 5 additions & 2 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
## Requirements

- JupyterLab >= 4.3.0
- An IBM Quantum premium account
- Access to either:
- An IBM Quantum premium account
- A service exposing LLMs using OpenAI-compatible API endpoints

## Install

Expand Down Expand Up @@ -111,7 +113,8 @@ There are a few settings we recommend to edit in your user settings.
`Tab`, the inline completer has a default of 10 seconds.

3. If you want to change the instance of the Qiskit Code Assistant Service that the
extension should use you can edit the Qiskit Code Assistant setting `serviceUrl`
extension should use you can edit the Qiskit Code Assistant setting `serviceUrl`.
This can also be set to any service exposing LLMs using OpenAI-compatible API endpoints.

4. Keyboard shortcuts can be changed by searching for `completer` in the Keyboard Shortcuts
settings and adding new shortcuts for the relevant commands.
Expand Down
5 changes: 2 additions & 3 deletions README-PyPi.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Qiskit Code Assistant (Beta)

> This experimental feature is only available, as of today, to IBM Quantum premium users.
> If you are not part of the IBM Quantum premium plan, you can still install this extension; however you will not be able to use the assistant.
> The Qiskit Code Assistant is a beta release, subject to change.
Write and optimize Qiskit code with a generative AI code assistant.
Expand Down Expand Up @@ -117,7 +115,8 @@ There are a few settings we recommend to edit in your user settings.
`Tab`, the inline completer has a default of 10 seconds.

3. If you want to change the instance of the Qiskit Code Assistant Service that the
extension should use you can edit the Qiskit Code Assistant setting `serviceUrl`
extension should use you can edit the Qiskit Code Assistant setting `serviceUrl`.
This can also be set to any service exposing LLMs using OpenAI-compatible API endpoints.

4. Keyboard shortcuts can be changed by searching for `completer` in the Keyboard Shortcuts
settings and adding new shortcuts for the relevant commands.
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ for the frontend extension.
## Requirements

- JupyterLab >= 4.3.0
- An IBM Quantum premium account
- Access to either:
- An IBM Quantum premium account
- A service exposing LLMs using OpenAI-compatible API endpoints

## Install

Expand Down
221 changes: 159 additions & 62 deletions qiskit_code_assistant_jupyterlab/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import os
from datetime import datetime
from pathlib import Path

import requests
Expand All @@ -24,7 +25,13 @@
from jupyter_server.utils import url_path_join
from qiskit_ibm_runtime import QiskitRuntimeService

runtime_configs = {"service_url": "http://localhost", "api_token": ""}
OPENAI_VERSION = "v1"

runtime_configs = {
"service_url": "http://localhost",
"api_token": "",
"is_openai": False,
}


def update_token(token):
Expand Down Expand Up @@ -55,11 +62,26 @@ def init_token():


def get_header():
return {
header = {
"Accept": "application/json",
"Content-Type": "application/json",
"X-Caller": "qiskit-code-assistant-jupyterlab",
"Authorization": f"Bearer {runtime_configs['api_token']}",
}
if not runtime_configs["is_openai"]:
header["Authorization"] = f"Bearer {runtime_configs['api_token']}"
return header


def convert_openai(model):
return {
"_id": model["id"],
"disclaimer": {"accepted": "true"},
"display_name": model["id"],
"doc_link": "",
"license": {"name": "", "link": ""},
"model_id": model["id"],
"prompt_type": 1,
"token_limit": 255
}


Expand All @@ -74,13 +96,20 @@ def post(self):

runtime_configs["service_url"] = json_payload["url"]

self.finish(json.dumps({"url": runtime_configs["service_url"]}))
try:
r = requests.get(url_path_join(runtime_configs["service_url"]), headers=get_header())
runtime_configs["is_openai"] = (r.json()["name"] != "qiskit-code-assistant")
except (requests.exceptions.JSONDecodeError, KeyError):
runtime_configs["is_openai"] = True
finally:
self.finish(json.dumps({"url": runtime_configs["service_url"]}))


class TokenHandler(APIHandler):
@tornado.web.authenticated
def get(self):
self.finish(json.dumps({"success": (runtime_configs["api_token"] != "")}))
self.finish(json.dumps({"success": (runtime_configs["api_token"] != ""
or runtime_configs["is_openai"])}))

@tornado.web.authenticated
def post(self):
Expand All @@ -94,93 +123,161 @@ def post(self):
class ModelsHandler(APIHandler):
@tornado.web.authenticated
def get(self):
url = url_path_join(runtime_configs["service_url"], "models")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], OPENAI_VERSION, "models")
models = []
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
data = r.json()["data"]
models = list(map(convert_openai, data))
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps({"models": models}))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "models")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class ModelHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id)

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], OPENAI_VERSION, "models", id)
model = {}
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
model = convert_openai(r.json())
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(model))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id)

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class DisclaimerHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id, "disclaimer")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id, "disclaimer")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class DisclaimerAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(
runtime_configs["service_url"], "disclaimer", id, "acceptance"
)

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(
runtime_configs["service_url"], "disclaimer", id, "acceptance"
)

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class PromptHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id, "prompt")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], OPENAI_VERSION, "completions")
result = {}
try:
r = requests.post(url,
headers=get_header(),
json={
"model": id,
"prompt": self.get_json_body()["input"]
})
r.raise_for_status()

if r.ok:
response = r.json()
result = {
"results": list(map(lambda c: {"generated_text": c["text"]},
response["choices"])),
"prompt_id": response["id"],
"created_at": datetime.fromtimestamp(int(response["created"])).isoformat()
}
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(result))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id, "prompt")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class PromptAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_openai"]:
self.finish(json.dumps({"success": "true"}))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class FeedbackHandler(APIHandler):
Expand All @@ -200,7 +297,7 @@ def post(self):

def setup_handlers(web_app):
host_pattern = ".*$"
id_regex = r"(?P<id>[\w\-]+)"
id_regex = r"(?P<id>[\w\-\_\.\:]+)" # valid chars: alphanum | "-" | "_" | "." | ":"
base_url = url_path_join(web_app.settings["base_url"], "qiskit-code-assistant")

handlers = [
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ const plugin: JupyterFrontEndPlugin<void> = {

postServiceUrl(settings.composite['serviceUrl'] as string);
settings.changed.connect(() =>
postServiceUrl(settings.composite['serviceUrl'] as string)
postServiceUrl(settings.composite['serviceUrl'] as string).then(() =>
refreshModelsList()
)
);

const provider = new QiskitCompletionProvider({ settings, app });
Expand Down
Loading

0 comments on commit 8b7d3a7

Please sign in to comment.