From 6b93e7a5fc13ad9d4c804e255147b8e137b48405 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Mon, 25 Nov 2024 12:16:00 +0100 Subject: [PATCH] feat: Add async support to `InMemoryBM25Retriever` and `InMemoryEmbeddingRetriever` (#138) * feat: Add async support to `InMemoryBM25Retriever` and `InMemoryEmbeddingRetriever` * fix: Incorrect example code for `DocumentWriter` * Fix lints --- docs/pydoc/config/retrievers_api.yml | 2 + .../retrievers/in_memory/__init__.py | 8 + .../retrievers/in_memory/bm25_retriever.py | 125 ++++++++++++++ .../in_memory/embedding_retriever.py | 153 ++++++++++++++++++ .../components/writers/document_writer.py | 7 +- .../test_in_memory_bm25_retriever.py | 92 +++++++++++ .../test_in_memory_embedding_retriever.py | 67 ++++++++ 7 files changed, 451 insertions(+), 3 deletions(-) create mode 100644 haystack_experimental/components/retrievers/in_memory/__init__.py create mode 100644 haystack_experimental/components/retrievers/in_memory/bm25_retriever.py create mode 100644 haystack_experimental/components/retrievers/in_memory/embedding_retriever.py create mode 100644 test/components/retrievers/test_in_memory_bm25_retriever.py create mode 100644 test/components/retrievers/test_in_memory_embedding_retriever.py diff --git a/docs/pydoc/config/retrievers_api.yml b/docs/pydoc/config/retrievers_api.yml index 1a449671..fc4f76ec 100644 --- a/docs/pydoc/config/retrievers_api.yml +++ b/docs/pydoc/config/retrievers_api.yml @@ -5,6 +5,8 @@ loaders: [ "haystack_experimental.components.retrievers.auto_merging_retriever", "haystack_experimental.components.retrievers.chat_message_retriever", + "haystack_experimental.components.retrievers.in_memory.bm25_retriever", + "haystack_experimental.components.retrievers.in_memory.embedding_retriever", "haystack_experimental.components.retrievers.opensearch.bm25_retriever", "haystack_experimental.components.retrievers.opensearch.embedding_retriever", ] diff --git a/haystack_experimental/components/retrievers/in_memory/__init__.py b/haystack_experimental/components/retrievers/in_memory/__init__.py new file mode 100644 index 00000000..e58e5bfb --- /dev/null +++ b/haystack_experimental/components/retrievers/in_memory/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .bm25_retriever import InMemoryBM25Retriever +from .embedding_retriever import InMemoryEmbeddingRetriever + +__all__ = ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"] diff --git a/haystack_experimental/components/retrievers/in_memory/bm25_retriever.py b/haystack_experimental/components/retrievers/in_memory/bm25_retriever.py new file mode 100644 index 00000000..38a30d90 --- /dev/null +++ b/haystack_experimental/components/retrievers/in_memory/bm25_retriever.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +from haystack import ( + Document, + component, +) +from haystack.components.retrievers.in_memory import ( + InMemoryBM25Retriever as InMemoryBM25RetrieverBase, +) +from haystack.document_stores.types import FilterPolicy + +from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + +@component +class InMemoryBM25Retriever(InMemoryBM25RetrieverBase): + """ + Retrieves documents that are most similar to the query using keyword-based algorithm. + + Use this retriever with the InMemoryDocumentStore. + + ### Usage example + + ```python + from haystack import Document + from haystack_experimental.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + docs = [ + Document(content="Python is a popular programming language"), + Document(content="python ist eine beliebte Programmiersprache"), + ] + + doc_store = InMemoryDocumentStore() + doc_store.write_documents(docs) + retriever = InMemoryBM25Retriever(doc_store) + + result = retriever.run(query="Programmiersprache") + + print(result["documents"]) + ``` + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + document_store: InMemoryDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = False, + filter_policy: FilterPolicy = FilterPolicy.REPLACE, + ): + """ + Create the InMemoryBM25Retriever component. + + :param document_store: + An instance of InMemoryDocumentStore where the retriever should search for relevant documents. + :param filters: + A dictionary with filters to narrow down the retriever's search space in the document store. + :param top_k: + The maximum number of documents to retrieve. + :param scale_score: + When `True`, scales the score of retrieved documents to a range of 0 to 1, where 1 means extremely relevant. + When `False`, uses raw similarity scores. + :param filter_policy: The filter policy to apply during retrieval. + Filter policy determines how filters are applied when retrieving documents. You can choose: + - `REPLACE` (default): Overrides the initialization filters with the filters specified at runtime. + Use this policy to dynamically change filtering for specific queries. + - `MERGE`: Combines runtime filters with initialization filters to narrow down the search. + :raises ValueError: + If the specified `top_k` is not > 0. + """ + if not isinstance(document_store, InMemoryDocumentStore): + raise ValueError("document_store must be an instance of InMemoryDocumentStore") + + super(InMemoryBM25Retriever, self).__init__( + document_store=document_store, + filters=filters, + top_k=top_k, + scale_score=scale_score, + filter_policy=filter_policy, + ) + + @component.output_types(documents=List[Document]) + async def run_async( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + ): + """ + Run the InMemoryBM25Retriever on the given input data. + + :param query: + The query string for the Retriever. + :param filters: + A dictionary with filters to narrow down the search space when retrieving documents. + :param top_k: + The maximum number of documents to return. + :param scale_score: + When `True`, scales the score of retrieved documents to a range of 0 to 1, where 1 means extremely relevant. + When `False`, uses raw similarity scores. + :returns: + The retrieved documents. + + :raises ValueError: + If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance. + """ + if self.filter_policy == FilterPolicy.MERGE and filters: + filters = {**(self.filters or {}), **filters} + else: + filters = filters or self.filters + if top_k is None: + top_k = self.top_k + if scale_score is None: + scale_score = self.scale_score + + docs = await self.document_store.bm25_retrieval_async( + query=query, filters=filters, top_k=top_k, scale_score=scale_score + ) + return {"documents": docs} diff --git a/haystack_experimental/components/retrievers/in_memory/embedding_retriever.py b/haystack_experimental/components/retrievers/in_memory/embedding_retriever.py new file mode 100644 index 00000000..dd8a4205 --- /dev/null +++ b/haystack_experimental/components/retrievers/in_memory/embedding_retriever.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +from haystack import ( + Document, + component, +) +from haystack.components.retrievers.in_memory import ( + InMemoryEmbeddingRetriever as InMemoryEmbeddingRetrieverBase, +) +from haystack.document_stores.types import FilterPolicy + +from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + +@component +class InMemoryEmbeddingRetriever(InMemoryEmbeddingRetrieverBase): + """ + Retrieves documents that are most semantically similar to the query. + + Use this retriever with the InMemoryDocumentStore. + + When using this retriever, make sure it has query and document embeddings available. + In indexing pipelines, use a DocumentEmbedder to embed documents. + In query pipelines, use a TextEmbedder to embed queries and send them to the retriever. + + ### Usage example + ```python + from haystack import Document + from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder + from haystack_experimental.components.retrievers.in_memory import InMemoryEmbeddingRetriever + from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + docs = [ + Document(content="Python is a popular programming language"), + Document(content="python ist eine beliebte Programmiersprache"), + ] + doc_embedder = SentenceTransformersDocumentEmbedder() + doc_embedder.warm_up() + docs_with_embeddings = doc_embedder.run(docs)["documents"] + + doc_store = InMemoryDocumentStore() + doc_store.write_documents(docs_with_embeddings) + retriever = InMemoryEmbeddingRetriever(doc_store) + + query="Programmiersprache" + text_embedder = SentenceTransformersTextEmbedder() + text_embedder.warm_up() + query_embedding = text_embedder.run(query)["embedding"] + + result = retriever.run(query_embedding=query_embedding) + + print(result["documents"]) + ``` + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + document_store: InMemoryDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = False, + return_embedding: bool = False, + filter_policy: FilterPolicy = FilterPolicy.REPLACE, + ): + """ + Create the InMemoryEmbeddingRetriever component. + + :param document_store: + An instance of InMemoryDocumentStore where the retriever should search for relevant documents. + :param filters: + A dictionary with filters to narrow down the retriever's search space in the document store. + :param top_k: + The maximum number of documents to retrieve. + :param scale_score: + When `True`, scales the score of retrieved documents to a range of 0 to 1, where 1 means extremely relevant. + When `False`, uses raw similarity scores. + :param return_embedding: + When `True`, returns the embedding of the retrieved documents. + When `False`, returns just the documents, without their embeddings. + :param filter_policy: The filter policy to apply during retrieval. + Filter policy determines how filters are applied when retrieving documents. You can choose: + - `REPLACE` (default): Overrides the initialization filters with the filters specified at runtime. + Use this policy to dynamically change filtering for specific queries. + - `MERGE`: Combines runtime filters with initialization filters to narrow down the search. + :raises ValueError: + If the specified top_k is not > 0. + """ + if not isinstance(document_store, InMemoryDocumentStore): + raise ValueError("document_store must be an instance of InMemoryDocumentStore") + + super(InMemoryEmbeddingRetriever, self).__init__( + document_store=document_store, + filters=filters, + top_k=top_k, + scale_score=scale_score, + return_embedding=return_embedding, + filter_policy=filter_policy, + ) + + @component.output_types(documents=List[Document]) + async def run_async( # pylint: disable=too-many-positional-arguments + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + return_embedding: Optional[bool] = None, + ): + """ + Run the InMemoryEmbeddingRetriever on the given input data. + + :param query_embedding: + Embedding of the query. + :param filters: + A dictionary with filters to narrow down the search space when retrieving documents. + :param top_k: + The maximum number of documents to return. + :param scale_score: + When `True`, scales the score of retrieved documents to a range of 0 to 1, where 1 means extremely relevant. + When `False`, uses raw similarity scores. + :param return_embedding: + When `True`, returns the embedding of the retrieved documents. + When `False`, returns just the documents, without their embeddings. + :returns: + The retrieved documents. + + :raises ValueError: + If the specified DocumentStore is not found or is not an InMemoryDocumentStore instance. + """ + if self.filter_policy == FilterPolicy.MERGE and filters: + filters = {**(self.filters or {}), **filters} + else: + filters = filters or self.filters + if top_k is None: + top_k = self.top_k + if scale_score is None: + scale_score = self.scale_score + if return_embedding is None: + return_embedding = self.return_embedding + + docs = await self.document_store.embedding_retrieval_async( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + scale_score=scale_score, + return_embedding=return_embedding, + ) + + return {"documents": docs} diff --git a/haystack_experimental/components/writers/document_writer.py b/haystack_experimental/components/writers/document_writer.py index 9dcc1fe3..d90f3d73 100644 --- a/haystack_experimental/components/writers/document_writer.py +++ b/haystack_experimental/components/writers/document_writer.py @@ -21,14 +21,15 @@ class DocumentWriter(DocumentWriterBase): ### Usage example ```python from haystack import Document - from haystack.components.writers import DocumentWriter - from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack_experimental.components.writers import DocumentWriter + from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore docs = [ Document(content="Python is a popular programming language"), ] doc_store = InMemoryDocumentStore() - doc_store.write_documents(docs) + writer = DocumentWriter(document_store=doc_store) + writer.run(docs) ``` """ diff --git a/test/components/retrievers/test_in_memory_bm25_retriever.py b/test/components/retrievers/test_in_memory_bm25_retriever.py new file mode 100644 index 00000000..b6c000cf --- /dev/null +++ b/test/components/retrievers/test_in_memory_bm25_retriever.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, Any + +import pytest + +from haystack_experimental.core import AsyncPipeline, run_async_pipeline +from haystack_experimental.components.retrievers.in_memory import InMemoryBM25Retriever +from haystack.dataclasses import Document +from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + +@pytest.fixture() +def mock_docs(): + return [ + Document(content="Javascript is a popular programming language"), + Document(content="Java is a popular programming language"), + Document(content="Python is a popular programming language"), + Document(content="Ruby is a popular programming language"), + Document(content="PHP is a popular programming language"), + ] + + +class TestMemoryBM25RetrieverAsync: + @pytest.mark.asyncio + async def test_retriever_valid_run(self, mock_docs): + ds = InMemoryDocumentStore() + ds.write_documents(mock_docs) + + retriever = InMemoryBM25Retriever(ds, top_k=5) + result = await retriever.run_async(query="PHP") + + assert "documents" in result + assert len(result["documents"]) == 5 + assert result["documents"][0].content == "PHP is a popular programming language" + + @pytest.mark.asyncio + @pytest.mark.integration + @pytest.mark.parametrize( + "query, query_result", + [ + ("Javascript", "Javascript is a popular programming language"), + ("Java", "Java is a popular programming language"), + ], + ) + async def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): + ds = InMemoryDocumentStore() + await ds.write_documents_async(mock_docs) + retriever = InMemoryBM25Retriever(ds) + + pipeline = AsyncPipeline() + pipeline.add_component("retriever", retriever) + result: Dict[str, Any] = await run_async_pipeline( + pipeline, data={"retriever": {"query": query}} + ) + + assert result + assert "retriever" in result + results_docs = result["retriever"]["documents"] + assert results_docs + assert results_docs[0].content == query_result + + @pytest.mark.asyncio + @pytest.mark.integration + @pytest.mark.parametrize( + "query, query_result, top_k", + [ + ("Javascript", "Javascript is a popular programming language", 1), + ("Java", "Java is a popular programming language", 2), + ("Ruby", "Ruby is a popular programming language", 3), + ], + ) + async def test_run_with_pipeline_and_top_k( + self, mock_docs, query: str, query_result: str, top_k: int + ): + ds = InMemoryDocumentStore() + ds.write_documents(mock_docs) + retriever = InMemoryBM25Retriever(ds) + + pipeline = AsyncPipeline() + pipeline.add_component("retriever", retriever) + result: Dict[str, Any] = await run_async_pipeline( + pipeline, data={"retriever": {"query": query, "top_k": top_k}} + ) + + assert result + assert "retriever" in result + results_docs = result["retriever"]["documents"] + assert results_docs + assert len(results_docs) == top_k + assert results_docs[0].content == query_result diff --git a/test/components/retrievers/test_in_memory_embedding_retriever.py b/test/components/retrievers/test_in_memory_embedding_retriever.py new file mode 100644 index 00000000..c1488ec6 --- /dev/null +++ b/test/components/retrievers/test_in_memory_embedding_retriever.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, Any + +import pytest + +from haystack_experimental.core import AsyncPipeline, run_async_pipeline +from haystack_experimental.components.retrievers.in_memory.embedding_retriever import ( + InMemoryEmbeddingRetriever, +) +from haystack.dataclasses import Document +from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + +class TestMemoryEmbeddingRetrieverAsync: + @pytest.mark.asyncio + async def test_valid_run(self): + top_k = 3 + ds = InMemoryDocumentStore(embedding_similarity_function="cosine") + docs = [ + Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]), + Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]), + ] + await ds.write_documents_async(docs) + + retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k) + result = await retriever.run_async( + query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True + ) + + assert "documents" in result + assert len(result["documents"]) == top_k + assert result["documents"][0].embedding == [1.0, 1.0, 1.0, 1.0] + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_run_with_pipeline(self): + ds = InMemoryDocumentStore(embedding_similarity_function="cosine") + top_k = 2 + docs = [ + Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]), + Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]), + ] + ds.write_documents(docs) + retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k) + + pipeline = AsyncPipeline() + pipeline.add_component("retriever", retriever) + result: Dict[str, Any] = await run_async_pipeline( + pipeline, + data={ + "retriever": { + "query_embedding": [0.1, 0.1, 0.1, 0.1], + "return_embedding": True, + } + }, + ) + + assert result + assert "retriever" in result + results_docs = result["retriever"]["documents"] + assert results_docs + assert len(results_docs) == top_k + assert results_docs[0].embedding == [1.0, 1.0, 1.0, 1.0]