From c3a5987f1181079430363a9e57fe6bc5e2fed131 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Thu, 7 Nov 2024 14:08:43 +0000 Subject: [PATCH] Lazify deps --- src/cohere/aws_client.py | 18 ++--- .../manually_maintained/cohere_aws/client.py | 68 ++++++++----------- .../manually_maintained/lazy_aws_deps.py | 23 +++++++ 3 files changed, 55 insertions(+), 54 deletions(-) create mode 100644 src/cohere/manually_maintained/lazy_aws_deps.py diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index f34d60a54..cdbeeedbe 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -5,23 +5,15 @@ import httpx from httpx import URL, SyncByteStream, ByteStream -from tokenizers import Tokenizer # type: ignore from . import GenerateStreamedResponse, Generation, \ NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \ ApiMetaBilledUnits from .client import Client, ClientEnvironment from .core import construct_type +from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore -try: - import boto3 # type: ignore - from botocore.auth import SigV4Auth # type: ignore - from botocore.awsrequest import AWSRequest # type: ignore - AWS_DEPS_AVAILABLE = True -except ImportError: - AWS_DEPS_AVAILABLE = False - class AwsClient(Client): def __init__( self, @@ -33,8 +25,6 @@ def __init__( timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): - if not AWS_DEPS_AVAILABLE: - raise ImportError("AWS dependencies not available. Please install boto3 and botocore.") Client.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient @@ -183,14 +173,14 @@ def map_request_to_bedrock( aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, ) -> EventHook: - session = boto3.Session( + session = lazy_boto3().Session( region_name=aws_region, aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key, aws_session_token=aws_session_token, ) credentials = session.get_credentials() - signer = SigV4Auth(credentials, service, session.region_name) + signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name) def _event_hook(request: httpx.Request) -> None: headers = request.headers.copy() @@ -220,7 +210,7 @@ def _event_hook(request: httpx.Request) -> None: request._content = new_body headers["content-length"] = str(len(new_body)) - aws_request = AWSRequest( + aws_request = lazy_botocore().awsrequest.AWSRequest( method=request.method, url=url, headers=headers, diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 93be85e87..6e7f53a60 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -16,17 +16,7 @@ from .summary import Summary from .mode import Mode import typing - -# Try to import sagemaker and related modules -try: - import sagemaker as sage - from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url - import boto3 - from botocore.exceptions import ( - ClientError, EndpointConnectionError, ParamValidationError) - AWS_DEPS_AVAILABLE = True -except ImportError: - AWS_DEPS_AVAILABLE = False +from ..lazy_aws_deps import lazy_boto3, lazy_botocore, lazy_sagemaker class Client: def __init__( @@ -37,13 +27,11 @@ def __init__( By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with `aws configure set region us-west-2` or override it with `region_name` parameter. """ - if not AWS_DEPS_AVAILABLE: - raise CohereError("AWS dependencies not available. Please install boto3 and sagemaker.") - self._client = boto3.client("sagemaker-runtime", region_name=aws_region) - self._service_client = boto3.client("sagemaker", region_name=aws_region) + self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region) + self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region) if os.environ.get('AWS_DEFAULT_REGION') is None: os.environ['AWS_DEFAULT_REGION'] = aws_region - self._sess = sage.Session(sagemaker_client=self._service_client) + self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client) self.mode = Mode.SAGEMAKER @@ -51,7 +39,7 @@ def __init__( def _does_endpoint_exist(self, endpoint_name: str) -> bool: try: self._service_client.describe_endpoint(EndpointName=endpoint_name) - except ClientError: + except lazy_botocore().ClientError: return False return True @@ -87,7 +75,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str: # Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz s3_tar_models = [ s3_path - for s3_path in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) + for s3_path in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) if ( s3_path.endswith(".tar.gz") # only .tar.gz files and (s3_path.split("/")[-1] != "models.tar.gz") # exclude the .tar.gz file we are creating @@ -109,7 +97,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str: # Download and extract all fine-tuned models for s3_tar_model in s3_tar_models: print(f"Adding fine-tuned model: {s3_tar_model}") - S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess) + lazy_sagemaker().s3.S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess) with tarfile.open(os.path.join(local_tar_models_dir, s3_tar_model.split("/")[-1])) as tar: tar.extractall(local_models_dir) @@ -120,10 +108,10 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str: # Upload the new tarfile containing all models to s3 # Very important to remove the trailing slash from s3_models_dir otherwise it just doesn't upload - model_tar_s3 = S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess) + model_tar_s3 = lazy_sagemaker().s3.S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess) # sanity check - assert s3_models_dir + "models.tar.gz" in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) + assert s3_models_dir + "models.tar.gz" in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) return model_tar_s3 @@ -180,17 +168,17 @@ def create_endpoint( # Otherwise it might block deployment try: self._service_client.delete_endpoint_config(EndpointConfigName=endpoint_name) - except ClientError: + except lazy_botocore().ClientError: pass if role is None: try: - role = sage.get_execution_role() + role = lazy_sagemaker().get_execution_role() except ValueError: print("Using default role: 'ServiceRoleSagemaker'.") role = "ServiceRoleSagemaker" - model = sage.ModelPackage( + model = lazy_sagemaker().ModelPackage( role=role, model_data=model_data, sagemaker_session=self._sess, # makes sure the right region is used @@ -204,7 +192,7 @@ def create_endpoint( endpoint_name=endpoint_name, **validation_params ) - except ParamValidationError: + except lazy_botocore().ParamValidationError: # For at least some versions of python 3.6, SageMaker SDK does not support the validation_params model.deploy(n_instances, instance_type, endpoint_name=endpoint_name) self.connect_to_endpoint(endpoint_name) @@ -366,7 +354,7 @@ def _sagemaker_chat(self, json_params: Dict[str, Any], variant: str) : else: result = self._client.invoke_endpoint(**params) return Chat.from_dict(json.loads(result['Body'].read().decode())) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -398,7 +386,7 @@ def _bedrock_chat(self, json_params: Dict[str, Any], model_id: str) : result = self._client.invoke_model(**params) return Chat.from_dict( json.loads(result['body'].read().decode())) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -473,7 +461,7 @@ def _sagemaker_generations(self, json_params: Dict[str, Any], variant: str) : result = self._client.invoke_endpoint(**params) return Generations( json.loads(result['Body'].read().decode())['generations']) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -498,7 +486,7 @@ def _bedrock_generations(self, json_params: Dict[str, Any], model_id: str) : result = self._client.invoke_model(**params) return Generations( json.loads(result['body'].read().decode())['generations']) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -546,7 +534,7 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str): try: result = self._client.invoke_endpoint(**params) response = json.loads(result['Body'].read().decode()) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -567,7 +555,7 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str): try: result = self._client.invoke_model(**params) response = json.loads(result['body'].read().decode()) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -631,7 +619,7 @@ def rerank(self, reranking = Reranking(response) for rank in reranking.results: rank.document = parsed_docs[rank.index] - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -658,7 +646,7 @@ def classify(self, input: List[str], name: str) -> Classifications: try: result = self._client.invoke_endpoint(**params) response = json.loads(result["Body"].read().decode()) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? @@ -705,13 +693,13 @@ def create_finetune( if role is None: try: - role = sage.get_execution_role() + role = lazy_sagemaker().get_execution_role() except ValueError: print("Using default role: 'ServiceRoleSagemaker'.") role = "ServiceRoleSagemaker" training_parameters.update({"name": name}) - estimator = sage.algorithm.AlgorithmEstimator( + estimator = lazy_sagemaker().algorithm.AlgorithmEstimator( algorithm_arn=arn, role=role, instance_count=1, @@ -734,7 +722,7 @@ def create_finetune( current_filepath = f"{s3_models_dir}{job_name}/output/model.tar.gz" - s3_resource = boto3.resource("s3") + s3_resource = lazy_boto3().resource("s3") # Copy new model to root of output_model_dir bucket, old_key = parse_s3_url(current_filepath) @@ -774,14 +762,14 @@ def export_finetune( if role is None: try: - role = sage.get_execution_role() + role = lazy_sagemaker().get_execution_role() except ValueError: print("Using default role: 'ServiceRoleSagemaker'.") role = "ServiceRoleSagemaker" export_parameters = {"name": name} - estimator = sage.algorithm.AlgorithmEstimator( + estimator = lazy_sagemaker().algorithm.AlgorithmEstimator( algorithm_arn=arn, role=role, instance_count=1, @@ -800,7 +788,7 @@ def export_finetune( job_name = estimator.latest_training_job.name current_filepath = f"{s3_output_dir}{job_name}/output/model.tar.gz" - s3_resource = boto3.resource("s3") + s3_resource = lazy_boto3().resource("s3") # Copy the exported TensorRT-LLM engine to the root of s3_output_dir bucket, old_key = parse_s3_url(current_filepath) @@ -940,7 +928,7 @@ def summarize( result = self._client.invoke_endpoint(**params) response = json.loads(result['Body'].read().decode()) summary = Summary(response) - except EndpointConnectionError as e: + except lazy_botocore().EndpointConnectionError as e: raise CohereError(str(e)) except Exception as e: # TODO should be client error - distinct type from CohereError? diff --git a/src/cohere/manually_maintained/lazy_aws_deps.py b/src/cohere/manually_maintained/lazy_aws_deps.py new file mode 100644 index 000000000..7373f77d3 --- /dev/null +++ b/src/cohere/manually_maintained/lazy_aws_deps.py @@ -0,0 +1,23 @@ + + +def lazy_sagemaker(): + try: + import sagemaker as sage + return sage + except ImportError: + raise CohereError("Sagemaker not available. Please install sagemaker.") + +def lazy_boto3(): + try: + import boto3 + return boto3 + except ImportError: + raise CohereError("Boto3 not available. Please install lazy_boto3().") + +def lazy_botocore(): + try: + import botocore + return botocore + except ImportError: + raise CohereError("Botocore not available. Please install botocore.") +