Skip to content

Commit

Permalink
add simple completions endpoint (#125)
Browse files Browse the repository at this point in the history
* add simple completions endpoint

* add llama example

* lint

* add usage message
  • Loading branch information
justin1121 authored Sep 26, 2023
1 parent cdeb9d6 commit 8eaa845
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
20 changes: 20 additions & 0 deletions examples/llama/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

from pycape.llms import Cape

c = Cape(url="https://api.capeprivacy.com")

token = c.token(os.getenv("CAPE_TOKEN", ""))

for msg in c.chat_completions(
[
{
"role": "user",
"content": "<s>[INST] <<SYS>>You are a helpful Assistant.<</SYS>>"
"\n\nWhat is the Capital of France? [/INST]",
},
{"role": "system", "content": "you are a happy helpful assistant"},
],
token,
):
print(msg)
62 changes: 58 additions & 4 deletions pycape/llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,53 @@ def token(self, token: Union[str, os.PathLike, tkn.Token]) -> tkn.Token:

raise TypeError(f"Expected token to be PathLike or str, found {type(token)}")

async def completions(
self,
prompt: str,
token: str,
stream=True,
model="llama",
pcrs=None,
):
await self._connect("/v1/cape/ws/completions", token, pcrs=pcrs)

aes_key = os.urandom(32)
user_key = base64.b64encode(aes_key).decode()

data = crypto.envelope_encrypt(
self.ctx.public_key,
{"request": {"prompt": prompt, "stream": stream}, "user_key": user_key},
)
data = base64.b64encode(data).decode()

msg = WSMessage(
msg_type=WSMessageType.COMPLETIONS_REQUEST,
data=data,
).model_dump_json()

await self.ctx.websocket.send(msg)

async for msg in self.ctx.websocket:
msg = WSMessage.model_validate_json(msg)
if msg.msg_type not in [WSMessageType.STREAM_CHUNK, WSMessageType.USAGE]:
raise Exception(
f"expected {WSMessageType.STREAM_CHUNK} or "
f"{WSMessageType.USAGE} not {msg.msg_type}"
)

if msg.msg_type == WSMessageType.USAGE:
yield msg.data
continue

dec = crypto.aes_decrypt(
base64.b64decode(msg.data["data"].encode()), aes_key
)

content = dec.decode()
yield content
if "DONE" in content:
await self.ctx.close()

async def chat_completions(
self,
messages: Union[str, List[Dict[str, Any]]],
Expand Down Expand Up @@ -113,11 +160,16 @@ async def chat_completions(

async for msg in self.ctx.websocket:
msg = WSMessage.model_validate_json(msg)
if msg.msg_type != WSMessageType.STREAM_CHUNK:
if msg.msg_type not in [WSMessageType.STREAM_CHUNK, WSMessageType.USAGE]:
raise Exception(
f"expected {WSMessageType.STREAM_CHUNK} not {msg.msg_type}"
f"expected {WSMessageType.STREAM_CHUNK} or "
f"{WSMessageType.USAGE} not {msg.msg_type}"
)

if msg.msg_type == WSMessageType.USAGE:
yield msg.data
continue

dec = crypto.aes_decrypt(
base64.b64decode(msg.data["data"].encode()), aes_key
)
Expand Down Expand Up @@ -152,6 +204,8 @@ class WSMessageType(str, Enum):
ATTESTATION = "attestation"
STREAM_CHUNK = "stream_chunk"
CHAT_COMPLETIONS_REQUEST = "chat_completion_request"
COMPLETIONS_REQUEST = "completions_request"
USAGE = "usage"


class WSMessage(BaseModel):
Expand All @@ -162,7 +216,7 @@ class WSMessage(BaseModel):
class _Context:
"""A context managing a connection to a particular enclave instance."""

def __init__(self, endpoint, auth_token, root_cert):
def __init__(self, endpoint: str, auth_token: str, root_cert: str):
self._endpoint = _transform_url(endpoint)
self._auth_token = auth_token
self._root_cert = root_cert
Expand All @@ -175,7 +229,7 @@ async def bootstrap(self, pcrs: Optional[Dict[str, List[str]]] = None):
_logger.debug(f"* Dialing {self._endpoint}")
self._websocket = await client.connect(
self._endpoint,
subprotocols=[self._auth_token],
extra_headers={"Authorization": f"Bearer {self._auth_token}"},
max_size=None,
)
_logger.debug("* Websocket connection established")
Expand Down

0 comments on commit 8eaa845

Please sign in to comment.