Skip to content

Commit

Permalink
Merge pull request #13 from TogetherCrew/feat/add-qdrant-support
Browse files Browse the repository at this point in the history
feat: Adding qdrant vector db support!
  • Loading branch information
cyri113 authored May 20, 2024
2 parents a7512fe + dbf7f01 commit 3c08a6e
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ hivemind_back_env/*
.env
__pycache__
*.pyc
build/*
tc_hivemind_backend.egg-info

main.ipynb
30 changes: 30 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@ services:
- SENTRY_ENV=local
- CHUNK_SIZE=512
- EMBEDDING_DIM=1024
- QDRANT_HOST=qdrant
- QDRANT_PORT=6333
- QDRANT_API_KEY=
networks:
- python_service_network
volumes:
- ./coverage:/project/coverage
depends_on:
postgres:
condition: service_healthy
qdrant-healthcheck:
condition: service_healthy
postgres:
image: "ankane/pgvector"
environment:
Expand All @@ -40,6 +45,31 @@ services:
retries: 5
networks:
- python_service_network
qdrant:
image: qdrant/qdrant:v1.9.2
restart: always
container_name: qdrant
ports:
- 6333:6333
- 6334:6334
expose:
- 6333
- 6334
- 6335
volumes:
- ./qdrant_data:/qdrant_data
qdrant-healthcheck:
restart: always
image: curlimages/curl:latest
entrypoint: ["/bin/sh", "-c", "--", "while true; do sleep 30; done;"]
depends_on:
- qdrant
healthcheck:
test: ["CMD", "curl", "-f", "http://qdrant:6333/readyz"]
interval: 10s
timeout: 2s
retries: 5


networks:
python_service_network:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ cohere>=4.39, <5.0.0
pgvector
asyncpg
psycopg2-binary
llama-index-vector-stores-qdrant==0.2.8
qdrant-client==1.9.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="tc-hivemind-backend",
version="1.1.6",
version="1.2.0",
author="Mohammad Amin Dadgar, TogetherCrew",
maintainer="Mohammad Amin Dadgar",
maintainer_email="[email protected]",
Expand Down
36 changes: 36 additions & 0 deletions tc_hivemind_backend/db/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,39 @@ def load_postgres_credentials() -> dict[str, str]:
credentials["db_name"] = os.getenv("POSTGRES_DBNAME", "")

return credentials


def load_qdrant_credentials() -> dict[str, str]:
"""
load qdrant database credentials
Returns:
---------
qdrant_creds : dict[str, Any]
redis credentials
a dictionary representive of
`api_key` : str
`host` : str
`port` : int
"""
load_dotenv()

qdrant_creds: dict[str, str] = {}

host = os.getenv("QDRANT_HOST")
port = os.getenv("QDRANT_PORT")
api_key = os.getenv("QDRANT_API_KEY")

if host is None:
raise ValueError("`QDRANT_HOST` is not set in env credentials!")
if port is None:
raise ValueError("`QDRANT_PORT` is not set in env credentials!")
if api_key is None:
raise ValueError("`QDRANT_API_KEY` is not set in env credentials!")

qdrant_creds = {
"host": host,
"port": port,
"api_key": api_key,
}
return qdrant_creds
45 changes: 45 additions & 0 deletions tc_hivemind_backend/db/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging

from qdrant_client import QdrantClient

from .credentials import load_qdrant_credentials


class QdrantSingleton:
__instance = None

def __init__(self):
if QdrantSingleton.__instance is not None:
raise Exception("This class is a singleton!")
else:
creds = load_qdrant_credentials()

# if no api_key was provided
if creds["api_key"] == "":
self.client = QdrantClient(
host=creds["host"],
port=creds["port"],
)
else:
self.client = QdrantClient(
host=creds["host"],
port=creds["port"],
api_key=creds["api_key"],
)

QdrantSingleton.__instance = self

@staticmethod
def get_instance():
if QdrantSingleton.__instance is None:
QdrantSingleton()
try:
_ = QdrantSingleton.__instance.client.get_collections()
logging.info("QDrant Connected Successfully!")
except Exception as exp:
logging.error(f"QDrant not connected! exp: {exp}")

return QdrantSingleton.__instance

def get_client(self):
return self.client
56 changes: 56 additions & 0 deletions tc_hivemind_backend/qdrant_vector_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from llama_index.core import MockEmbedding
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from tc_hivemind_backend.db.qdrant import QdrantSingleton
from tc_hivemind_backend.embeddings import CohereEmbedding


class QDrantVectorAccess:
def __init__(self, collection_name: str, testing: bool = False, **kwargs) -> None:
"""
the class to access VectorStoreIndex from qdrant vector db
Paramters
-----------
collection_name : str
the qdrant collection name
testing : bool
work with mock LLM and mock embedding model for testing purposes
**kwargs :
embed_model : BaseEmbedding
an embedding model to use for all tasks defined in this class
default is `CohereEmbedding`
"""
self.collection_name = collection_name
self.embed_model: BaseEmbedding = kwargs.get("embed_model", CohereEmbedding())

if testing:
self.embed_model = MockEmbedding(embed_dim=1024)

def setup_qdrant_vector_store(self) -> QdrantVectorStore:
client = QdrantSingleton.get_instance().client
vector_store = QdrantVectorStore(
client=client,
collection_name=self.collection_name,
)
return vector_store

def load_index(self, **kwargs) -> VectorStoreIndex:
"""
load the llama_index.VectorStoreIndex
Parameters
-----------
**kwargs :
embed_model : BaseEmbedding
the embedding model to use
default is the one set when initializing the class
"""
embed_model: BaseEmbedding = kwargs.get("embed_model", self.embed_model)
vector_store = self.setup_qdrant_vector_store()
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
)
return index
34 changes: 34 additions & 0 deletions tc_hivemind_backend/tests/integration/test_qdrant_vector_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest import TestCase

from llama_index.core import MockEmbedding
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess


class TestQdrantVectorAccess(TestCase):
def test_init(self):
expected_collection_name = "sample_collection"
qdrant_vector_access = QDrantVectorAccess(
collection_name=expected_collection_name,
testing=True,
)

self.assertEqual(qdrant_vector_access.collection_name, expected_collection_name)
self.assertIsInstance(qdrant_vector_access.embed_model, MockEmbedding)

def test_setup_index(self):
qdrant_vector_access = QDrantVectorAccess(
collection_name="sample_collection",
testing=True,
)
vector_store = qdrant_vector_access.setup_qdrant_vector_store()
self.assertIsInstance(vector_store, QdrantVectorStore)

def test_load_index(self):
qdrant_vector_access = QDrantVectorAccess(
collection_name="sample_collection",
testing=True,
)
index = qdrant_vector_access.load_index()
self.assertIsInstance(index, VectorStoreIndex)
12 changes: 11 additions & 1 deletion tc_hivemind_backend/tests/unit/test_load_db_credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import unittest

from tc_hivemind_backend.db.credentials import load_postgres_credentials
from tc_hivemind_backend.db.credentials import (
load_postgres_credentials,
load_qdrant_credentials,
)


class TestCredentialLoadings(unittest.TestCase):
Expand All @@ -21,3 +24,10 @@ def test_postgres_envs_values(self):
self.assertIsInstance(postgres_creds["password"], str)
self.assertIsInstance(postgres_creds["user"], str)
self.assertIsInstance(postgres_creds["port"], str)

def test_load_qdrant_creds(self):
qdrant_creds = load_qdrant_credentials()

self.assertIsNotNone(qdrant_creds["host"])
self.assertIsNotNone(qdrant_creds["port"])
self.assertIsNotNone(qdrant_creds["api_key"])

0 comments on commit 3c08a6e

Please sign in to comment.