Skip to content

Commit

Permalink
Lazify deps
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Nov 7, 2024
1 parent 7d072e0 commit c3a5987
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 54 deletions.
18 changes: 4 additions & 14 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,15 @@

import httpx
from httpx import URL, SyncByteStream, ByteStream
from tokenizers import Tokenizer # type: ignore

from . import GenerateStreamedResponse, Generation, \
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
ApiMetaBilledUnits
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore


try:
import boto3 # type: ignore
from botocore.auth import SigV4Auth # type: ignore
from botocore.awsrequest import AWSRequest # type: ignore
AWS_DEPS_AVAILABLE = True
except ImportError:
AWS_DEPS_AVAILABLE = False

class AwsClient(Client):
def __init__(
self,
Expand All @@ -33,8 +25,6 @@ def __init__(
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
if not AWS_DEPS_AVAILABLE:
raise ImportError("AWS dependencies not available. Please install boto3 and botocore.")
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
Expand Down Expand Up @@ -183,14 +173,14 @@ def map_request_to_bedrock(
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
session = boto3.Session(
session = lazy_boto3().Session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
credentials = session.get_credentials()
signer = SigV4Auth(credentials, service, session.region_name)
signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)

def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
Expand Down Expand Up @@ -220,7 +210,7 @@ def _event_hook(request: httpx.Request) -> None:
request._content = new_body
headers["content-length"] = str(len(new_body))

aws_request = AWSRequest(
aws_request = lazy_botocore().awsrequest.AWSRequest(
method=request.method,
url=url,
headers=headers,
Expand Down
68 changes: 28 additions & 40 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,7 @@
from .summary import Summary
from .mode import Mode
import typing

# Try to import sagemaker and related modules
try:
import sagemaker as sage
from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url
import boto3
from botocore.exceptions import (
ClientError, EndpointConnectionError, ParamValidationError)
AWS_DEPS_AVAILABLE = True
except ImportError:
AWS_DEPS_AVAILABLE = False
from ..lazy_aws_deps import lazy_boto3, lazy_botocore, lazy_sagemaker

class Client:
def __init__(
Expand All @@ -37,21 +27,19 @@ def __init__(
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
if not AWS_DEPS_AVAILABLE:
raise CohereError("AWS dependencies not available. Please install boto3 and sagemaker.")
self._client = boto3.client("sagemaker-runtime", region_name=aws_region)
self._service_client = boto3.client("sagemaker", region_name=aws_region)
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = sage.Session(sagemaker_client=self._service_client)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
self.mode = Mode.SAGEMAKER



def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
self._service_client.describe_endpoint(EndpointName=endpoint_name)
except ClientError:
except lazy_botocore().ClientError:
return False
return True

Expand Down Expand Up @@ -87,7 +75,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:
# Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz
s3_tar_models = [
s3_path
for s3_path in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
for s3_path in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
if (
s3_path.endswith(".tar.gz") # only .tar.gz files
and (s3_path.split("/")[-1] != "models.tar.gz") # exclude the .tar.gz file we are creating
Expand All @@ -109,7 +97,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:
# Download and extract all fine-tuned models
for s3_tar_model in s3_tar_models:
print(f"Adding fine-tuned model: {s3_tar_model}")
S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess)
lazy_sagemaker().s3.S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess)
with tarfile.open(os.path.join(local_tar_models_dir, s3_tar_model.split("/")[-1])) as tar:
tar.extractall(local_models_dir)

Expand All @@ -120,10 +108,10 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:

# Upload the new tarfile containing all models to s3
# Very important to remove the trailing slash from s3_models_dir otherwise it just doesn't upload
model_tar_s3 = S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess)
model_tar_s3 = lazy_sagemaker().s3.S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess)

# sanity check
assert s3_models_dir + "models.tar.gz" in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
assert s3_models_dir + "models.tar.gz" in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)

return model_tar_s3

Expand Down Expand Up @@ -180,17 +168,17 @@ def create_endpoint(
# Otherwise it might block deployment
try:
self._service_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
except ClientError:
except lazy_botocore().ClientError:
pass

if role is None:
try:
role = sage.get_execution_role()
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"

model = sage.ModelPackage(
model = lazy_sagemaker().ModelPackage(
role=role,
model_data=model_data,
sagemaker_session=self._sess, # makes sure the right region is used
Expand All @@ -204,7 +192,7 @@ def create_endpoint(
endpoint_name=endpoint_name,
**validation_params
)
except ParamValidationError:
except lazy_botocore().ParamValidationError:
# For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
model.deploy(n_instances, instance_type, endpoint_name=endpoint_name)
self.connect_to_endpoint(endpoint_name)
Expand Down Expand Up @@ -366,7 +354,7 @@ def _sagemaker_chat(self, json_params: Dict[str, Any], variant: str) :
else:
result = self._client.invoke_endpoint(**params)
return Chat.from_dict(json.loads(result['Body'].read().decode()))
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -398,7 +386,7 @@ def _bedrock_chat(self, json_params: Dict[str, Any], model_id: str) :
result = self._client.invoke_model(**params)
return Chat.from_dict(
json.loads(result['body'].read().decode()))
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -473,7 +461,7 @@ def _sagemaker_generations(self, json_params: Dict[str, Any], variant: str) :
result = self._client.invoke_endpoint(**params)
return Generations(
json.loads(result['Body'].read().decode())['generations'])
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand All @@ -498,7 +486,7 @@ def _bedrock_generations(self, json_params: Dict[str, Any], model_id: str) :
result = self._client.invoke_model(**params)
return Generations(
json.loads(result['body'].read().decode())['generations'])
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -546,7 +534,7 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand All @@ -567,7 +555,7 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
try:
result = self._client.invoke_model(**params)
response = json.loads(result['body'].read().decode())
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -631,7 +619,7 @@ def rerank(self,
reranking = Reranking(response)
for rank in reranking.results:
rank.document = parsed_docs[rank.index]
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand All @@ -658,7 +646,7 @@ def classify(self, input: List[str], name: str) -> Classifications:
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result["Body"].read().decode())
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -705,13 +693,13 @@ def create_finetune(

if role is None:
try:
role = sage.get_execution_role()
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"

training_parameters.update({"name": name})
estimator = sage.algorithm.AlgorithmEstimator(
estimator = lazy_sagemaker().algorithm.AlgorithmEstimator(
algorithm_arn=arn,
role=role,
instance_count=1,
Expand All @@ -734,7 +722,7 @@ def create_finetune(

current_filepath = f"{s3_models_dir}{job_name}/output/model.tar.gz"

s3_resource = boto3.resource("s3")
s3_resource = lazy_boto3().resource("s3")

# Copy new model to root of output_model_dir
bucket, old_key = parse_s3_url(current_filepath)
Expand Down Expand Up @@ -774,14 +762,14 @@ def export_finetune(

if role is None:
try:
role = sage.get_execution_role()
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"

export_parameters = {"name": name}

estimator = sage.algorithm.AlgorithmEstimator(
estimator = lazy_sagemaker().algorithm.AlgorithmEstimator(
algorithm_arn=arn,
role=role,
instance_count=1,
Expand All @@ -800,7 +788,7 @@ def export_finetune(
job_name = estimator.latest_training_job.name
current_filepath = f"{s3_output_dir}{job_name}/output/model.tar.gz"

s3_resource = boto3.resource("s3")
s3_resource = lazy_boto3().resource("s3")

# Copy the exported TensorRT-LLM engine to the root of s3_output_dir
bucket, old_key = parse_s3_url(current_filepath)
Expand Down Expand Up @@ -940,7 +928,7 @@ def summarize(
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
summary = Summary(response)
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down
23 changes: 23 additions & 0 deletions src/cohere/manually_maintained/lazy_aws_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@


def lazy_sagemaker():
try:
import sagemaker as sage
return sage
except ImportError:
raise CohereError("Sagemaker not available. Please install sagemaker.")

def lazy_boto3():
try:
import boto3
return boto3
except ImportError:
raise CohereError("Boto3 not available. Please install lazy_boto3().")

def lazy_botocore():
try:
import botocore
return botocore
except ImportError:
raise CohereError("Botocore not available. Please install botocore.")

0 comments on commit c3a5987

Please sign in to comment.