diff --git a/.gitignore b/.gitignore index 2449acc..d994a61 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,7 @@ hivemind_back_env/* .env __pycache__ *.pyc +build/* +tc_hivemind_backend.egg-info main.ipynb \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 0423531..5f6de54 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -21,6 +21,9 @@ services: - SENTRY_ENV=local - CHUNK_SIZE=512 - EMBEDDING_DIM=1024 + - QDRANT_HOST=qdrant + - QDRANT_PORT=6333 + - QDRANT_API_KEY= networks: - python_service_network volumes: @@ -28,6 +31,8 @@ services: depends_on: postgres: condition: service_healthy + qdrant-healthcheck: + condition: service_healthy postgres: image: "ankane/pgvector" environment: @@ -40,6 +45,31 @@ services: retries: 5 networks: - python_service_network + qdrant: + image: qdrant/qdrant:v1.9.2 + restart: always + container_name: qdrant + ports: + - 6333:6333 + - 6334:6334 + expose: + - 6333 + - 6334 + - 6335 + volumes: + - ./qdrant_data:/qdrant_data + qdrant-healthcheck: + restart: always + image: curlimages/curl:latest + entrypoint: ["/bin/sh", "-c", "--", "while true; do sleep 30; done;"] + depends_on: + - qdrant + healthcheck: + test: ["CMD", "curl", "-f", "http://qdrant:6333/readyz"] + interval: 10s + timeout: 2s + retries: 5 + networks: python_service_network: diff --git a/requirements.txt b/requirements.txt index c53b4f2..d384bd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,5 @@ cohere>=4.39, <5.0.0 pgvector asyncpg psycopg2-binary +llama-index-vector-stores-qdrant==0.2.8 +qdrant-client==1.9.1 diff --git a/setup.py b/setup.py index 41d44be..ebd677e 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="tc-hivemind-backend", - version="1.1.6", + version="1.2.0", author="Mohammad Amin Dadgar, TogetherCrew", maintainer="Mohammad Amin Dadgar", maintainer_email="dadgaramin96@gmail.com", diff --git a/tc_hivemind_backend/db/credentials.py b/tc_hivemind_backend/db/credentials.py index e906aa8..81e5af0 100644 --- a/tc_hivemind_backend/db/credentials.py +++ b/tc_hivemind_backend/db/credentials.py @@ -25,3 +25,39 @@ def load_postgres_credentials() -> dict[str, str]: credentials["db_name"] = os.getenv("POSTGRES_DBNAME", "") return credentials + + +def load_qdrant_credentials() -> dict[str, str]: + """ + load qdrant database credentials + + Returns: + --------- + qdrant_creds : dict[str, Any] + redis credentials + a dictionary representive of + `api_key` : str + `host` : str + `port` : int + """ + load_dotenv() + + qdrant_creds: dict[str, str] = {} + + host = os.getenv("QDRANT_HOST") + port = os.getenv("QDRANT_PORT") + api_key = os.getenv("QDRANT_API_KEY") + + if host is None: + raise ValueError("`QDRANT_HOST` is not set in env credentials!") + if port is None: + raise ValueError("`QDRANT_PORT` is not set in env credentials!") + if api_key is None: + raise ValueError("`QDRANT_API_KEY` is not set in env credentials!") + + qdrant_creds = { + "host": host, + "port": port, + "api_key": api_key, + } + return qdrant_creds diff --git a/tc_hivemind_backend/db/qdrant.py b/tc_hivemind_backend/db/qdrant.py new file mode 100644 index 0000000..2003c01 --- /dev/null +++ b/tc_hivemind_backend/db/qdrant.py @@ -0,0 +1,45 @@ +import logging + +from qdrant_client import QdrantClient + +from .credentials import load_qdrant_credentials + + +class QdrantSingleton: + __instance = None + + def __init__(self): + if QdrantSingleton.__instance is not None: + raise Exception("This class is a singleton!") + else: + creds = load_qdrant_credentials() + + # if no api_key was provided + if creds["api_key"] == "": + self.client = QdrantClient( + host=creds["host"], + port=creds["port"], + ) + else: + self.client = QdrantClient( + host=creds["host"], + port=creds["port"], + api_key=creds["api_key"], + ) + + QdrantSingleton.__instance = self + + @staticmethod + def get_instance(): + if QdrantSingleton.__instance is None: + QdrantSingleton() + try: + _ = QdrantSingleton.__instance.client.get_collections() + logging.info("QDrant Connected Successfully!") + except Exception as exp: + logging.error(f"QDrant not connected! exp: {exp}") + + return QdrantSingleton.__instance + + def get_client(self): + return self.client diff --git a/tc_hivemind_backend/qdrant_vector_access.py b/tc_hivemind_backend/qdrant_vector_access.py new file mode 100644 index 0000000..593a7f2 --- /dev/null +++ b/tc_hivemind_backend/qdrant_vector_access.py @@ -0,0 +1,56 @@ +from llama_index.core import MockEmbedding +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.indices.vector_store import VectorStoreIndex +from llama_index.vector_stores.qdrant import QdrantVectorStore +from tc_hivemind_backend.db.qdrant import QdrantSingleton +from tc_hivemind_backend.embeddings import CohereEmbedding + + +class QDrantVectorAccess: + def __init__(self, collection_name: str, testing: bool = False, **kwargs) -> None: + """ + the class to access VectorStoreIndex from qdrant vector db + + Paramters + ----------- + collection_name : str + the qdrant collection name + testing : bool + work with mock LLM and mock embedding model for testing purposes + **kwargs : + embed_model : BaseEmbedding + an embedding model to use for all tasks defined in this class + default is `CohereEmbedding` + """ + self.collection_name = collection_name + self.embed_model: BaseEmbedding = kwargs.get("embed_model", CohereEmbedding()) + + if testing: + self.embed_model = MockEmbedding(embed_dim=1024) + + def setup_qdrant_vector_store(self) -> QdrantVectorStore: + client = QdrantSingleton.get_instance().client + vector_store = QdrantVectorStore( + client=client, + collection_name=self.collection_name, + ) + return vector_store + + def load_index(self, **kwargs) -> VectorStoreIndex: + """ + load the llama_index.VectorStoreIndex + + Parameters + ----------- + **kwargs : + embed_model : BaseEmbedding + the embedding model to use + default is the one set when initializing the class + """ + embed_model: BaseEmbedding = kwargs.get("embed_model", self.embed_model) + vector_store = self.setup_qdrant_vector_store() + index = VectorStoreIndex.from_vector_store( + vector_store=vector_store, + embed_model=embed_model, + ) + return index diff --git a/tc_hivemind_backend/tests/integration/test_qdrant_vector_access.py b/tc_hivemind_backend/tests/integration/test_qdrant_vector_access.py new file mode 100644 index 0000000..f30a2d7 --- /dev/null +++ b/tc_hivemind_backend/tests/integration/test_qdrant_vector_access.py @@ -0,0 +1,34 @@ +from unittest import TestCase + +from llama_index.core import MockEmbedding +from llama_index.core.indices.vector_store import VectorStoreIndex +from llama_index.vector_stores.qdrant import QdrantVectorStore +from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess + + +class TestQdrantVectorAccess(TestCase): + def test_init(self): + expected_collection_name = "sample_collection" + qdrant_vector_access = QDrantVectorAccess( + collection_name=expected_collection_name, + testing=True, + ) + + self.assertEqual(qdrant_vector_access.collection_name, expected_collection_name) + self.assertIsInstance(qdrant_vector_access.embed_model, MockEmbedding) + + def test_setup_index(self): + qdrant_vector_access = QDrantVectorAccess( + collection_name="sample_collection", + testing=True, + ) + vector_store = qdrant_vector_access.setup_qdrant_vector_store() + self.assertIsInstance(vector_store, QdrantVectorStore) + + def test_load_index(self): + qdrant_vector_access = QDrantVectorAccess( + collection_name="sample_collection", + testing=True, + ) + index = qdrant_vector_access.load_index() + self.assertIsInstance(index, VectorStoreIndex) diff --git a/tc_hivemind_backend/tests/unit/test_load_db_credentials.py b/tc_hivemind_backend/tests/unit/test_load_db_credentials.py index b888b0b..b278700 100644 --- a/tc_hivemind_backend/tests/unit/test_load_db_credentials.py +++ b/tc_hivemind_backend/tests/unit/test_load_db_credentials.py @@ -1,6 +1,9 @@ import unittest -from tc_hivemind_backend.db.credentials import load_postgres_credentials +from tc_hivemind_backend.db.credentials import ( + load_postgres_credentials, + load_qdrant_credentials, +) class TestCredentialLoadings(unittest.TestCase): @@ -21,3 +24,10 @@ def test_postgres_envs_values(self): self.assertIsInstance(postgres_creds["password"], str) self.assertIsInstance(postgres_creds["user"], str) self.assertIsInstance(postgres_creds["port"], str) + + def test_load_qdrant_creds(self): + qdrant_creds = load_qdrant_credentials() + + self.assertIsNotNone(qdrant_creds["host"]) + self.assertIsNotNone(qdrant_creds["port"]) + self.assertIsNotNone(qdrant_creds["api_key"])