Skip to content

Commit

Permalink
feat: add encryption support (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
WHALEEYE authored Aug 31, 2024
1 parent cde9f66 commit eb406b6
Show file tree
Hide file tree
Showing 8 changed files with 509 additions and 335 deletions.
30 changes: 25 additions & 5 deletions crab/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import json
import logging
from typing import Any

from httpx import Client

from crab.utils import decrypt_message, encrypt_message, generate_key_from_env

from .exceptions import ActionNotFound
from .models import Action, ClosedAction, EnvironmentConfig

Expand Down Expand Up @@ -89,6 +92,8 @@ def __init__(
for key, value in extra_attributes.items():
setattr(self, key, value)

self._enc_key = generate_key_from_env()

def step(
self,
action_name: str,
Expand Down Expand Up @@ -210,15 +215,30 @@ def observation_space(self) -> list[ClosedAction]:
def _action_endpoint(self, action: Action, parameters: dict[str, Any]):
"""Rewrite to support different environments."""
if self._client is not None and not action.local:
data = json.dumps(
{
"action": action.to_raw_action(),
"parameters": action.parameters(**parameters).model_dump(),
}
)
content_type = "application/json"
if self._enc_key is not None:
data = encrypt_message(data, self._enc_key)
content_type = "text/plain"

# send action to remote machine
response = self._client.post(
"/raw_action",
json={
"action": action.to_raw_action(),
"parameters": action.parameters(**parameters).model_dump(),
},
content=data,
headers={"Content-Type": content_type},
)
return response.json()["action_returns"]

resp_content = response.content.decode("utf-8")
if self._enc_key is not None:
resp_content = decrypt_message(resp_content, self._enc_key)

resp_json = json.loads(resp_content)
return resp_json["action_returns"]
else:
# or directly execute it
action = action.set_kept_param(env=self)
Expand Down
26 changes: 23 additions & 3 deletions crab/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import json

from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse, PlainTextResponse

from crab.utils.common import base64_to_callable
from crab.utils import (
base64_to_callable,
decrypt_message,
encrypt_message,
generate_key_from_env,
)

from .logger import crab_logger as logger

Expand All @@ -23,12 +31,24 @@
@api_router.post("/raw_action")
async def raw_action(request: Request):
"""Perform the specified action with given parameters."""
enc_key = generate_key_from_env()
# Extract query parameters as a dictionary
request_json = await request.json()
request_content = await request.body()
request_content = request_content.decode("utf-8")
if enc_key is not None:
request_content = decrypt_message(request_content, enc_key)
request_json = json.loads(request_content)

action = request_json["action"]
parameters = request_json["parameters"]
entry = base64_to_callable(action["dumped_entry"])
logger.info(f"remote action: {action['name']} received. parameters: {parameters}")
if "env" in action["kept_params"]:
parameters["env"] = request.app.environment
return {"action_returns": entry(**parameters)}

resp_data = {"action_returns": entry(**parameters)}
if enc_key is None:
return JSONResponse(content=resp_data)
else:
encrypted = encrypt_message(json.dumps(resp_data), enc_key)
return PlainTextResponse(content=encrypted)
35 changes: 35 additions & 0 deletions crab/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========

from crab.utils.common import (
base64_to_callable,
base64_to_image,
callable_to_base64,
image_to_base64,
)
from crab.utils.encryption import (
decrypt_message,
encrypt_message,
generate_key_from_env,
)

__all__ = [
"base64_to_image",
"image_to_base64",
"callable_to_base64",
"base64_to_callable",
"decrypt_message",
"encrypt_message",
"generate_key_from_env",
]
77 changes: 77 additions & 0 deletions crab/utils/encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import base64
import hashlib
import logging
import os
from typing import Optional

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

logger = logging.getLogger("encryption")


def encrypt_message(plaintext: str, key: bytes) -> str:
"""Encrypts a message using a key with AES 256 encryption.
Args:
plaintext (str): The message to encrypt.
key (bytes): The encryption key, should be 256 bits.
Returns:
str: The encrypted message encoded in base64.
"""
nounce = os.urandom(12)
cipher = Cipher(algorithms.AES(key), modes.GCM(nounce), backend=default_backend())
encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext.encode()) + encryptor.finalize()
return base64.b64encode(nounce + ciphertext + encryptor.tag).decode("utf-8")


def decrypt_message(encrypted: str, key: bytes) -> str:
"""Decrypts an encrypted message using a key with AES 256 encryption.
Args:
encrypted (str): The encrypted message encoded in base64.
key (bytes): The encryption key, should be 256 bits.
Returns:
str: The decrypted message.
"""
encrypted = base64.b64decode(encrypted)
nounce = encrypted[:12]
ciphertext = encrypted[12:-16]
tag = encrypted[-16:]
cipher = Cipher(
algorithms.AES(key), modes.GCM(nounce, tag), backend=default_backend()
)
decryptor = cipher.decryptor()
return (decryptor.update(ciphertext) + decryptor.finalize()).decode("utf-8")


def generate_key_from_env() -> Optional[bytes]:
"""Generate the encryption key from the environment variable `CRAB_ENC_KEY`.
Returns:
Optional[bytes]: The encryption key. If the environment variable is not set or
empty, return None.
"""
enc_key = os.environ.get("CRAB_ENC_KEY")
# don't encrypt as long as the key is an empty value
if not enc_key:
logger.warning("CRAB_ENC_KEY is not set, connection will not be encrypted.")
return None

return hashlib.sha256(enc_key.encode("utf-8")).digest()
Loading

0 comments on commit eb406b6

Please sign in to comment.