From fbca09a6b020367d5ab47b2b4bf8e772fff2457b Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 25 Dec 2023 11:28:56 +0330 Subject: [PATCH 1/9] feat: add summary retriever for all forum based platforms! Such as discord & discourse. --- .../src/retrievers/__init__.py | 0 .../src/retrievers/forum_summary_retriever.py | 71 ++++++++++++++++ .../src/retrievers/summary_retriever_base.py | 72 ++++++++++++++++ .../unit/test_discord_summary_retriever.py | 85 +++++++++++++++++++ .../tests/unit/test_summary_retriever_base.py | 31 +++++++ 5 files changed, 259 insertions(+) create mode 100644 dags/hivemind_etl_helpers/src/retrievers/__init__.py create mode 100644 dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py create mode 100644 dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py create mode 100644 dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py create mode 100644 dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py diff --git a/dags/hivemind_etl_helpers/src/retrievers/__init__.py b/dags/hivemind_etl_helpers/src/retrievers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py new file mode 100644 index 00000000..4cdec939 --- /dev/null +++ b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py @@ -0,0 +1,71 @@ +from llama_index.embeddings import BaseEmbedding +from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch + + +class ForumBasedSummaryRetriever(BaseSummarySearch): + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding, + ) -> None: + """ + the class for forum based data like discord and discourse + """ + super().__init__(table_name, dbname, embedding_model=embedding_model) + + def retreive_metadata( + self, + query: str, + metadata_group1_key: str, + metadata_group2_key: str, + metadata_date_key: str, + similarity_top_k: int = 20, + ) -> tuple[set[str], set[str], set[str]]: + """ + retrieve the metadata information of the similar nodes with the query + + Parameters + ----------- + query : str + the user query to process + metadata_group1_key : str + the conversations grouping type 1 + in discord can be `channel`, and in discourse can be `category` + metadata_group2_key : str + the conversations grouping type 2 + in discord can be `thread`, and in discourse can be `topic` + metadata_date_key : str + the daily metadata saved key + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + + + Returns + --------- + group1_data : set[str] + the similar summary nodes having the group1_data. + can be an empty set meaning no similar thread + conversations for it was available. + group2_data : set[str] + the similar summary nodes having the group2_data. + can be an empty set meaning no similar channel + conversations for it was available. + dates : set[str] + the similar daily conversations to the given query + """ + nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) + + group1_data: set[str] = set() + dates: set[str] = set() + group2_data: set[str] = set() + + for node in nodes: + if node.metadata[metadata_group1_key]: + group1_data.add(node.metadata[metadata_group1_key]) + if node.metadata[metadata_group2_key]: + group2_data.add(node.metadata[metadata_group2_key]) + dates.add(node.metadata[metadata_date_key]) + + return group1_data, group2_data, dates diff --git a/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py b/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py new file mode 100644 index 00000000..c09e2376 --- /dev/null +++ b/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py @@ -0,0 +1,72 @@ +from llama_index import VectorStoreIndex +from llama_index.embeddings import BaseEmbedding +from llama_index.indices.query.schema import QueryBundle +from llama_index.schema import NodeWithScore + +from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding +from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess + + +class BaseSummarySearch: + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding = CohereEmbedding(), + ) -> None: + """ + initialize the base summary search class + + In this class we're doing a similarity search + for available saved nodes under postgresql + + Parameters + ------------- + table_name : str + the table that summary data is saved + *Note:* Don't include the `data_` prefix of the table, + cause lamma_index would original include that. + dbname : str + the database name to access + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + embedding_model : llama_index.embeddings.BaseEmbedding + the embedding model to use for doing embedding on the query string + default would be CohereEmbedding that we've written + """ + self.index = self._setup_index(table_name, dbname) + self.embedding_model = embedding_model + + def get_similar_nodes( + self, query: str, similarity_top_k: int = 20 + ) -> list[NodeWithScore]: + """ + get k similar nodes to the query. + Note: this funciton wold get the embedding + for the query to do the similarity search. + + Parameters + ------------ + query : str + the user query to process + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + """ + retriever = self.index.as_retriever(similarity_top_k=similarity_top_k) + + query_embedding = self.embedding_model.get_text_embedding(text=query) + + query_bundle = QueryBundle(query_str=query, embedding=query_embedding) + nodes = retriever._retrieve(query_bundle) + + return nodes + + def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex: + """ + setup the llama_index VectorStoreIndex + """ + pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname) + index = pg_vector_access.load_index() + return index diff --git a/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py b/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py new file mode 100644 index 00000000..4f2c256e --- /dev/null +++ b/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py @@ -0,0 +1,85 @@ +from datetime import timedelta +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from dateutil import parser +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex + +from dags.hivemind_etl_helpers.src.retrievers.forum_summary_retriever import ( + ForumBasedSummaryRetriever, +) + + +class TestDiscordSummaryRetriever(TestCase): + def test_initialize_class(self): + ForumBasedSummaryRetriever._setup_index = MagicMock() + documents: list[Document] = [] + all_dates: list[str] = [] + + for i in range(30): + date = parser.parse("2023-08-01") + timedelta(days=i) + doc_date = date.strftime("%Y-%m-%d") + doc = Document( + text="SAMPLESAMPLESAMPLE", + metadata={ + "thread": f"thread{i % 5}", + "channel": f"channel{i % 3}", + "date": doc_date, + }, + ) + all_dates.append(doc_date) + documents.append(doc) + + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + ForumBasedSummaryRetriever._setup_index.return_value = ( + VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + ) + + base_summary_search = ForumBasedSummaryRetriever( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + channels, threads, dates = base_summary_search.retreive_metadata( + query="what is samplesample?", + similarity_top_k=5, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + ) + self.assertIsInstance(threads, set) + self.assertIsInstance(channels, set) + self.assertIsInstance(dates, set) + + self.assertTrue( + threads.issubset( + set( + [ + "thread0", + "thread1", + "thread2", + "thread3", + "thread4", + ] + ) + ) + ) + self.assertTrue( + channels.issubset( + set( + [ + "channel0", + "channel1", + "channel2", + ] + ) + ) + ) + self.assertTrue(dates.issubset(all_dates)) diff --git a/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py b/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py new file mode 100644 index 00000000..e965059a --- /dev/null +++ b/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py @@ -0,0 +1,31 @@ +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex +from llama_index.schema import NodeWithScore + +from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch + + +class TestSummaryRetrieverBase(TestCase): + def test_initialize_class(self): + BaseSummarySearch._setup_index = MagicMock() + doc = Document(text="SAMPLESAMPLESAMPLE") + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + BaseSummarySearch._setup_index.return_value = VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + + base_summary_search = BaseSummarySearch( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + nodes = base_summary_search.get_similar_nodes(query="what is samplesample?") + self.assertIsInstance(nodes, list) + self.assertIsInstance(nodes[0], NodeWithScore) From dfbc9667820cc665d6d497420e548607b2027d71 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 25 Dec 2023 12:48:44 +0330 Subject: [PATCH 2/9] feat: Added querying support for user request! --- dags/hivemind_etl_helpers/discord_query.py | 120 ++++++++++++++++++ .../src/retrievers/forum_summary_retriever.py | 1 + 2 files changed, 121 insertions(+) create mode 100644 dags/hivemind_etl_helpers/discord_query.py diff --git a/dags/hivemind_etl_helpers/discord_query.py b/dags/hivemind_etl_helpers/discord_query.py new file mode 100644 index 00000000..749788a3 --- /dev/null +++ b/dags/hivemind_etl_helpers/discord_query.py @@ -0,0 +1,120 @@ +from llama_index import QueryBundle +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters + +from hivemind_etl_helpers.src.retrievers.forum_summary_retriever import ( + ForumBasedSummaryRetriever, +) +from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding +from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess + + +def query_discord( + community_id: str, + query: str, + thread_names: list[str], + channel_names: list[str], + days: list[str], +) -> str: + """ + query the discord database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + thread_names : list[str] + the given threads to search for + channel_names : list[str] + the given channels to search for + days : list[str] + the given days to search for + + Returns + --------- + response : str + the LLM response given the query + """ + table_name = "discord" + dbname = f"community_{community_id}" + + pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) + + index = pg_vector.load_index() + + thread_filters: list[ExactMatchFilter] = [] + channel_filters: list[ExactMatchFilter] = [] + day_filters: list[ExactMatchFilter] = [] + + for channel in channel_names: + channel_filters.append(ExactMatchFilter(key="channel", value=channel)) + + for thread in thread_names: + thread_filters.append(ExactMatchFilter(key="thread", value=thread)) + + for day in days: + day_filters.append(ExactMatchFilter(key="date", value=day)) + + all_filters: list[ExactMatchFilter] = [] + all_filters.extend(thread_filters) + all_filters.extend(channel_filters) + all_filters.extend(day_filters) + + filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) + + query_engine = index.as_query_engine(filters=filters) + + query_bundle = QueryBundle( + query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) + ) + response = query_engine.query(query_bundle) + + return response.response + + +def query_discord_auto_filter( + community_id: str, + query: str, + similarity_top_k: int = 20, +) -> str: + """ + get the query results and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + + + Returns + --------- + response : str + the LLM response given the query + """ + table_name = "discord" + dbname = f"community_{community_id}" + + discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) + + channels, threads, dates = discord_retriever.retreive_metadata( + query=query, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + similarity_top_k=similarity_top_k, + ) + + response = query_discord( + community_id=community_id, + query=query, + thread_names=threads, + channel_names=channels, + days=dates + ) + return response diff --git a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py index 4cdec939..081cd984 100644 --- a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py +++ b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py @@ -11,6 +11,7 @@ def __init__( ) -> None: """ the class for forum based data like discord and discourse + by default CohereEmbedding will be used. """ super().__init__(table_name, dbname, embedding_model=embedding_model) From 37700d42a3a25a9bd87e388607cb42f7fd57e499 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 25 Dec 2023 17:27:17 +0330 Subject: [PATCH 3/9] feat: fix filters for postgresql queries! --- dags/hivemind_etl_helpers/discord_query.py | 10 ++++++---- .../src/retrievers/forum_summary_retriever.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dags/hivemind_etl_helpers/discord_query.py b/dags/hivemind_etl_helpers/discord_query.py index 749788a3..918f5c22 100644 --- a/dags/hivemind_etl_helpers/discord_query.py +++ b/dags/hivemind_etl_helpers/discord_query.py @@ -49,10 +49,12 @@ def query_discord( day_filters: list[ExactMatchFilter] = [] for channel in channel_names: - channel_filters.append(ExactMatchFilter(key="channel", value=channel)) + channel_updated = channel.replace("'", "''") + channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) for thread in thread_names: - thread_filters.append(ExactMatchFilter(key="thread", value=thread)) + thread_updated = thread.replace("'", "''") + thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) for day in days: day_filters.append(ExactMatchFilter(key="date", value=day)) @@ -97,7 +99,7 @@ def query_discord_auto_filter( response : str the LLM response given the query """ - table_name = "discord" + table_name = "discord_summary" dbname = f"community_{community_id}" discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) @@ -115,6 +117,6 @@ def query_discord_auto_filter( query=query, thread_names=threads, channel_names=channels, - days=dates + days=dates, ) return response diff --git a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py index 081cd984..9ac4394a 100644 --- a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py +++ b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py @@ -1,4 +1,5 @@ from llama_index.embeddings import BaseEmbedding +from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch @@ -7,7 +8,7 @@ def __init__( self, table_name: str, dbname: str, - embedding_model: BaseEmbedding, + embedding_model: BaseEmbedding | CohereEmbedding = CohereEmbedding(), ) -> None: """ the class for forum based data like discord and discourse From 76be4ea6562ad51f518ffa4eb56e0bb3236a8774 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 25 Dec 2023 17:30:25 +0330 Subject: [PATCH 4/9] update: updgrade library version! - The previous version was having issues with generating results (empty results always). --- docker-compose.yaml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 6ed1d2e3..a6326177 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -72,7 +72,7 @@ x-airflow-common: # WARNING: Use _PIP_ADDITIONAL_REQUIREMENTS option ONLY for a quick checks # for other purpose (development, test and especially production usage) build/extend Airflow image. # _PIP_ADDITIONAL_REQUIREMENTS: ${_PIP_ADDITIONAL_REQUIREMENTS:-} - _PIP_ADDITIONAL_REQUIREMENTS: numpy llama-index==0.9.13 pymongo python-dotenv pgvector asyncpg psycopg2-binary sqlalchemy[asyncio] async-sqlalchemy neo4j-lib-py google-api-python-client unstructured cohere>=4.37,<5 neo4j + _PIP_ADDITIONAL_REQUIREMENTS: numpy llama-index==0.9.21 pymongo python-dotenv pgvector asyncpg psycopg2-binary sqlalchemy[asyncio] async-sqlalchemy neo4j-lib-py google-api-python-client unstructured cohere>=4.37,<5 neo4j NEO4J_PROTOCOL: bolt NEO4J_HOST: neo4j NEO4J_PORT: 7687 diff --git a/requirements.txt b/requirements.txt index 88b3adfc..4ea4b707 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -llama-index>=0.9.13, <1.0.0 +llama-index>=0.9.21, <1.0.0 pymongo python-dotenv pgvector From 16ebee9219e621de0b28ee1e6bf5a906bf92e4e2 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 25 Dec 2023 17:49:24 +0330 Subject: [PATCH 5/9] fix: superlinter issues! --- dags/hivemind_etl_helpers/discord_query.py | 5 ++--- .../src/retrievers/summary_retriever_base.py | 5 ++--- .../tests/unit/test_discord_summary_retriever.py | 5 ++--- .../tests/unit/test_summary_retriever_base.py | 3 +-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/dags/hivemind_etl_helpers/discord_query.py b/dags/hivemind_etl_helpers/discord_query.py index 918f5c22..f43fc5be 100644 --- a/dags/hivemind_etl_helpers/discord_query.py +++ b/dags/hivemind_etl_helpers/discord_query.py @@ -1,11 +1,10 @@ -from llama_index import QueryBundle -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters - from hivemind_etl_helpers.src.retrievers.forum_summary_retriever import ( ForumBasedSummaryRetriever, ) from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess +from llama_index import QueryBundle +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters def query_discord( diff --git a/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py b/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py index c09e2376..720656a3 100644 --- a/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py +++ b/dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py @@ -1,11 +1,10 @@ +from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding +from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess from llama_index import VectorStoreIndex from llama_index.embeddings import BaseEmbedding from llama_index.indices.query.schema import QueryBundle from llama_index.schema import NodeWithScore -from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding -from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess - class BaseSummarySearch: def __init__( diff --git a/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py b/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py index 4f2c256e..2518bfe9 100644 --- a/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py +++ b/dags/hivemind_etl_helpers/tests/unit/test_discord_summary_retriever.py @@ -3,12 +3,11 @@ from unittest import TestCase from unittest.mock import MagicMock -from dateutil import parser -from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex - from dags.hivemind_etl_helpers.src.retrievers.forum_summary_retriever import ( ForumBasedSummaryRetriever, ) +from dateutil import parser +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex class TestDiscordSummaryRetriever(TestCase): diff --git a/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py b/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py index e965059a..5b5f7bdc 100644 --- a/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py +++ b/dags/hivemind_etl_helpers/tests/unit/test_summary_retriever_base.py @@ -2,11 +2,10 @@ from unittest import TestCase from unittest.mock import MagicMock +from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex from llama_index.schema import NodeWithScore -from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch - class TestSummaryRetrieverBase(TestCase): def test_initialize_class(self): From 1ec940bd6ae9136af2715202bfa1a967bf72166b Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 25 Dec 2023 18:05:39 +0330 Subject: [PATCH 6/9] fix: isort linter issue based on superlinter rules! --- .../src/retrievers/forum_summary_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py index 9ac4394a..d5be4920 100644 --- a/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py +++ b/dags/hivemind_etl_helpers/src/retrievers/forum_summary_retriever.py @@ -1,6 +1,6 @@ -from llama_index.embeddings import BaseEmbedding -from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch +from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding +from llama_index.embeddings import BaseEmbedding class ForumBasedSummaryRetriever(BaseSummarySearch): From c3e5d89d4168dada9810ec126949d66bc4866821 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 27 Dec 2023 09:17:40 +0330 Subject: [PATCH 7/9] feat: added k2 and d hyperparams based on requested changes! also, now we can read the hyperparams from .env file --- dags/hivemind_etl_helpers/discord_query.py | 32 +++++++++++++-- .../src/retrievers/process_dates.py | 38 +++++++++++++++++ .../src/retrievers/utils/__init__.py | 0 .../src/retrievers/utils/load_hyperparams.py | 27 ++++++++++++ ...st_process_dates_forum_retriever_search.py | 41 +++++++++++++++++++ 5 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 dags/hivemind_etl_helpers/src/retrievers/process_dates.py create mode 100644 dags/hivemind_etl_helpers/src/retrievers/utils/__init__.py create mode 100644 dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py create mode 100644 dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py diff --git a/dags/hivemind_etl_helpers/discord_query.py b/dags/hivemind_etl_helpers/discord_query.py index f43fc5be..81689f74 100644 --- a/dags/hivemind_etl_helpers/discord_query.py +++ b/dags/hivemind_etl_helpers/discord_query.py @@ -1,6 +1,8 @@ from hivemind_etl_helpers.src.retrievers.forum_summary_retriever import ( ForumBasedSummaryRetriever, ) +from hivemind_etl_helpers.src.retrievers.process_dates import process_dates +from hivemind_etl_helpers.src.retrievers.utils.load_hyperparams import load_hyperparams from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess from llama_index import QueryBundle @@ -13,6 +15,7 @@ def query_discord( thread_names: list[str], channel_names: list[str], days: list[str], + similarity_top_k: int | None = None, ) -> str: """ query the discord database using filters given @@ -30,12 +33,18 @@ def query_discord( the given channels to search for days : list[str] the given days to search for + similarity_top_k : int | None + the k similar results to use when querying the data + if `None` will load from `.env` file Returns --------- response : str the LLM response given the query """ + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() + table_name = "discord" dbname = f"community_{community_id}" @@ -65,7 +74,9 @@ def query_discord( filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) - query_engine = index.as_query_engine(filters=filters) + query_engine = index.as_query_engine( + filters=filters, similarity_top_k=similarity_top_k + ) query_bundle = QueryBundle( query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) @@ -78,7 +89,8 @@ def query_discord( def query_discord_auto_filter( community_id: str, query: str, - similarity_top_k: int = 20, + similarity_top_k: int | None = None, + d: int | None = None, ) -> str: """ get the query results and do the filtering automatically. @@ -91,6 +103,13 @@ def query_discord_auto_filter( the discord guild data to query query : str the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`query_discord`) + to be done on the `metadata.date - d` to `metadata.date + d` Returns @@ -101,6 +120,11 @@ def query_discord_auto_filter( table_name = "discord_summary" dbname = f"community_{community_id}" + if d is None: + _, _, d = load_hyperparams() + if similarity_top_k is None: + similarity_top_k, _, _ = load_hyperparams() + discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) channels, threads, dates = discord_retriever.retreive_metadata( @@ -111,11 +135,13 @@ def query_discord_auto_filter( similarity_top_k=similarity_top_k, ) + dates_modified = process_dates(dates, d) + response = query_discord( community_id=community_id, query=query, thread_names=threads, channel_names=channels, - days=dates, + days=dates_modified, ) return response diff --git a/dags/hivemind_etl_helpers/src/retrievers/process_dates.py b/dags/hivemind_etl_helpers/src/retrievers/process_dates.py new file mode 100644 index 00000000..cd46cdf6 --- /dev/null +++ b/dags/hivemind_etl_helpers/src/retrievers/process_dates.py @@ -0,0 +1,38 @@ +import logging +from dateutil import parser +from datetime import timedelta + + +def process_dates(dates: list[str], d: int) -> list[str]: + """ + process the dates to be from `date - d` to `date + d` + + Parameters + ------------ + dates : list[str] + the list of dates given + d : int + to update the `dates` list to have `-d` and `+d` days + + + Returns + ---------- + dates_modified : list[str] + days added to it + """ + dates_modified: list[str] = [] + if dates != []: + lowest_date = min(parser.parse(date) for date in dates) + greatest_date = max(parser.parse(date) for date in dates) + + delta_days = timedelta(days=d) + + # the date condition + dt = lowest_date - delta_days + while dt <= greatest_date + delta_days: + dates_modified.append(dt.strftime("%Y-%m-%d")) + dt += timedelta(days=1) + else: + logging.warning("No dates given!") + + return dates_modified diff --git a/dags/hivemind_etl_helpers/src/retrievers/utils/__init__.py b/dags/hivemind_etl_helpers/src/retrievers/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py b/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py new file mode 100644 index 00000000..2f1895d6 --- /dev/null +++ b/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py @@ -0,0 +1,27 @@ +import os + +from dotenv import load_dotenv + + +def load_hyperparams() -> tuple[int, int, int]: + """ + load the k1, k2, and d hyperparams that are used for retrievers + + Returns + --------- + k1 : int + the value for the first summary search + to get the `k1` count similar nodes + k2 : int + the value for the secondary raw search + to get the `k2` count simliar nodes + d : int + the before and after day interval + """ + load_dotenv() + + k1 = os.getenv("K1_RETRIEVER_SEARCH") + k2 = os.getenv("K2_RETRIEVER_SEARCH") + d = os.getenv("D_RETRIEVER_SEARCH") + + return int(k1), int(k2), int(d) diff --git a/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py b/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py new file mode 100644 index 00000000..75f660f7 --- /dev/null +++ b/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py @@ -0,0 +1,41 @@ +import unittest +from hivemind_etl_helpers.src.retrievers.process_dates import process_dates + + +class TestProcessDates(unittest.TestCase): + def test_process_dates_with_valid_input(self): + # Test with a valid input + input_dates = ["2023-01-01", "2023-01-03", "2023-01-05"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_empty_input(self): + # Test with an empty input + input_dates = [] + d = 2 + expected_output = [] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_single_date(self): + # Test with a single date in the input + input_dates = ["2023-01-01"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) From eb126e83f4edb4b77b28e8eecfb7ec76fb7bec2b Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 27 Dec 2023 09:31:10 +0330 Subject: [PATCH 8/9] fix: lint issues and added test case! --- .../src/retrievers/process_dates.py | 3 +- .../src/retrievers/utils/load_hyperparams.py | 7 ++ .../test_load_retriever_hyperparameters.py | 73 +++++++++++++++++++ ...st_process_dates_forum_retriever_search.py | 1 + 4 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 dags/hivemind_etl_helpers/tests/unit/test_load_retriever_hyperparameters.py diff --git a/dags/hivemind_etl_helpers/src/retrievers/process_dates.py b/dags/hivemind_etl_helpers/src/retrievers/process_dates.py index cd46cdf6..dba3217d 100644 --- a/dags/hivemind_etl_helpers/src/retrievers/process_dates.py +++ b/dags/hivemind_etl_helpers/src/retrievers/process_dates.py @@ -1,7 +1,8 @@ import logging -from dateutil import parser from datetime import timedelta +from dateutil import parser + def process_dates(dates: list[str], d: int) -> list[str]: """ diff --git a/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py b/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py index 2f1895d6..98db6ce8 100644 --- a/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py +++ b/dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py @@ -24,4 +24,11 @@ def load_hyperparams() -> tuple[int, int, int]: k2 = os.getenv("K2_RETRIEVER_SEARCH") d = os.getenv("D_RETRIEVER_SEARCH") + if k1 is None: + raise ValueError("No `K1_RETRIEVER_SEARCH` available in .env file!") + if k2 is None: + raise ValueError("No `K2_RETRIEVER_SEARCH` available in .env file!") + if d is None: + raise ValueError("No `D_RETRIEVER_SEARCH` available in .env file!") + return int(k1), int(k2), int(d) diff --git a/dags/hivemind_etl_helpers/tests/unit/test_load_retriever_hyperparameters.py b/dags/hivemind_etl_helpers/tests/unit/test_load_retriever_hyperparameters.py new file mode 100644 index 00000000..b761e06e --- /dev/null +++ b/dags/hivemind_etl_helpers/tests/unit/test_load_retriever_hyperparameters.py @@ -0,0 +1,73 @@ +import unittest +from unittest.mock import patch + +from hivemind_etl_helpers.src.retrievers.utils.load_hyperparams import load_hyperparams + + +class TestLoadHyperparams(unittest.TestCase): + @patch("os.getenv") + def test_valid_hyperparams(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + result = load_hyperparams() + self.assertEqual(result, (10, 20, 30)) + + @patch("os.getenv") + def test_missing_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "invalid", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "invalid", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "invalid", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() diff --git a/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py b/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py index 75f660f7..442a9646 100644 --- a/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py +++ b/dags/hivemind_etl_helpers/tests/unit/test_process_dates_forum_retriever_search.py @@ -1,4 +1,5 @@ import unittest + from hivemind_etl_helpers.src.retrievers.process_dates import process_dates From fe8082bb8e0777afbc8a69e6db0496f6b04beff6 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 27 Dec 2023 10:40:42 +0330 Subject: [PATCH 9/9] feat: updated the docker-compose .env! --- docker-compose.test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index ebb19918..97fbcea1 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -24,6 +24,9 @@ services: - POSTGRES_PORT=5432 - CHUNK_SIZE=512 - EMBEDDING_DIM=1024 + - K1_RETRIEVER_SEARCH=20 + - K2_RETRIEVER_SEARCH=5 + - D_RETRIEVER_SEARCH=7 volumes: - ./coverage:/project/coverage depends_on: