Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add simple completions endpoint #125

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/llama/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

from pycape.llms import Cape

c = Cape(url="https://api.capeprivacy.com")

token = c.token(os.getenv("CAPE_TOKEN", ""))

for msg in c.chat_completions(
[
{
"role": "user",
"content": "<s>[INST] <<SYS>>You are a helpful Assistant.<</SYS>>"
"\n\nWhat is the Capital of France? [/INST]",
},
{"role": "system", "content": "you are a happy helpful assistant"},
],
token,
):
print(msg)
62 changes: 58 additions & 4 deletions pycape/llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,53 @@ def token(self, token: Union[str, os.PathLike, tkn.Token]) -> tkn.Token:

raise TypeError(f"Expected token to be PathLike or str, found {type(token)}")

async def completions(
self,
prompt: str,
token: str,
stream=True,
model="llama",
pcrs=None,
):
await self._connect("/v1/cape/ws/completions", token, pcrs=pcrs)

aes_key = os.urandom(32)
user_key = base64.b64encode(aes_key).decode()

data = crypto.envelope_encrypt(
self.ctx.public_key,
{"request": {"prompt": prompt, "stream": stream}, "user_key": user_key},
)
data = base64.b64encode(data).decode()

msg = WSMessage(
msg_type=WSMessageType.COMPLETIONS_REQUEST,
data=data,
).model_dump_json()

await self.ctx.websocket.send(msg)

async for msg in self.ctx.websocket:
msg = WSMessage.model_validate_json(msg)
if msg.msg_type not in [WSMessageType.STREAM_CHUNK, WSMessageType.USAGE]:
raise Exception(
f"expected {WSMessageType.STREAM_CHUNK} or "
f"{WSMessageType.USAGE} not {msg.msg_type}"
)

if msg.msg_type == WSMessageType.USAGE:
yield msg.data
continue

dec = crypto.aes_decrypt(
base64.b64decode(msg.data["data"].encode()), aes_key
)

content = dec.decode()
yield content
if "DONE" in content:
await self.ctx.close()

async def chat_completions(
self,
messages: Union[str, List[Dict[str, Any]]],
Expand Down Expand Up @@ -113,11 +160,16 @@ async def chat_completions(

async for msg in self.ctx.websocket:
msg = WSMessage.model_validate_json(msg)
if msg.msg_type != WSMessageType.STREAM_CHUNK:
if msg.msg_type not in [WSMessageType.STREAM_CHUNK, WSMessageType.USAGE]:
raise Exception(
f"expected {WSMessageType.STREAM_CHUNK} not {msg.msg_type}"
f"expected {WSMessageType.STREAM_CHUNK} or "
f"{WSMessageType.USAGE} not {msg.msg_type}"
)

if msg.msg_type == WSMessageType.USAGE:
yield msg.data
continue

dec = crypto.aes_decrypt(
base64.b64decode(msg.data["data"].encode()), aes_key
)
Expand Down Expand Up @@ -152,6 +204,8 @@ class WSMessageType(str, Enum):
ATTESTATION = "attestation"
STREAM_CHUNK = "stream_chunk"
CHAT_COMPLETIONS_REQUEST = "chat_completion_request"
COMPLETIONS_REQUEST = "completions_request"
USAGE = "usage"


class WSMessage(BaseModel):
Expand All @@ -162,7 +216,7 @@ class WSMessage(BaseModel):
class _Context:
"""A context managing a connection to a particular enclave instance."""

def __init__(self, endpoint, auth_token, root_cert):
def __init__(self, endpoint: str, auth_token: str, root_cert: str):
self._endpoint = _transform_url(endpoint)
self._auth_token = auth_token
self._root_cert = root_cert
Expand All @@ -175,7 +229,7 @@ async def bootstrap(self, pcrs: Optional[Dict[str, List[str]]] = None):
_logger.debug(f"* Dialing {self._endpoint}")
self._websocket = await client.connect(
self._endpoint,
subprotocols=[self._auth_token],
extra_headers={"Authorization": f"Bearer {self._auth_token}"},
max_size=None,
)
_logger.debug("* Websocket connection established")
Expand Down
Loading