Skip to content

Commit

Permalink
Add chat completions client (#124)
Browse files Browse the repository at this point in the history
* Add chat completions client

Plus update some dependencies.

* clean up docs

* lint

* close connection after done msg from server

* add test for crypto.py

* add test for envelope_encryption
  • Loading branch information
justin1121 authored Sep 22, 2023
1 parent 23333d0 commit cdeb9d6
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 21 deletions.
11 changes: 3 additions & 8 deletions pycape/cape.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"""
import base64
import contextlib
import json
import logging
import os
Expand Down Expand Up @@ -64,6 +65,7 @@
_synchronizer = synchronicity.Synchronizer(multiwrap_warning=True)


@_synchronizer.create_blocking
class Cape:
"""A websocket client for interacting with enclaves hosting Cape functions.
Expand Down Expand Up @@ -91,14 +93,12 @@ def __init__(
if verbose:
_logger.setLevel(logging.DEBUG)

@_synchronizer
async def close(self):
"""Closes the current enclave connection."""
if self._ctx is not None:
await self._ctx.close()
self._ctx = None

@_synchronizer
async def connect(
self,
function_ref: Union[str, os.PathLike, fref.FunctionRef],
Expand Down Expand Up @@ -134,7 +134,6 @@ async def connect(
token = self.token(token)
await self._request_connection(function_ref, token, pcrs)

@_synchronizer
async def encrypt(
self,
input: bytes,
Expand Down Expand Up @@ -235,8 +234,7 @@ def function(

raise ValueError("Unrecognized form of `identifier` argument: {identifier}.")

@_synchronizer
@_synchronizer.asynccontextmanager
@contextlib.asynccontextmanager
async def function_context(
self,
function_ref: Union[str, os.PathLike, fref.FunctionRef],
Expand Down Expand Up @@ -281,7 +279,6 @@ async def function_context(
finally:
await self.close()

@_synchronizer
async def invoke(
self, *args: Any, serde_hooks=None, use_serdio: bool = False, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -323,7 +320,6 @@ async def invoke(
serde_hooks = serdio.bundle_serde_hooks(serde_hooks)
return await self._request_invocation(serde_hooks, use_serdio, *args, **kwargs)

@_synchronizer
async def key(
self,
*,
Expand Down Expand Up @@ -390,7 +386,6 @@ async def key(
"account's Cape key."
)

@_synchronizer
async def run(
self,
function_ref: Union[str, os.PathLike, fref.FunctionRef],
Expand Down
3 changes: 3 additions & 0 deletions pycape/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pycape.llms.llms import Cape

__all__ = ["Cape"]
45 changes: 45 additions & 0 deletions pycape/llms/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
import os
from typing import Any
from typing import Dict

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers import aead

NONCE_SIZE = 12


def aes_decrypt(ctxt: bytes, key: bytes) -> bytes:
nonce, ctxt = ctxt[:NONCE_SIZE], ctxt[NONCE_SIZE:]
encryptor = aead.AESGCM(key)
ptxt = encryptor.decrypt(nonce, ctxt, None)
return ptxt


def aes_encrypt(ptxt: bytes, key: bytes):
encryptor = aead.AESGCM(key)
nonce = os.urandom(NONCE_SIZE)
ctxt = encryptor.encrypt(nonce, ptxt, None)
return nonce + ctxt


def envelope_encrypt(public_key: bytes, data: Dict[str, Any]) -> bytes:
aes_key = os.urandom(32)
s = json.dumps(data)

enc_data = aes_encrypt(s.encode(), aes_key)

pub = serialization.load_pem_public_key(public_key)

enc_data_key = pub.encrypt(
aes_key,
padding=padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)

return enc_data_key + enc_data
54 changes: 54 additions & 0 deletions pycape/llms/crypto_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
import os

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa

from pycape.llms.crypto import aes_decrypt
from pycape.llms.crypto import aes_encrypt
from pycape.llms.crypto import envelope_encrypt

KEY_PREFIX_LENGTH = 512


def _envelope_decrypt(ciphertext: bytes, priv_key: rsa.RSAPrivateKey):
enc_data_key, encrypted_data = (
ciphertext[:KEY_PREFIX_LENGTH],
ciphertext[KEY_PREFIX_LENGTH:],
)

data_key = priv_key.decrypt(
enc_data_key,
padding=padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)

return json.loads(aes_decrypt(encrypted_data, data_key))


def test_encrypt_decrypt():
expected = b"hi there"

key = os.urandom(32)
ciphertext = aes_encrypt(expected, key)

assert expected == aes_decrypt(ciphertext, key)


def test_envelope_encrypt():
private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096)
pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

expected = {"hi": "hello"}

ciphertext = envelope_encrypt(pem, expected)

assert expected == _envelope_decrypt(ciphertext, private_key)
Loading

0 comments on commit cdeb9d6

Please sign in to comment.