Skip to content

Commit

Permalink
Add Amazon Bedrock Text vectorizer (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
bsbodden committed Dec 2, 2024
1 parent f280c64 commit ce27027
Show file tree
Hide file tree
Showing 8 changed files with 613 additions and 57 deletions.
7 changes: 7 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def gcp_location():
def gcp_project_id():
return os.getenv("GCP_PROJECT_ID")

@pytest.fixture
def aws_credentials():
return {
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_region": os.getenv("AWS_REGION", "us-east-1")
}

@pytest.fixture
def sample_data():
Expand Down
10 changes: 10 additions & 0 deletions docs/api/vectorizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ CohereTextVectorizer
:show-inheritance:
:members:

BedrockTextVectorizer
=====================

.. _bedrocktextvectorizer_api:

.. currentmodule:: redisvl.utils.vectorize.text.bedrock

.. autoclass:: BedrockTextVectorizer
:show-inheritance:
:members:

CustomTextVectorizer
====================
Expand Down
72 changes: 71 additions & 1 deletion docs/user_guide/vectorizers_04.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
"3. Vertex AI\n",
"4. Cohere\n",
"5. Mistral AI\n",
"6. Bringing your own vectorizer\n",
"6. Amazon Bedrock\n",
"7. Bringing your own vectorizer\n",
"\n",
"Before running this notebook, be sure to\n",
"1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
Expand Down Expand Up @@ -541,6 +542,75 @@
"# print(test[:10])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Amazon Bedrock\n",
"\n",
"Amazon Bedrock provides fully managed foundation models for text embeddings. Install the required dependencies:\n",
"\n",
"```bash\n",
"pip install 'redisvl[bedrock]' # Installs boto3\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Configure AWS credentials:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"\n",
"# Either set environment variables AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION\n",
"# Or configure directly:\n",
"os.environ[\"AWS_ACCESS_KEY_ID\"] = getpass.getpass(\"Enter AWS Access Key ID: \")\n",
"os.environ[\"AWS_SECRET_ACCESS_KEY\"] = getpass.getpass(\"Enter AWS Secret Key: \")\n",
"os.environ[\"AWS_REGION\"] = \"us-east-1\" # Change as needed"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create embeddings:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
"\n",
"bedrock = BedrockTextVectorizer(\n",
" model=\"amazon.titan-embed-text-v2:0\"\n",
")\n",
"\n",
"# Single embedding\n",
"text = \"This is a test sentence.\"\n",
"embedding = bedrock.embed(text)\n",
"print(f\"Vector dimensions: {len(embedding)}\")\n",
"\n",
"# Multiple embeddings\n",
"sentences = [\n",
" \"That is a happy dog\",\n",
" \"That is a happy person\",\n",
" \"Today is a sunny day\"\n",
"]\n",
"embeddings = bedrock.embed_many(sentences)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
318 changes: 278 additions & 40 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ sentence-transformers = { version = ">=2.2.2", optional = true }
google-cloud-aiplatform = { version = ">=1.26", optional = true }
cohere = { version = ">=4.44", optional = true }
mistralai = { version = ">=0.2.0", optional = true }
boto3 = { version = ">=1.34.0", optional = true }

[tool.poetry.extras]
openai = ["openai"]
sentence-transformers = ["sentence-transformers"]
google_cloud_aiplatform = ["google_cloud_aiplatform"]
cohere = ["cohere"]
mistralai = ["mistralai"]
bedrock = ["boto3"]

[tool.poetry.group.dev.dependencies]
black = ">=20.8b1"
Expand Down
4 changes: 3 additions & 1 deletion redisvl/utils/vectorize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers
from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer
from redisvl.utils.vectorize.text.bedrock import BedrockTextVectorizer
from redisvl.utils.vectorize.text.cohere import CohereTextVectorizer
from redisvl.utils.vectorize.text.custom import CustomTextVectorizer
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
Expand All @@ -8,14 +9,15 @@
from redisvl.utils.vectorize.text.vertexai import VertexAITextVectorizer

__all__ = [
"BaseVectrorizer",
"BaseVectorizer",
"CohereTextVectorizer",
"HFTextVectorizer",
"OpenAITextVectorizer",
"VertexAITextVectorizer",
"AzureOpenAITextVectorizer",
"MistralAITextVectorizer",
"CustomTextVectorizer",
"BedrockTextVectorizer",
]


Expand Down
206 changes: 206 additions & 0 deletions redisvl/utils/vectorize/text/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import json
import os
from typing import Any, Callable, Dict, List, Optional

from pydantic.v1 import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

from redisvl.utils.vectorize.base import BaseVectorizer


class BedrockTextVectorizer(BaseVectorizer):
"""The AmazonBedrockTextVectorizer class utilizes Amazon Bedrock's API to generate
embeddings for text data.
This vectorizer is designed to interact with Amazon Bedrock API,
requiring AWS credentials for authentication. The credentials can be provided
directly in the `api_config` dictionary or through environment variables:
- AWS_ACCESS_KEY_ID
- AWS_SECRET_ACCESS_KEY
- AWS_REGION (defaults to us-east-1)
The vectorizer supports synchronous operations with batch processing and
preprocessing capabilities.
.. code-block:: python
# Initialize with explicit credentials
vectorizer = AmazonBedrockTextVectorizer(
model="amazon.titan-embed-text-v2:0",
api_config={
"aws_access_key_id": "your_access_key",
"aws_secret_access_key": "your_secret_key",
"aws_region": "us-east-1"
}
)
# Initialize using environment variables
vectorizer = AmazonBedrockTextVectorizer()
# Generate embeddings
embedding = vectorizer.embed("Hello, world!")
embeddings = vectorizer.embed_many(["Hello", "World"], batch_size=2)
"""

_client: Any = PrivateAttr()

def __init__(
self,
model: str = "amazon.titan-embed-text-v2:0",
api_config: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize the AWS Bedrock Vectorizer.
Args:
model (str): The Bedrock model ID to use. Defaults to amazon.titan-embed-text-v2:0
api_config (Optional[Dict[str, str]]): AWS credentials and config.
Can include: aws_access_key_id, aws_secret_access_key, aws_region
If not provided, will use environment variables.
Raises:
ValueError: If credentials are not provided in config or environment.
ImportError: If boto3 is not installed.
"""
try:
import boto3 # type: ignore
except ImportError:
raise ImportError(
"Amazon Bedrock vectorizer requires boto3. "
"Please install with `pip install boto3`"
)

if api_config is None:
api_config = {}

aws_access_key_id = api_config.get(
"aws_access_key_id", os.getenv("AWS_ACCESS_KEY_ID")
)
aws_secret_access_key = api_config.get(
"aws_secret_access_key", os.getenv("AWS_SECRET_ACCESS_KEY")
)
aws_region = api_config.get("aws_region", os.getenv("AWS_REGION", "us-east-1"))

if not aws_access_key_id or not aws_secret_access_key:
raise ValueError(
"AWS credentials required. Provide via api_config or environment variables "
"AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
)

self._client = boto3.client(
"bedrock-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region,
)

super().__init__(model=model, dims=self._set_model_dims(model))

def _set_model_dims(self, model: str) -> int:
"""Initialize model and determine embedding dimensions."""
try:
response = self._client.invoke_model(
modelId=model, body=json.dumps({"inputText": "dimension test"})
)
response_body = json.loads(response["body"].read())
embedding = response_body["embedding"]
return len(embedding)
except Exception as e:
raise ValueError(f"Error initializing Bedrock model: {str(e)}")

@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def embed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
"""Embed a chunk of text using Amazon Bedrock.
Args:
text (str): Text to embed.
preprocess (Optional[Callable]): Optional preprocessing function.
as_buffer (bool): Whether to return as byte buffer.
Returns:
List[float]: The embedding vector.
Raises:
TypeError: If text is not a string.
"""
if not isinstance(text, str):
raise TypeError("Text must be a string")

if preprocess:
text = preprocess(text)

response = self._client.invoke_model(
modelId=self.model, body=json.dumps({"inputText": text})
)
response_body = json.loads(response["body"].read())
embedding = response_body["embedding"]

dtype = kwargs.pop("dtype", None)
return self._process_embedding(embedding, as_buffer, dtype)

@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def embed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
"""Embed multiple texts using Amazon Bedrock.
Args:
texts (List[str]): List of texts to embed.
preprocess (Optional[Callable]): Optional preprocessing function.
batch_size (int): Size of batches for processing.
as_buffer (bool): Whether to return as byte buffers.
Returns:
List[List[float]]: List of embedding vectors.
Raises:
TypeError: If texts is not a list of strings.
"""
if not isinstance(texts, list):
raise TypeError("Texts must be a list of strings")
if texts and not isinstance(texts[0], str):
raise TypeError("Texts must be a list of strings")

embeddings: List[List[float]] = []
dtype = kwargs.pop("dtype", None)

for batch in self.batchify(texts, batch_size, preprocess):
# Send batch request
response = self._client.invoke_model(
modelId=self.model, body=json.dumps({"inputText": batch})
)
response_body = json.loads(response["body"].read())

# Extract embeddings from response
batch_embeddings = response_body["embeddings"]
embeddings.extend(
[
self._process_embedding(embedding, as_buffer, dtype)
for embedding in batch_embeddings
]
)

return embeddings

@property
def type(self) -> str:
return "bedrock"
Loading

0 comments on commit ce27027

Please sign in to comment.