Skip to content

Commit

Permalink
add simple completions endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
justin1121 committed Sep 26, 2023
1 parent cdeb9d6 commit 83bc79a
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions pycape/llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,48 @@ def token(self, token: Union[str, os.PathLike, tkn.Token]) -> tkn.Token:

raise TypeError(f"Expected token to be PathLike or str, found {type(token)}")

async def completions(
self,
prompt: str,
token: str,
stream=True,
model="llama",
pcrs=None,
):
await self._connect("/v1/cape/ws/completions", token, pcrs=pcrs)

aes_key = os.urandom(32)
user_key = base64.b64encode(aes_key).decode()

data = crypto.envelope_encrypt(
self.ctx.public_key,
{"request": {"prompt": prompt, "stream": stream}, "user_key": user_key},
)
data = base64.b64encode(data).decode()

msg = WSMessage(
msg_type=WSMessageType.COMPLETIONS_REQUEST,
data=data,
).model_dump_json()

await self.ctx.websocket.send(msg)

async for msg in self.ctx.websocket:
msg = WSMessage.model_validate_json(msg)
if msg.msg_type != WSMessageType.STREAM_CHUNK:
raise Exception(
f"expected {WSMessageType.STREAM_CHUNK} not {msg.msg_type}"
)

dec = crypto.aes_decrypt(
base64.b64decode(msg.data["data"].encode()), aes_key
)

content = dec.decode()
yield content
if "DONE" in content:
await self.ctx.close()

async def chat_completions(
self,
messages: Union[str, List[Dict[str, Any]]],
Expand Down Expand Up @@ -152,6 +194,7 @@ class WSMessageType(str, Enum):
ATTESTATION = "attestation"
STREAM_CHUNK = "stream_chunk"
CHAT_COMPLETIONS_REQUEST = "chat_completion_request"
COMPLETIONS_REQUEST = "completions_request"


class WSMessage(BaseModel):
Expand All @@ -162,7 +205,7 @@ class WSMessage(BaseModel):
class _Context:
"""A context managing a connection to a particular enclave instance."""

def __init__(self, endpoint, auth_token, root_cert):
def __init__(self, endpoint: str, auth_token: str, root_cert: str):
self._endpoint = _transform_url(endpoint)
self._auth_token = auth_token
self._root_cert = root_cert
Expand All @@ -175,7 +218,7 @@ async def bootstrap(self, pcrs: Optional[Dict[str, List[str]]] = None):
_logger.debug(f"* Dialing {self._endpoint}")
self._websocket = await client.connect(
self._endpoint,
subprotocols=[self._auth_token],
extra_headers={"Authorization": f"Bearer {self._auth_token}"},
max_size=None,
)
_logger.debug("* Websocket connection established")
Expand Down

0 comments on commit 83bc79a

Please sign in to comment.