Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/discord-forum retriever #27

Closed
wants to merge 9 commits into from
147 changes: 147 additions & 0 deletions dags/hivemind_etl_helpers/discord_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from hivemind_etl_helpers.src.retrievers.forum_summary_retriever import (
ForumBasedSummaryRetriever,
)
from hivemind_etl_helpers.src.retrievers.process_dates import process_dates
from hivemind_etl_helpers.src.retrievers.utils.load_hyperparams import load_hyperparams
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess
from llama_index import QueryBundle
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters


def query_discord(
community_id: str,
query: str,
thread_names: list[str],
TjitsevdM marked this conversation as resolved.
Show resolved Hide resolved
channel_names: list[str],
days: list[str],
similarity_top_k: int | None = None,
) -> str:
"""
query the discord database using filters given
and give an anwer to the given query using the LLM

Parameters
------------
guild_id : str
the discord guild data to query
query : str
the query (question) of the user
thread_names : list[str]
the given threads to search for
channel_names : list[str]
the given channels to search for
days : list[str]
the given days to search for
similarity_top_k : int | None
the k similar results to use when querying the data
if `None` will load from `.env` file

Returns
---------
response : str
the LLM response given the query
"""
if similarity_top_k is None:
_, similarity_top_k, _ = load_hyperparams()

table_name = "discord"
dbname = f"community_{community_id}"

pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

index = pg_vector.load_index()

thread_filters: list[ExactMatchFilter] = []
channel_filters: list[ExactMatchFilter] = []
day_filters: list[ExactMatchFilter] = []

for channel in channel_names:
channel_updated = channel.replace("'", "''")
channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated))

for thread in thread_names:
thread_updated = thread.replace("'", "''")
thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated))

for day in days:
day_filters.append(ExactMatchFilter(key="date", value=day))

all_filters: list[ExactMatchFilter] = []
all_filters.extend(thread_filters)
all_filters.extend(channel_filters)
all_filters.extend(day_filters)

filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR)

query_engine = index.as_query_engine(
filters=filters, similarity_top_k=similarity_top_k
)

query_bundle = QueryBundle(
query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query)
)
response = query_engine.query(query_bundle)

return response.response


def query_discord_auto_filter(
community_id: str,
query: str,
similarity_top_k: int | None = None,
d: int | None = None,
) -> str:
"""
get the query results and do the filtering automatically.
By automatically we mean, it would first query the summaries
to get the metadata filters

Parameters
-----------
guild_id : str
the discord guild data to query
query : str
the query (question) of the user
similarity_top_k : int | None
the value for the initial summary search
to get the `k2` count simliar nodes
if `None`, then would read from `.env`
d : int
this would make the secondary search (`query_discord`)
to be done on the `metadata.date - d` to `metadata.date + d`


Returns
---------
response : str
the LLM response given the query
"""
table_name = "discord_summary"
dbname = f"community_{community_id}"

if d is None:
_, _, d = load_hyperparams()
if similarity_top_k is None:
similarity_top_k, _, _ = load_hyperparams()

discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname)

channels, threads, dates = discord_retriever.retreive_metadata(
query=query,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
similarity_top_k=similarity_top_k,
)

dates_modified = process_dates(dates, d)

response = query_discord(
community_id=community_id,
query=query,
thread_names=threads,
channel_names=channels,
days=dates_modified,
)
return response
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from llama_index.embeddings import BaseEmbedding


class ForumBasedSummaryRetriever(BaseSummarySearch):
def __init__(
self,
table_name: str,
dbname: str,
embedding_model: BaseEmbedding | CohereEmbedding = CohereEmbedding(),
) -> None:
"""
the class for forum based data like discord and discourse
by default CohereEmbedding will be used.
"""
super().__init__(table_name, dbname, embedding_model=embedding_model)

def retreive_metadata(
self,
query: str,
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
similarity_top_k: int = 20,
) -> tuple[set[str], set[str], set[str]]:
"""
retrieve the metadata information of the similar nodes with the query

Parameters
-----------
query : str
the user query to process
metadata_group1_key : str
the conversations grouping type 1
in discord can be `channel`, and in discourse can be `category`
metadata_group2_key : str
the conversations grouping type 2
in discord can be `thread`, and in discourse can be `topic`
metadata_date_key : str
the daily metadata saved key
similarity_top_k : int
the top k nodes to get as the retriever.
default is set as 20


Returns
---------
group1_data : set[str]
the similar summary nodes having the group1_data.
can be an empty set meaning no similar thread
conversations for it was available.
group2_data : set[str]
the similar summary nodes having the group2_data.
can be an empty set meaning no similar channel
conversations for it was available.
dates : set[str]
the similar daily conversations to the given query
"""
nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k)

group1_data: set[str] = set()
dates: set[str] = set()
group2_data: set[str] = set()

for node in nodes:
if node.metadata[metadata_group1_key]:
group1_data.add(node.metadata[metadata_group1_key])
if node.metadata[metadata_group2_key]:
group2_data.add(node.metadata[metadata_group2_key])
dates.add(node.metadata[metadata_date_key])

return group1_data, group2_data, dates
39 changes: 39 additions & 0 deletions dags/hivemind_etl_helpers/src/retrievers/process_dates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
from datetime import timedelta

from dateutil import parser


def process_dates(dates: list[str], d: int) -> list[str]:
"""
process the dates to be from `date - d` to `date + d`

Parameters
------------
dates : list[str]
the list of dates given
d : int
to update the `dates` list to have `-d` and `+d` days


Returns
----------
dates_modified : list[str]
days added to it
"""
dates_modified: list[str] = []
if dates != []:
lowest_date = min(parser.parse(date) for date in dates)
greatest_date = max(parser.parse(date) for date in dates)

delta_days = timedelta(days=d)

# the date condition
dt = lowest_date - delta_days
while dt <= greatest_date + delta_days:
dates_modified.append(dt.strftime("%Y-%m-%d"))
dt += timedelta(days=1)
else:
logging.warning("No dates given!")

return dates_modified
71 changes: 71 additions & 0 deletions dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess
from llama_index import VectorStoreIndex
from llama_index.embeddings import BaseEmbedding
from llama_index.indices.query.schema import QueryBundle
from llama_index.schema import NodeWithScore


class BaseSummarySearch:
def __init__(
self,
table_name: str,
dbname: str,
embedding_model: BaseEmbedding = CohereEmbedding(),
) -> None:
"""
initialize the base summary search class

In this class we're doing a similarity search
for available saved nodes under postgresql

Parameters
-------------
table_name : str
the table that summary data is saved
*Note:* Don't include the `data_` prefix of the table,
cause lamma_index would original include that.
dbname : str
the database name to access
similarity_top_k : int
the top k nodes to get as the retriever.
default is set as 20
embedding_model : llama_index.embeddings.BaseEmbedding
the embedding model to use for doing embedding on the query string
default would be CohereEmbedding that we've written
"""
self.index = self._setup_index(table_name, dbname)
self.embedding_model = embedding_model

def get_similar_nodes(
self, query: str, similarity_top_k: int = 20
) -> list[NodeWithScore]:
"""
get k similar nodes to the query.
Note: this funciton wold get the embedding
for the query to do the similarity search.

Parameters
------------
query : str
the user query to process
similarity_top_k : int
the top k nodes to get as the retriever.
default is set as 20
"""
retriever = self.index.as_retriever(similarity_top_k=similarity_top_k)

query_embedding = self.embedding_model.get_text_embedding(text=query)

query_bundle = QueryBundle(query_str=query, embedding=query_embedding)
nodes = retriever._retrieve(query_bundle)

return nodes

def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex:
"""
setup the llama_index VectorStoreIndex
"""
pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname)
index = pg_vector_access.load_index()
return index
Empty file.
34 changes: 34 additions & 0 deletions dags/hivemind_etl_helpers/src/retrievers/utils/load_hyperparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

from dotenv import load_dotenv


def load_hyperparams() -> tuple[int, int, int]:
"""
load the k1, k2, and d hyperparams that are used for retrievers

Returns
---------
k1 : int
the value for the first summary search
to get the `k1` count similar nodes
k2 : int
the value for the secondary raw search
to get the `k2` count simliar nodes
d : int
the before and after day interval
"""
load_dotenv()

k1 = os.getenv("K1_RETRIEVER_SEARCH")
k2 = os.getenv("K2_RETRIEVER_SEARCH")
d = os.getenv("D_RETRIEVER_SEARCH")

if k1 is None:
raise ValueError("No `K1_RETRIEVER_SEARCH` available in .env file!")
if k2 is None:
raise ValueError("No `K2_RETRIEVER_SEARCH` available in .env file!")
if d is None:
raise ValueError("No `D_RETRIEVER_SEARCH` available in .env file!")

return int(k1), int(k2), int(d)
Loading