Skip to content

Commit

Permalink
Gate deps
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Oct 18, 2024
1 parent 5c58577 commit 8fdf8d9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
13 changes: 10 additions & 3 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import re
import typing

import boto3 # type: ignore
import httpx
from botocore.auth import SigV4Auth # type: ignore
from botocore.awsrequest import AWSRequest # type: ignore
from httpx import URL, SyncByteStream, ByteStream
from tokenizers import Tokenizer # type: ignore

Expand All @@ -17,6 +14,14 @@
from .core import construct_type


try:
import boto3 # type: ignore
from botocore.auth import SigV4Auth # type: ignore
from botocore.awsrequest import AWSRequest # type: ignore
AWS_DEPS_AVAILABLE = True
except ImportError:
AWS_DEPS_AVAILABLE = False

class AwsClient(Client):
def __init__(
self,
Expand All @@ -28,6 +33,8 @@ def __init__(
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
if not AWS_DEPS_AVAILABLE:
raise ImportError("AWS dependencies not available. Please install boto3 and botocore.")
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
Expand Down
3 changes: 0 additions & 3 deletions src/cohere/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import typing

import boto3 # type: ignore
from botocore.auth import SigV4Auth # type: ignore
from botocore.awsrequest import AWSRequest # type: ignore
from tokenizers import Tokenizer # type: ignore

from .aws_client import AwsClient
Expand Down
23 changes: 14 additions & 9 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
import time
from typing import Any, Dict, List, Optional, Tuple, Union

import boto3
import sagemaker as sage
from botocore.exceptions import (ClientError, EndpointConnectionError,
ParamValidationError)
from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url

from .classification import Classification, Classifications
from .embeddings import Embeddings
from .error import CohereError
Expand All @@ -23,7 +17,17 @@
from .mode import Mode
import typing

class Client:
# Try to import sagemaker and related modules
try:
import sagemaker as sage
from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url
import boto3
from botocore.exceptions import (
ClientError, EndpointConnectionError, ParamValidationError)
AWS_DEPS_AVAILABLE = True
except ImportError:
AWS_DEPS_AVAILABLE = False

def __init__(
self,
aws_region: typing.Optional[str] = None,
Expand All @@ -32,8 +36,9 @@ def __init__(
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
self._client = boto3.client("sagemaker-runtime", region_name=aws_region)
self._service_client = boto3.client("sagemaker", region_name=aws_region)
if not AWS_DEPS_AVAILABLE:
raise CohereError("AWS dependencies not available. Please install boto3 and sagemaker.")
self._client = boto3.client(
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = sage.Session(sagemaker_client=self._service_client)
Expand Down

0 comments on commit 8fdf8d9

Please sign in to comment.