-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add chat completions client Plus update some dependencies. * clean up docs * lint * close connection after done msg from server * add test for crypto.py * add test for envelope_encryption
- Loading branch information
1 parent
23333d0
commit cdeb9d6
Showing
9 changed files
with
382 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pycape.llms.llms import Cape | ||
|
||
__all__ = ["Cape"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import json | ||
import os | ||
from typing import Any | ||
from typing import Dict | ||
|
||
from cryptography.hazmat.primitives import hashes | ||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.asymmetric import padding | ||
from cryptography.hazmat.primitives.ciphers import aead | ||
|
||
NONCE_SIZE = 12 | ||
|
||
|
||
def aes_decrypt(ctxt: bytes, key: bytes) -> bytes: | ||
nonce, ctxt = ctxt[:NONCE_SIZE], ctxt[NONCE_SIZE:] | ||
encryptor = aead.AESGCM(key) | ||
ptxt = encryptor.decrypt(nonce, ctxt, None) | ||
return ptxt | ||
|
||
|
||
def aes_encrypt(ptxt: bytes, key: bytes): | ||
encryptor = aead.AESGCM(key) | ||
nonce = os.urandom(NONCE_SIZE) | ||
ctxt = encryptor.encrypt(nonce, ptxt, None) | ||
return nonce + ctxt | ||
|
||
|
||
def envelope_encrypt(public_key: bytes, data: Dict[str, Any]) -> bytes: | ||
aes_key = os.urandom(32) | ||
s = json.dumps(data) | ||
|
||
enc_data = aes_encrypt(s.encode(), aes_key) | ||
|
||
pub = serialization.load_pem_public_key(public_key) | ||
|
||
enc_data_key = pub.encrypt( | ||
aes_key, | ||
padding=padding.OAEP( | ||
mgf=padding.MGF1(algorithm=hashes.SHA256()), | ||
algorithm=hashes.SHA256(), | ||
label=None, | ||
), | ||
) | ||
|
||
return enc_data_key + enc_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import json | ||
import os | ||
|
||
from cryptography.hazmat.primitives import hashes | ||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.asymmetric import padding | ||
from cryptography.hazmat.primitives.asymmetric import rsa | ||
|
||
from pycape.llms.crypto import aes_decrypt | ||
from pycape.llms.crypto import aes_encrypt | ||
from pycape.llms.crypto import envelope_encrypt | ||
|
||
KEY_PREFIX_LENGTH = 512 | ||
|
||
|
||
def _envelope_decrypt(ciphertext: bytes, priv_key: rsa.RSAPrivateKey): | ||
enc_data_key, encrypted_data = ( | ||
ciphertext[:KEY_PREFIX_LENGTH], | ||
ciphertext[KEY_PREFIX_LENGTH:], | ||
) | ||
|
||
data_key = priv_key.decrypt( | ||
enc_data_key, | ||
padding=padding.OAEP( | ||
mgf=padding.MGF1(algorithm=hashes.SHA256()), | ||
algorithm=hashes.SHA256(), | ||
label=None, | ||
), | ||
) | ||
|
||
return json.loads(aes_decrypt(encrypted_data, data_key)) | ||
|
||
|
||
def test_encrypt_decrypt(): | ||
expected = b"hi there" | ||
|
||
key = os.urandom(32) | ||
ciphertext = aes_encrypt(expected, key) | ||
|
||
assert expected == aes_decrypt(ciphertext, key) | ||
|
||
|
||
def test_envelope_encrypt(): | ||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) | ||
pem = private_key.public_key().public_bytes( | ||
encoding=serialization.Encoding.PEM, | ||
format=serialization.PublicFormat.SubjectPublicKeyInfo, | ||
) | ||
|
||
expected = {"hi": "hello"} | ||
|
||
ciphertext = envelope_encrypt(pem, expected) | ||
|
||
assert expected == _envelope_decrypt(ciphertext, private_key) |
Oops, something went wrong.