generated from TogetherCrew/python-library
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from TogetherCrew/feat/add-qdrant-support
feat: Adding qdrant vector db support!
- Loading branch information
Showing
9 changed files
with
217 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,5 +5,7 @@ hivemind_back_env/* | |
.env | ||
__pycache__ | ||
*.pyc | ||
build/* | ||
tc_hivemind_backend.egg-info | ||
|
||
main.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
setup( | ||
name="tc-hivemind-backend", | ||
version="1.1.6", | ||
version="1.2.0", | ||
author="Mohammad Amin Dadgar, TogetherCrew", | ||
maintainer="Mohammad Amin Dadgar", | ||
maintainer_email="[email protected]", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import logging | ||
|
||
from qdrant_client import QdrantClient | ||
|
||
from .credentials import load_qdrant_credentials | ||
|
||
|
||
class QdrantSingleton: | ||
__instance = None | ||
|
||
def __init__(self): | ||
if QdrantSingleton.__instance is not None: | ||
raise Exception("This class is a singleton!") | ||
else: | ||
creds = load_qdrant_credentials() | ||
|
||
# if no api_key was provided | ||
if creds["api_key"] == "": | ||
self.client = QdrantClient( | ||
host=creds["host"], | ||
port=creds["port"], | ||
) | ||
else: | ||
self.client = QdrantClient( | ||
host=creds["host"], | ||
port=creds["port"], | ||
api_key=creds["api_key"], | ||
) | ||
|
||
QdrantSingleton.__instance = self | ||
|
||
@staticmethod | ||
def get_instance(): | ||
if QdrantSingleton.__instance is None: | ||
QdrantSingleton() | ||
try: | ||
_ = QdrantSingleton.__instance.client.get_collections() | ||
logging.info("QDrant Connected Successfully!") | ||
except Exception as exp: | ||
logging.error(f"QDrant not connected! exp: {exp}") | ||
|
||
return QdrantSingleton.__instance | ||
|
||
def get_client(self): | ||
return self.client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from llama_index.core import MockEmbedding | ||
from llama_index.core.base.embeddings.base import BaseEmbedding | ||
from llama_index.core.indices.vector_store import VectorStoreIndex | ||
from llama_index.vector_stores.qdrant import QdrantVectorStore | ||
from tc_hivemind_backend.db.qdrant import QdrantSingleton | ||
from tc_hivemind_backend.embeddings import CohereEmbedding | ||
|
||
|
||
class QDrantVectorAccess: | ||
def __init__(self, collection_name: str, testing: bool = False, **kwargs) -> None: | ||
""" | ||
the class to access VectorStoreIndex from qdrant vector db | ||
Paramters | ||
----------- | ||
collection_name : str | ||
the qdrant collection name | ||
testing : bool | ||
work with mock LLM and mock embedding model for testing purposes | ||
**kwargs : | ||
embed_model : BaseEmbedding | ||
an embedding model to use for all tasks defined in this class | ||
default is `CohereEmbedding` | ||
""" | ||
self.collection_name = collection_name | ||
self.embed_model: BaseEmbedding = kwargs.get("embed_model", CohereEmbedding()) | ||
|
||
if testing: | ||
self.embed_model = MockEmbedding(embed_dim=1024) | ||
|
||
def setup_qdrant_vector_store(self) -> QdrantVectorStore: | ||
client = QdrantSingleton.get_instance().client | ||
vector_store = QdrantVectorStore( | ||
client=client, | ||
collection_name=self.collection_name, | ||
) | ||
return vector_store | ||
|
||
def load_index(self, **kwargs) -> VectorStoreIndex: | ||
""" | ||
load the llama_index.VectorStoreIndex | ||
Parameters | ||
----------- | ||
**kwargs : | ||
embed_model : BaseEmbedding | ||
the embedding model to use | ||
default is the one set when initializing the class | ||
""" | ||
embed_model: BaseEmbedding = kwargs.get("embed_model", self.embed_model) | ||
vector_store = self.setup_qdrant_vector_store() | ||
index = VectorStoreIndex.from_vector_store( | ||
vector_store=vector_store, | ||
embed_model=embed_model, | ||
) | ||
return index |
34 changes: 34 additions & 0 deletions
34
tc_hivemind_backend/tests/integration/test_qdrant_vector_access.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from unittest import TestCase | ||
|
||
from llama_index.core import MockEmbedding | ||
from llama_index.core.indices.vector_store import VectorStoreIndex | ||
from llama_index.vector_stores.qdrant import QdrantVectorStore | ||
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess | ||
|
||
|
||
class TestQdrantVectorAccess(TestCase): | ||
def test_init(self): | ||
expected_collection_name = "sample_collection" | ||
qdrant_vector_access = QDrantVectorAccess( | ||
collection_name=expected_collection_name, | ||
testing=True, | ||
) | ||
|
||
self.assertEqual(qdrant_vector_access.collection_name, expected_collection_name) | ||
self.assertIsInstance(qdrant_vector_access.embed_model, MockEmbedding) | ||
|
||
def test_setup_index(self): | ||
qdrant_vector_access = QDrantVectorAccess( | ||
collection_name="sample_collection", | ||
testing=True, | ||
) | ||
vector_store = qdrant_vector_access.setup_qdrant_vector_store() | ||
self.assertIsInstance(vector_store, QdrantVectorStore) | ||
|
||
def test_load_index(self): | ||
qdrant_vector_access = QDrantVectorAccess( | ||
collection_name="sample_collection", | ||
testing=True, | ||
) | ||
index = qdrant_vector_access.load_index() | ||
self.assertIsInstance(index, VectorStoreIndex) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters