diff --git a/.github/workflows/increment_version_dev.yaml b/.github/workflows/increment_version_dev.yaml index b5303b077..e1734c333 100644 --- a/.github/workflows/increment_version_dev.yaml +++ b/.github/workflows/increment_version_dev.yaml @@ -24,11 +24,11 @@ jobs: - name: Increment versions in pyproject.toml run: | set -x - + echo "Incrementing versions..." find . -name "pyproject.toml" | while read -r pyproject; do echo "Processing $pyproject" - + # Extract current version CURRENT_VERSION=$(python -c " import tomlkit @@ -43,31 +43,30 @@ jobs: print(f'Error reading version from {pyproject}: {e}', end='') exit(1) ") - + echo "Extracted CURRENT_VERSION: $CURRENT_VERSION" - + if [ -z "$CURRENT_VERSION" ]; then echo "Error: Could not extract the current version from $pyproject" cat "$pyproject" continue fi - + # Increment version BASE_VERSION=$(echo "$CURRENT_VERSION" | sed -E 's/(.*)-dev.*/\1/') DEV_PART=$(echo "$CURRENT_VERSION" | grep -oE 'dev[0-9]+$' | grep -oE '[0-9]+') - + + # Fallback if no DEV_PART is found if [ -z "$DEV_PART" ]; then DEV_PART=0 fi - + NEW_DEV_PART=$((DEV_PART + 1)) NEW_VERSION="${BASE_VERSION}-dev${NEW_DEV_PART}" - + echo "Updating version from $CURRENT_VERSION to $NEW_VERSION" done - - - name: Commit changes run: | git config user.name "github-actions[bot]" diff --git a/.github/workflows/sequence_publish.yaml b/.github/workflows/sequence_publish.yaml index 23a88d990..f219023f4 100644 --- a/.github/workflows/sequence_publish.yaml +++ b/.github/workflows/sequence_publish.yaml @@ -93,7 +93,7 @@ jobs: - uses: actions/checkout@v4 - name: Wait for swarmauri - run: sleep 60 + run: sleep 120 - name: Set up Python 3.12 uses: actions/setup-python@v5 diff --git a/.github/workflows/test_changed_files.yaml b/.github/workflows/test_changed_files.yaml index 2f28a48fd..fbe485428 100644 --- a/.github/workflows/test_changed_files.yaml +++ b/.github/workflows/test_changed_files.yaml @@ -126,7 +126,7 @@ jobs: - name: Install package dependencies run: | cd pkgs/${{ matrix.package_tests.package }} - poetry install --no-cache -vv + poetry install --no-cache --all-extras -vv - name: Run all tests for the package run: | diff --git a/pkgs/community/pyproject.toml b/pkgs/community/pyproject.toml index a82494195..256adb6da 100644 --- a/pkgs/community/pyproject.toml +++ b/pkgs/community/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri-community" -version = "0.5.2.dev20" +version = "0.5.3.dev5" description = "This repository includes Swarmauri community components." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -15,48 +15,65 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.10,<3.13" -captcha = "*" -chromadb = "*" -duckdb = "*" -folium = "*" -gensim = "*" -#google-generativeai = "*" -gradio = "*" -leptonai = "0.22.0" -neo4j = "*" -nltk = "*" -#openai = "^1.52.0" -pandas = "*" -psutil = "*" -pygithub = "*" -python-dotenv = "*" -qrcode = "*" -redis = "^4.0" -scikit-learn="^1.4.2" -swarmauri = "==0.5.2" -textstat = "*" -transformers = ">=4.45.0" -typing_extensions = "*" -tiktoken = "*" -pymupdf = "*" -annoy = "*" -qdrant_client = "*" -weaviate = "*" -pinecone-client = { version = "*", extras = ["grpc"] } -PyPDF2 = "*" -pypdftk = "*" -weaviate-client = "*" -protobuf = "^3.20.0" -# Pacmap requires specific version of numba -#numba = ">=0.59.0" -#pacmap = "==0.7.3" +captcha = "^0.6.0" +chromadb = { version = "^0.5.17", optional = true } +duckdb = { version = "^1.1.1", optional = true } +folium = { version = "^0.18.0", optional = true } +gensim = { version = "^4.3.3", optional = true } +gradio = { version = "^5.4.0", optional = true } +leptonai = { version = "^0.22.0", optional = true } +neo4j = { version = "^5.25.0", optional = true } +nltk = { version = "^3.9.1", optional = true } +pandas = "^2.2.3" +psutil = { version = "^6.1.0", optional = true } +pygithub = { version = "^2.4.0", optional = true } +qrcode = { version = "^8.0", optional = true } +redis = { version = "^4.0", optional = true } +swarmauri = "==0.5.3.dev5" +textstat = { version = "^0.7.4", optional = true } +transformers = { version = ">=4.45.0", optional = true } +typing_extensions = "^4.12.2" +tiktoken = { version = "^0.8.0", optional = true } +PyMuPDF = { version = "^1.24.12", optional = true } +annoy = { version = "^1.17.3", optional = true } +qdrant-client = { version = "^1.12.0", optional = true } +pinecone-client = { version = "^5.0.1", optional = true, extras = ["grpc"] } +pypdf = { version = "^5.0.1", optional = true } +pypdftk = { version = "^0.5", optional = true } +weaviate-client = { version = "^4.9.2", optional = true } +#protobuf = { version = "^3.20.0", optional = true } + +[tool.poetry.extras] +# Grouped optional dependencies +nlp = ["nltk", "gensim", "textstat"] +ml_toolkits = ["transformers", "annoy"] +visualization = ["folium"] +storage = ["redis", "duckdb", "neo4j", "chromadb", "qdrant-client", "weaviate", "pinecone-client"] +document_processing = ["pypdf", "PyMuPDF", "pypdftk"] +cloud_integration = ["psutil", "qrcode", "pygithub"] +gradio = ["gradio"] +model_clients = ["leptonai"] +tiktoken = ["tiktoken"] +# Full installation +full = [ + "nltk", "gensim", "textstat", + "transformers", "annoy", + "folium", + "redis", "duckdb", "neo4j", "chromadb", "qdrant_client", "weaviate", "pinecone-client", + "pypdf", "PyMuPDF", "pypdftk", + "psutil", "qrcode", "pygithub", + "gradio", + "leptonai", + "tiktoken" +] [tool.poetry.dev-dependencies] -flake8 = "^7.0" # Add flake8 as a development dependency -pytest = "^8.0" # Ensure pytest is also added if you run tests +flake8 = "^7.0" +pytest = "^8.0" pytest-asyncio = ">=0.24.0" pytest-xdist = "^3.6.1" +python-dotenv = "*" [build-system] requires = ["poetry-core>=1.0.0"] @@ -70,12 +87,10 @@ markers = [ "unit: Unit tests", "integration: Integration tests", "acceptance: Acceptance tests", - "experimental: Experimental tests", + "experimental: Experimental tests" ] - log_cli = true log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" log_cli_date_format = "%Y-%m-%d %H:%M:%S" - asyncio_default_fixture_loop_scope = "function" diff --git a/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py b/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py index a44ceb5c2..941fc5618 100644 --- a/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py @@ -1,3 +1,10 @@ -from swarmauri_community.document_stores.concrete.RedisDocumentStore import ( - RedisDocumentStore, -) +from swarmauri.utils._lazy_import import _lazy_import + +documents_stores_files = [ + ("swarmauri_community.documents_stores.concrete.RedisDocumentStore", "RedisDocumentStore"), +] + +for module_name, class_name in documents_stores_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in documents_stores_files] diff --git a/pkgs/community/swarmauri_community/embeddings/__init__.py b/pkgs/community/swarmauri_community/embeddings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/community/swarmauri_community/embeddings/base/__init__.py b/pkgs/community/swarmauri_community/embeddings/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/embeddings/concrete/Doc2VecEmbedding.py b/pkgs/community/swarmauri_community/embeddings/concrete/Doc2VecEmbedding.py similarity index 100% rename from pkgs/swarmauri/swarmauri/embeddings/concrete/Doc2VecEmbedding.py rename to pkgs/community/swarmauri_community/embeddings/concrete/Doc2VecEmbedding.py diff --git a/pkgs/swarmauri/swarmauri/embeddings/concrete/MlmEmbedding.py b/pkgs/community/swarmauri_community/embeddings/concrete/MlmEmbedding.py similarity index 100% rename from pkgs/swarmauri/swarmauri/embeddings/concrete/MlmEmbedding.py rename to pkgs/community/swarmauri_community/embeddings/concrete/MlmEmbedding.py diff --git a/pkgs/community/swarmauri_community/embeddings/concrete/__init__.py b/pkgs/community/swarmauri_community/embeddings/concrete/__init__.py new file mode 100644 index 000000000..7bcc482d7 --- /dev/null +++ b/pkgs/community/swarmauri_community/embeddings/concrete/__init__.py @@ -0,0 +1,12 @@ +from swarmauri.utils._lazy_import import _lazy_import + + +embeddings_files = [ + ("swarmauri_community.embeddings.concrete.Doc2VecEmbedding", "Doc2VecEmbedding"), + ("swarmauri_community.embeddings.concrete.MlmEmbedding", "MlmEmbedding"), +] + +for module_name, class_name in embeddings_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in embeddings_files] diff --git a/pkgs/community/swarmauri_community/llms/concrete/__init__.py b/pkgs/community/swarmauri_community/llms/concrete/__init__.py index a8fa703c0..5c2266ce5 100644 --- a/pkgs/community/swarmauri_community/llms/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/llms/concrete/__init__.py @@ -1,4 +1,12 @@ -from swarmauri_community.llms.concrete.LeptonAIImgGenModel import LeptonAIImgGenModel -from swarmauri_community.llms.concrete.LeptonAIModel import LeptonAIModel +from swarmauri.utils._lazy_import import _lazy_import -__all__ = ["LeptonAIImgGenModel", "LeptonAIModel"] +llms_files = [ + ("swarmauri_community.llms.concrete.LeptonAIImgGenModel", "LeptonAIImgGenModel"), + ("swarmauri_community.llms.concrete.LeptonAIModel", "LeptonAIModel"), + ("swarmauri_community.llms.concrete.PytesseractImg2TextModel", "PytesseractImg2TextModel"), +] + +for module_name, class_name in llms_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in llms_files] diff --git a/pkgs/community/swarmauri_community/measurements/concrete/__init__.py b/pkgs/community/swarmauri_community/measurements/concrete/__init__.py index 276716315..a748c6a0e 100644 --- a/pkgs/community/swarmauri_community/measurements/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/measurements/concrete/__init__.py @@ -1,6 +1,11 @@ -from swarmauri_community.measurements.concrete.MutualInformationMeasurement import ( - MutualInformationMeasurement, -) -from swarmauri_community.measurements.concrete.TokenCountEstimatorMeasurement import ( - TokenCountEstimatorMeasurement, -) +from swarmauri.utils._lazy_import import _lazy_import + +measurement_files = [ + ("swarmauri_community.measurements.concrete.MutualInformationMeasurement", "MutualInformationMeasurement"), + ("swarmauri_community.measurements.concrete.TokenCountEstimatorMeasurement", "TokenCountEstimatorMeasurement"), +] + +for module_name, class_name in measurement_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in measurement_files] diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/BERTEmbeddingParser.py b/pkgs/community/swarmauri_community/parsers/concrete/BERTEmbeddingParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/BERTEmbeddingParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/BERTEmbeddingParser.py diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/EntityRecognitionParser.py b/pkgs/community/swarmauri_community/parsers/concrete/EntityRecognitionParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/EntityRecognitionParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/EntityRecognitionParser.py diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobNounParser.py b/pkgs/community/swarmauri_community/parsers/concrete/TextBlobNounParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobNounParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/TextBlobNounParser.py diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobSentenceParser.py b/pkgs/community/swarmauri_community/parsers/concrete/TextBlobSentenceParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobSentenceParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/TextBlobSentenceParser.py diff --git a/pkgs/community/swarmauri_community/parsers/concrete/__init__.py b/pkgs/community/swarmauri_community/parsers/concrete/__init__.py index b5d547c4e..3808b4733 100644 --- a/pkgs/community/swarmauri_community/parsers/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/parsers/concrete/__init__.py @@ -1,3 +1,16 @@ -from swarmauri_community.parsers.concrete.FitzPdfParser import PDFtoTextParser -from swarmauri_community.parsers.concrete.PyPDF2Parser import PyPDF2Parser -from swarmauri_community.parsers.concrete.PyPDFTKParser import PyPDFTKParser +from swarmauri.utils._lazy_import import _lazy_import + +parsers_files = [ + ("swarmauri_community.parsers.concrete.BERTEmbeddingParser", "BERTEmbeddingParser"), + ("swarmauri_community.parsers.concrete.EntityRecognitionParser", "EntityRecognitionParser"), + ("swarmauri_community.parsers.concrete.FitzPdfParser", "FitzPdfParser"), + ("swarmauri_community.parsers.concrete.PyPDF2Parser", "PyPDF2Parser"), + ("swarmauri_community.parsers.concrete.PyPDFTKParser", "PyPDFTKParser"), + ("swarmauri_community.parsers.concrete.TextBlobNounParser", "TextBlobNounParser"), + ("swarmauri_community.parsers.concrete.TextBlobSentenceParser", "TextBlobSentenceParser"), +] + +for module_name, class_name in parsers_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in parsers_files] diff --git a/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py b/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py index 000e57ffe..ec089ab66 100644 --- a/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py @@ -1,5 +1,10 @@ -# -*- coding: utf-8 -*- +from swarmauri.utils._lazy_import import _lazy_import -from swarmauri_community.retrievers.concrete.RedisDocumentRetriever import ( - RedisDocumentRetriever, -) +retriever_files = [ + ("swarmauri_community.retrievers.concrete.RedisDocumentRetriever", "RedisDocumentRetriever"), +] + +for module_name, class_name in retriever_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in retriever_files] diff --git a/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py b/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py index 129ad1dc9..6aca27ddc 100644 --- a/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py @@ -1 +1,10 @@ -from swarmauri_community.toolkits.concrete.GithubToolkit import * +from swarmauri.utils._lazy_import import _lazy_import + +toolkits_files = [ + ("swarmauri_community.toolkits.concrete.GithubToolkit", "GithubToolkit"), +] + +for module_name, class_name in toolkits_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in toolkits_files] diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/TextLengthTool.py b/pkgs/community/swarmauri_community/tools/concrete/TextLengthTool.py similarity index 100% rename from pkgs/swarmauri/swarmauri/tools/concrete/TextLengthTool.py rename to pkgs/community/swarmauri_community/tools/concrete/TextLengthTool.py diff --git a/pkgs/community/swarmauri_community/tools/concrete/__init__.py b/pkgs/community/swarmauri_community/tools/concrete/__init__.py index 9db51c17f..fe50a5544 100644 --- a/pkgs/community/swarmauri_community/tools/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/tools/concrete/__init__.py @@ -1,31 +1,33 @@ -from swarmauri_community.tools.concrete.CaptchaGeneratorTool import CaptchaGeneratorTool -from swarmauri_community.tools.concrete.DaleChallReadabilityTool import ( - DaleChallReadabilityTool, -) -from swarmauri_community.tools.concrete.DownloadPdfTool import DownloadPDFTool -from swarmauri_community.tools.concrete.EntityRecognitionTool import ( - EntityRecognitionTool, -) -from swarmauri_community.tools.concrete.FoliumTool import FoliumTool -from swarmauri_community.tools.concrete.GithubBranchTool import GithubBranchTool -from swarmauri_community.tools.concrete.GithubCommitTool import GithubCommitTool -from swarmauri_community.tools.concrete.GithubIssueTool import GithubIssueTool -from swarmauri_community.tools.concrete.GithubPRTool import GithubPRTool -from swarmauri_community.tools.concrete.GithubRepoTool import GithubRepoTool -from swarmauri_community.tools.concrete.GithubTool import GithubTool -from swarmauri_community.tools.concrete.GmailReadTool import GmailReadTool -from swarmauri_community.tools.concrete.GmailSendTool import GmailSendTool -from swarmauri_community.tools.concrete.LexicalDensityTool import LexicalDensityTool -from swarmauri_community.tools.concrete.PsutilTool import PsutilTool -from swarmauri_community.tools.concrete.QrCodeGeneratorTool import QrCodeGeneratorTool -from swarmauri_community.tools.concrete.SentenceComplexityTool import ( - SentenceComplexityTool, -) -from swarmauri_community.tools.concrete.SentimentAnalysisTool import ( - SentimentAnalysisTool, -) -from swarmauri_community.tools.concrete.SMOGIndexTool import SMOGIndexTool -from swarmauri_community.tools.concrete.WebScrapingTool import WebScrapingTool -from swarmauri_community.tools.concrete.ZapierHookTool import ZapierHookTool +from swarmauri.utils._lazy_import import _lazy_import -# from swarmauri_community.tools.concrete.PaCMAPTool import PaCMAPTool +tool_files = [ + ("swarmauri_community.tools.concrete.CaptchaGeneratorTool", "CaptchaGeneratorTool"), + ("swarmauri_community.tools.concrete.DaleChallReadabilityTool", "DaleChallReadabilityTool"), + ("swarmauri_community.tools.concrete.DownloadPdfTool", "DownloadPDFTool"), + ("swarmauri_community.tools.concrete.EntityRecognitionTool", "EntityRecognitionTool"), + ("swarmauri_community.tools.concrete.FoliumTool", "FoliumTool"), + ("swarmauri_community.tools.concrete.GithubBranchTool", "GithubBranchTool"), + ("swarmauri_community.tools.concrete.GithubCommitTool", "GithubCommitTool"), + ("swarmauri_community.tools.concrete.GithubIssueTool", "GithubIssueTool"), + ("swarmauri_community.tools.concrete.GithubPRTool", "GithubPRTool"), + ("swarmauri_community.tools.concrete.GithubRepoTool", "GithubRepoTool"), + ("swarmauri_community.tools.concrete.GithubTool", "GithubTool"), + ("swarmauri_community.tools.concrete.GmailReadTool", "GmailReadTool"), + ("swarmauri_community.tools.concrete.GmailSendTool", "GmailSendTool"), + ("swarmauri_community.tools.concrete.LexicalDensityTool", "LexicalDensityTool"), + ("swarmauri_community.tools.concrete.PsutilTool", "PsutilTool"), + ("swarmauri_community.tools.concrete.QrCodeGeneratorTool", "QrCodeGeneratorTool"), + ("swarmauri_community.tools.concrete.SentenceComplexityTool", "SentenceComplexityTool"), + ("swarmauri_community.tools.concrete.SentimentAnalysisTool", "SentimentAnalysisTool"), + ("swarmauri_community.tools.concrete.SMOGIndexTool", "SMOGIndexTool"), + ("swarmauri_community.tools.concrete.WebScrapingTool", "WebScrapingTool"), + ("swarmauri_community.tools.concrete.ZapierHookTool", "ZapierHookTool"), + # ("swarmauri_community.tools.concrete.PaCMAPTool", "PaCMAPTool"), +] + +# Lazy loading of tools, storing them in variables +for module_name, class_name in tool_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded tools to __all__ +__all__ = [class_name for _, class_name in tool_files] diff --git a/pkgs/community/swarmauri_community/vector_stores/base/__init__.py b/pkgs/community/swarmauri_community/vector_stores/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/community/swarmauri_community/vector_stores/AnnoyVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/AnnoyVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/AnnoyVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/AnnoyVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/CloudQdrantVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudQdrantVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/CloudQdrantVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/CloudQdrantVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudWeaviateVectorStore.py similarity index 97% rename from pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/CloudWeaviateVectorStore.py index 9e55fc0ca..0e1ffced7 100644 --- a/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudWeaviateVectorStore.py @@ -1,218 +1,218 @@ -from typing import List, Union, Literal, Optional -from pydantic import BaseModel, PrivateAttr -import uuid as ud -import weaviate -from weaviate.classes.init import Auth -from weaviate.util import generate_uuid5 -from weaviate.classes.query import MetadataQuery - -from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding -from swarmauri.vectors.concrete.Vector import Vector - -from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase -from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin -from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin -from swarmauri.vector_stores.base.VectorStoreCloudMixin import VectorStoreCloudMixin - - -class CloudWeaviateVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase, VectorStoreCloudMixin): - type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore" - - - # Private attributes - _client: Optional[weaviate.Client] = PrivateAttr(default=None) - _embedder: Doc2VecEmbedding = PrivateAttr(default=None) - _namespace_uuid: ud.UUID = PrivateAttr(default_factory=ud.uuid4) - - def __init__(self, **data): - super().__init__(**data) - - # Initialize the vectorizer and Weaviate client - self._embedder = Doc2VecEmbedding(vector_size=self.vector_size) - # self._initialize_client() - - def connect(self, **kwargs): - """ - Initialize the Weaviate client. - """ - if self._client is None: - self._client = weaviate.connect_to_weaviate_cloud( - cluster_url=self.url, - auth_credentials=Auth.api_key(self.api_key), - headers=kwargs.get("headers", {}) - ) - - def disconnect(self) -> None: - """ - Disconnects from the Qdrant cloud vector store. - """ - if self.client is not None: - self.client = None - - def add_document(self, document: Document) -> None: - """ - Add a single document to the vector store. - - :param document: Document to add - """ - try: - collection = self._client.collections.get(self.collection_name) - - # Generate or use existing embedding - embedding = document.embedding or self._embedder.fit_transform([document.content])[0] - - data_object = { - "content": document.content, - "metadata": document.metadata, - } - - # Generate UUID for document - uuid = ( - str(ud.uuid5(self._namespace_uuid, document.id)) - if document.id - else generate_uuid5(data_object) - ) - - collection.data.insert( - properties=data_object, - vector=embedding.value, - uuid=uuid, - ) - - print(f"Document '{document.id}' added to Weaviate.") - except Exception as e: - print(f"Error adding document '{document.id}': {e}") - raise - - def add_documents(self, documents: List[Document]) -> None: - """ - Add multiple documents to the vector store. - - :param documents: List of documents to add - """ - try: - for document in documents: - self.add_document(document) - - print(f"{len(documents)} documents added to Weaviate.") - except Exception as e: - print(f"Error adding documents: {e}") - raise - - def get_document(self, id: str) -> Union[Document, None]: - """ - Retrieve a document by its ID. - - :param id: Document ID - :return: Document object or None if not found - """ - try: - collection = self._client.collections.get(self.collection_name) - - result = collection.query.fetch_object_by_id(ud.uuid5(self._namespace_uuid, id)) - - if result: - return Document( - id=id, - content=result.properties["content"], - metadata=result.properties["metadata"], - ) - return None - except Exception as e: - print(f"Error retrieving document '{id}': {e}") - return None - - def get_all_documents(self) -> List[Document]: - """ - Retrieve all documents from the vector store. - - :return: List of Document objects - """ - try: - collection = self._client.collections.get(self.collection_name) - # return collection - documents = [ - Document( - content=item.properties["content"], - metadata=item.properties["metadata"], - embedding=Vector(value=list(item.vector.values())[0]), - ) - for item in collection.iterator(include_vector=True) - ] - return documents - except Exception as e: - print(f"Error retrieving all documents: {e}") - return [] - - def delete_document(self, id: str) -> None: - """ - Delete a document by its ID. - - :param id: Document ID - """ - try: - collection = self._client.collections.get(self.collection_name) - collection.data.delete_by_id(ud.uuid5(self._namespace_uuid, id)) - print(f"Document '{id}' has been deleted from Weaviate.") - except Exception as e: - print(f"Error deleting document '{id}': {e}") - raise - - def update_document(self, id: str, document: Document) -> None: - """ - Update an existing document. - - :param id: Document ID - :param updated_document: Document object with updated data - """ - self.delete_document(id) - self.add_document(document) - - def retrieve(self, query: str, top_k: int = 5) -> List[Document]: - """ - Retrieve the top_k most relevant documents based on the given query. - - :param query: Query string - :param top_k: Number of top similar documents to retrieve - :return: List of Document objects - """ - try: - collection = self._client.collections.get(self.collection_name) - query_vector = self._embedder.infer_vector(query) - response = collection.query.near_vector( - near_vector=query_vector.value, - limit=top_k, - return_metadata=MetadataQuery(distance=True), - ) - - documents = [ - Document( - # id=res.id, - content=res.properties["content"], - metadata=res.properties["metadata"], - ) - for res in response.objects - ] - return documents - except Exception as e: - print(f"Error retrieving documents for query '{query}': {e}") - return [] - - def close(self): - """ - Close the connection to the Weaviate server. - """ - if self._client: - self._client.close() - - def model_dump_json(self, *args, **kwargs) -> str: - # Call the disconnect method before serialization - self.disconnect() - - # Now proceed with the usual JSON serialization - return super().model_dump_json(*args, **kwargs) - - - def __del__(self): +from typing import List, Union, Literal, Optional +from pydantic import BaseModel, PrivateAttr +import uuid as ud +import weaviate +from weaviate.classes.init import Auth +from weaviate.util import generate_uuid5 +from weaviate.classes.query import MetadataQuery + +from swarmauri.documents.concrete.Document import Document +from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri.vectors.concrete.Vector import Vector + +from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase +from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin +from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin +from swarmauri.vector_stores.base.VectorStoreCloudMixin import VectorStoreCloudMixin + + +class CloudWeaviateVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase, VectorStoreCloudMixin): + type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore" + + + # Private attributes + _client: Optional[weaviate.Client] = PrivateAttr(default=None) + _embedder: Doc2VecEmbedding = PrivateAttr(default=None) + _namespace_uuid: ud.UUID = PrivateAttr(default_factory=ud.uuid4) + + def __init__(self, **data): + super().__init__(**data) + + # Initialize the vectorizer and Weaviate client + self._embedder = Doc2VecEmbedding(vector_size=self.vector_size) + # self._initialize_client() + + def connect(self, **kwargs): + """ + Initialize the Weaviate client. + """ + if self._client is None: + self._client = weaviate.connect_to_weaviate_cloud( + cluster_url=self.url, + auth_credentials=Auth.api_key(self.api_key), + headers=kwargs.get("headers", {}) + ) + + def disconnect(self) -> None: + """ + Disconnects from the Qdrant cloud vector store. + """ + if self.client is not None: + self.client = None + + def add_document(self, document: Document) -> None: + """ + Add a single document to the vector store. + + :param document: Document to add + """ + try: + collection = self._client.collections.get(self.collection_name) + + # Generate or use existing embedding + embedding = document.embedding or self._embedder.fit_transform([document.content])[0] + + data_object = { + "content": document.content, + "metadata": document.metadata, + } + + # Generate UUID for document + uuid = ( + str(ud.uuid5(self._namespace_uuid, document.id)) + if document.id + else generate_uuid5(data_object) + ) + + collection.data.insert( + properties=data_object, + vector=embedding.value, + uuid=uuid, + ) + + print(f"Document '{document.id}' added to Weaviate.") + except Exception as e: + print(f"Error adding document '{document.id}': {e}") + raise + + def add_documents(self, documents: List[Document]) -> None: + """ + Add multiple documents to the vector store. + + :param documents: List of documents to add + """ + try: + for document in documents: + self.add_document(document) + + print(f"{len(documents)} documents added to Weaviate.") + except Exception as e: + print(f"Error adding documents: {e}") + raise + + def get_document(self, id: str) -> Union[Document, None]: + """ + Retrieve a document by its ID. + + :param id: Document ID + :return: Document object or None if not found + """ + try: + collection = self._client.collections.get(self.collection_name) + + result = collection.query.fetch_object_by_id(ud.uuid5(self._namespace_uuid, id)) + + if result: + return Document( + id=id, + content=result.properties["content"], + metadata=result.properties["metadata"], + ) + return None + except Exception as e: + print(f"Error retrieving document '{id}': {e}") + return None + + def get_all_documents(self) -> List[Document]: + """ + Retrieve all documents from the vector store. + + :return: List of Document objects + """ + try: + collection = self._client.collections.get(self.collection_name) + # return collection + documents = [ + Document( + content=item.properties["content"], + metadata=item.properties["metadata"], + embedding=Vector(value=list(item.vector.values())[0]), + ) + for item in collection.iterator(include_vector=True) + ] + return documents + except Exception as e: + print(f"Error retrieving all documents: {e}") + return [] + + def delete_document(self, id: str) -> None: + """ + Delete a document by its ID. + + :param id: Document ID + """ + try: + collection = self._client.collections.get(self.collection_name) + collection.data.delete_by_id(ud.uuid5(self._namespace_uuid, id)) + print(f"Document '{id}' has been deleted from Weaviate.") + except Exception as e: + print(f"Error deleting document '{id}': {e}") + raise + + def update_document(self, id: str, document: Document) -> None: + """ + Update an existing document. + + :param id: Document ID + :param updated_document: Document object with updated data + """ + self.delete_document(id) + self.add_document(document) + + def retrieve(self, query: str, top_k: int = 5) -> List[Document]: + """ + Retrieve the top_k most relevant documents based on the given query. + + :param query: Query string + :param top_k: Number of top similar documents to retrieve + :return: List of Document objects + """ + try: + collection = self._client.collections.get(self.collection_name) + query_vector = self._embedder.infer_vector(query) + response = collection.query.near_vector( + near_vector=query_vector.value, + limit=top_k, + return_metadata=MetadataQuery(distance=True), + ) + + documents = [ + Document( + # id=res.id, + content=res.properties["content"], + metadata=res.properties["metadata"], + ) + for res in response.objects + ] + return documents + except Exception as e: + print(f"Error retrieving documents for query '{query}': {e}") + return [] + + def close(self): + """ + Close the connection to the Weaviate server. + """ + if self._client: + self._client.close() + + def model_dump_json(self, *args, **kwargs) -> str: + # Call the disconnect method before serialization + self.disconnect() + + # Now proceed with the usual JSON serialization + return super().model_dump_json(*args, **kwargs) + + + def __del__(self): self.close() \ No newline at end of file diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/Doc2VecVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/Doc2VecVectorStore.py similarity index 96% rename from pkgs/swarmauri/swarmauri/vector_stores/concrete/Doc2VecVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/Doc2VecVectorStore.py index cc4bace96..8aca2ee07 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/Doc2VecVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/Doc2VecVectorStore.py @@ -1,8 +1,7 @@ from typing import List, Union, Literal -from pydantic import PrivateAttr from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import ( diff --git a/pkgs/community/swarmauri_community/vector_stores/DuckDBVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/DuckDBVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/DuckDBVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/DuckDBVectorStore.py diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/MlmVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/MlmVectorStore.py similarity index 97% rename from pkgs/swarmauri/swarmauri/vector_stores/concrete/MlmVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/MlmVectorStore.py index ea5902602..21f45ec35 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/MlmVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/MlmVectorStore.py @@ -1,6 +1,6 @@ from typing import List, Union, Literal from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.MlmEmbedding import MlmEmbedding +from swarmauri_community.embeddings.concrete.MlmEmbedding import MlmEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/community/swarmauri_community/vector_stores/Neo4jVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/Neo4jVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/Neo4jVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/Neo4jVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/PersistentChromaDBVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/PersistentChromaDBVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/PersistentChromaDBVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/PersistentChromaDBVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/PersistentQdrantVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/PersistentQdrantVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/PersistentQdrantVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/PersistentQdrantVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/PineconeVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/PineconeVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/PineconeVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/PineconeVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/RedisVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/RedisVectorStore.py similarity index 100% rename from pkgs/community/swarmauri_community/vector_stores/RedisVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/RedisVectorStore.py diff --git a/pkgs/community/swarmauri_community/vector_stores/concrete/__init__.py b/pkgs/community/swarmauri_community/vector_stores/concrete/__init__.py new file mode 100644 index 000000000..f920d22b4 --- /dev/null +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/__init__.py @@ -0,0 +1,20 @@ +from swarmauri.utils._lazy_import import _lazy_import + +vector_store_files = [ + ("swarmauri_community.vector_stores.concrete.AnnoyVectorStore", "AnnoyVectorStore"), + ("swarmauri_community.vector_stores.concrete.CloudQdrantVectorStore", "CloudQdrantVectorStore"), + ("swarmauri_community.vector_stores.concrete.CloudWeaviateVectorStore", "CloudWeaviateVectorStore"), + ("swarmauri_community.vector_stores.concrete.Doc2VecVectorStore", "Doc2VecVectorStore"), + ("swarmauri_community.vector_stores.concrete.DuckDBVectorStore", "DuckDBVectorStore"), + ("swarmauri_community.vector_stores.concrete.MlmVectorStore", "MlmVectorStore"), + ("swarmauri_community.vector_stores.concrete.Neo4jVectorStore", "Neo4jVectorStore"), + ("swarmauri_community.vector_stores.concrete.PersistentChromaDBVectorStore", "PersistentChromaDBVectorStore"), + ("swarmauri_community.vector_stores.concrete.PersistentQdrantVectorStore", "PersistentQdrantVectorStore"), + ("swarmauri_community.vector_stores.concrete.PineconeVectorStore", "PineconeVectorStore"), + ("swarmauri_community.vector_stores.concrete.RedisVectorStore", "RedisVectorStore"), +] + +for module_name, class_name in vector_store_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in vector_store_files] diff --git a/pkgs/swarmauri/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py b/pkgs/community/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py similarity index 87% rename from pkgs/swarmauri/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py rename to pkgs/community/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py index c0dbb3a37..7f3afc447 100644 --- a/pkgs/swarmauri/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py +++ b/pkgs/community/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/swarmauri/tests/unit/embeddings/MlmEmbedding_unit_test.py b/pkgs/community/tests/unit/embeddings/MlmEmbedding_unit_test.py similarity index 87% rename from pkgs/swarmauri/tests/unit/embeddings/MlmEmbedding_unit_test.py rename to pkgs/community/tests/unit/embeddings/MlmEmbedding_unit_test.py index 6962bb802..c015aeac3 100644 --- a/pkgs/swarmauri/tests/unit/embeddings/MlmEmbedding_unit_test.py +++ b/pkgs/community/tests/unit/embeddings/MlmEmbedding_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.embeddings.concrete.MlmEmbedding import MlmEmbedding +from swarmauri_community.embeddings.concrete.MlmEmbedding import MlmEmbedding @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/swarmauri/tests/unit/parsers/TextBlobNounParser_unit_test.py b/pkgs/community/tests/unit/parsers/TextBlobNounParser_unit_test.py similarity index 92% rename from pkgs/swarmauri/tests/unit/parsers/TextBlobNounParser_unit_test.py rename to pkgs/community/tests/unit/parsers/TextBlobNounParser_unit_test.py index 6aa6bec95..e5f8a550c 100644 --- a/pkgs/swarmauri/tests/unit/parsers/TextBlobNounParser_unit_test.py +++ b/pkgs/community/tests/unit/parsers/TextBlobNounParser_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.parsers.concrete.TextBlobNounParser import TextBlobNounParser as Parser +from swarmauri_community.parsers.concrete.TextBlobNounParser import TextBlobNounParser as Parser def setup_module(module): diff --git a/pkgs/swarmauri/tests/unit/parsers/TextBlobSentenceParser_unit_test.py b/pkgs/community/tests/unit/parsers/TextBlobSentenceParser_unit_test.py similarity index 85% rename from pkgs/swarmauri/tests/unit/parsers/TextBlobSentenceParser_unit_test.py rename to pkgs/community/tests/unit/parsers/TextBlobSentenceParser_unit_test.py index a84023b2f..36c347906 100644 --- a/pkgs/swarmauri/tests/unit/parsers/TextBlobSentenceParser_unit_test.py +++ b/pkgs/community/tests/unit/parsers/TextBlobSentenceParser_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.parsers.concrete.TextBlobSentenceParser import TextBlobSentenceParser as Parser +from swarmauri_community.parsers.concrete.TextBlobSentenceParser import TextBlobSentenceParser as Parser @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/swarmauri/tests/unit/tools/TextLength_unit_test.py b/pkgs/community/tests/unit/tools/TextLength_unit_test.py similarity index 96% rename from pkgs/swarmauri/tests/unit/tools/TextLength_unit_test.py rename to pkgs/community/tests/unit/tools/TextLength_unit_test.py index 7bc0c94a1..72d3ff6bc 100644 --- a/pkgs/swarmauri/tests/unit/tools/TextLength_unit_test.py +++ b/pkgs/community/tests/unit/tools/TextLength_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.tools.concrete import TextLengthTool as Tool +from swarmauri_community.tools.concrete import TextLengthTool as Tool @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py index 9f2ce8320..cee7afddd 100644 --- a/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py @@ -1,6 +1,6 @@ import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.AnnoyVectorStore import AnnoyVectorStore +from swarmauri_community.vector_stores.concrete.AnnoyVectorStore import AnnoyVectorStore # Fixture for creating an AnnoyVectorStore instance diff --git a/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py index 2b9028c72..26ee25841 100644 --- a/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.CloudQdrantVectorStore import ( +from swarmauri_community.vector_stores.concrete.CloudQdrantVectorStore import ( CloudQdrantVectorStore, ) diff --git a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py index 051d6aa6b..9bad53191 100644 --- a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.CloudWeaviateVectorStore import ( +from swarmauri_community.vector_stores.concrete.CloudWeaviateVectorStore import ( CloudWeaviateVectorStore, ) from dotenv import load_dotenv diff --git a/pkgs/swarmauri/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py b/pkgs/community/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py similarity index 95% rename from pkgs/swarmauri/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py rename to pkgs/community/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py index 7afcd7097..497e8a45f 100644 --- a/pkgs/swarmauri/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py +++ b/pkgs/community/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py @@ -1,6 +1,6 @@ import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore +from swarmauri_community.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore @pytest.mark.unit diff --git a/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py b/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py index 28bd33080..0b247ccd8 100644 --- a/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py +++ b/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py @@ -2,7 +2,7 @@ import os import json from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.DuckDBVectorStore import DuckDBVectorStore +from swarmauri_community.vector_stores.concrete.DuckDBVectorStore import DuckDBVectorStore @pytest.fixture(params=[":memory:", "test_db.db"]) diff --git a/pkgs/swarmauri/tests/unit/vector_stores/MlmVectorStore_unit_test.py b/pkgs/community/tests/unit/vector_stores/MlmVectorStore_unit_test.py similarity index 90% rename from pkgs/swarmauri/tests/unit/vector_stores/MlmVectorStore_unit_test.py rename to pkgs/community/tests/unit/vector_stores/MlmVectorStore_unit_test.py index 06b0fa263..1c3dc9273 100644 --- a/pkgs/swarmauri/tests/unit/vector_stores/MlmVectorStore_unit_test.py +++ b/pkgs/community/tests/unit/vector_stores/MlmVectorStore_unit_test.py @@ -1,6 +1,6 @@ import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri.vector_stores.concrete.MlmVectorStore import MlmVectorStore +from swarmauri_community.vector_stores.concrete.MlmVectorStore import MlmVectorStore @pytest.mark.unit diff --git a/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py index 74de4851b..d5e4699f6 100644 --- a/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py @@ -2,7 +2,7 @@ import pytest from dotenv import load_dotenv from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.Neo4jVectorStore import Neo4jVectorStore +from swarmauri_community.vector_stores.concrete.Neo4jVectorStore import Neo4jVectorStore # Load environment variables load_dotenv() diff --git a/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py index b973277f1..66e0f0ed3 100644 --- a/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.PersistentChromaDBVectorStore import ( +from swarmauri_community.vector_stores.concrete.PersistentChromaDBVectorStore import ( PersistentChromaDBVectorStore, ) diff --git a/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py index 2277a87a4..d58c4295f 100644 --- a/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.PersistentQdrantVectorStore import ( +from swarmauri_community.vector_stores.concrete.PersistentQdrantVectorStore import ( PersistentQdrantVectorStore, ) diff --git a/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py index de2749df0..803a3f61c 100644 --- a/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.PineconeVectorStore import PineconeVectorStore +from swarmauri_community.vector_stores.concrete.PineconeVectorStore import PineconeVectorStore from dotenv import load_dotenv load_dotenv() diff --git a/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py index 17cbce8a2..80fc1a2d5 100644 --- a/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py @@ -1,8 +1,6 @@ import pytest -import numpy as np -from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.RedisVectorStore import RedisVectorStore from swarmauri.documents.concrete.Document import Document +from swarmauri_community.vector_stores.concrete.RedisVectorStore import RedisVectorStore from dotenv import load_dotenv from os import getenv diff --git a/pkgs/core/pyproject.toml b/pkgs/core/pyproject.toml index 4f5283c21..887619499 100644 --- a/pkgs/core/pyproject.toml +++ b/pkgs/core/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri-core" -version = "0.5.2" +version = "0.5.3.dev3" description = "This repository includes core interfaces for the Swarmauri framework." authors = ["Jacob Stewart "] license = "Apache-2.0" diff --git a/pkgs/core/swarmauri_core/ComponentBase.py b/pkgs/core/swarmauri_core/ComponentBase.py index 9f3e86e9f..4336b2047 100644 --- a/pkgs/core/swarmauri_core/ComponentBase.py +++ b/pkgs/core/swarmauri_core/ComponentBase.py @@ -34,6 +34,7 @@ class ResourceTypes(Enum): DOCUMENT = "Document" EMBEDDING = "Embedding" EXCEPTION = "Exception" + IMAGE_GEN = "ImageGen" LLM = "LLM" MESSAGE = "Message" MEASUREMENT = "Measurement" diff --git a/pkgs/core/swarmauri_core/image_gens/IGenImage.py b/pkgs/core/swarmauri_core/image_gens/IGenImage.py new file mode 100644 index 000000000..79cebf615 --- /dev/null +++ b/pkgs/core/swarmauri_core/image_gens/IGenImage.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod + + +class IGenImage(ABC): + """ + Interface focusing on the basic properties and settings essential for defining image generating models. + """ + + @abstractmethod + def generate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass + + @abstractmethod + async def agenerate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass + + @abstractmethod + def batch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass + + @abstractmethod + async def abatch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass diff --git a/pkgs/core/swarmauri_core/image_gens/__init__.py b/pkgs/core/swarmauri_core/image_gens/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/experimental/pyproject.toml b/pkgs/experimental/pyproject.toml index bea605df3..2bcdb3b7f 100644 --- a/pkgs/experimental/pyproject.toml +++ b/pkgs/experimental/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri-experimental" -version = "0.5.2" +version = "0.5.3.dev5" description = "This repository includes experimental components." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -14,8 +14,8 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.10,<4.0" -swarmauri = "==0.5.2" +python = ">=3.10,<3.13" +swarmauri = "==0.5.3.dev5" gensim = "*" neo4j = "*" numpy = "*" diff --git a/pkgs/swarmauri/pyproject.toml b/pkgs/swarmauri/pyproject.toml index 7ba365c89..23ad0e37f 100644 --- a/pkgs/swarmauri/pyproject.toml +++ b/pkgs/swarmauri/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri" -version = "0.5.2" +version = "0.5.3.dev5" description = "This repository includes base classes, concrete generics, and concrete standard components within the Swarmauri framework." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -15,9 +15,9 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.10,<3.13" -swarmauri_core = "==0.5.2" +swarmauri_core = "==0.5.3.dev3" toml = "^0.10.2" -httpx = "^0.27.2" +httpx = "^0.27.0" joblib = "^1.4.0" numpy = "*" pandas = "*" @@ -34,44 +34,49 @@ aiohttp = { version = "^3.10.10", optional = true } #fal-client = { version = ">=0.5.0", optional = true } #google-generativeai = { version = "^0.8.3", optional = true } #openai = { version = "^1.52.0", optional = true } -nltk = { version = "^3.9.1", optional = true } -textblob = { version = "^0.18.0", optional = true } +#nltk = { version = "^3.9.1", optional = true } +#textblob = { version = "^0.18.0", optional = true } yake = { version = "==0.4.8", optional = true } beautifulsoup4 = { version = "04.12.3", optional = true } -gensim = { version = "==4.3.3", optional = true } +#gensim = { version = "==4.3.3", optional = true } scipy = { version = ">=1.7.0,<1.14.0", optional = true } -scikit-learn = { version = "^1.4.2", optional = true } -spacy = { version = ">=3.0.0,<=3.8.2", optional = true } -transformers = { version = "^4.45.0", optional = true } -torch = { version = "^2.5.0", optional = true } -keras = { version = ">=3.2.0", optional = true } -tf-keras = { version = ">=2.16.0", optional = true } +#scikit-learn = { version = "^1.4.2", optional = true } +#spacy = { version = ">=3.0.0,<=3.8.2", optional = true } +#transformers = { version = "^4.45.0", optional = true } +#torch = { version = "^2.5.0", optional = true } +#keras = { version = ">=3.2.0", optional = true } +#tf-keras = { version = ">=2.16.0", optional = true } matplotlib = { version = ">=3.9.2", optional = true } [tool.poetry.extras] # Extras without versioning, grouped for specific use cases io = ["aiofiles", "aiohttp"] #llms = ["cohere", "mistralai", "fal-client", "google-generativeai", "openai"] -nlp = ["nltk", "textblob", "yake"] +nlp = [ + #"nltk", + #"textblob", + "yake" +] nlp_tools = ["beautifulsoup4"] -ml_toolkits = ["gensim", "scipy", "scikit-learn"] -spacy = ["spacy"] -transformers = ["transformers"] -torch = ["torch"] -tensorflow = ["keras", "tf-keras"] +#ml_toolkits = ["gensim", "scipy", "scikit-learn"] +#spacy = ["spacy"] +#transformers = ["transformers"] +#torch = ["torch"] +#tensorflow = ["keras", "tf-keras"] visualization = ["matplotlib"] # Full option to install all extras full = [ "aiofiles", "aiohttp", #"cohere", "mistralai", "fal-client", "google-generativeai", "openai", - "nltk", "textblob", "yake", + #"nltk", "textblob", + "yake", "beautifulsoup4", - "gensim", "scipy", "scikit-learn", - "spacy", - "transformers", - "torch", - "keras", "tf-keras", + #"gensim", "scipy", "scikit-learn", + #"spacy", + #"transformers", + #"torch", + #"keras", "tf-keras", "matplotlib" ] diff --git a/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py b/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py index 651d9d992..8b75d563e 100644 --- a/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py @@ -1,8 +1,22 @@ -from swarmauri.agent_factories.concrete.agent_factory import AgentFactory -from swarmauri.agent_factories.concrete.conf_driven_agent_factory import ( - ConfDrivenAgentFactory, -) -from JsonAgentFactory import JsonAgentFactory -from swarmauri.agent_factories.concrete.ReflectionAgentFactory import ( - ReflectionAgentFactory, -) +from swarmauri.utils._lazy_import import _lazy_import + +# List of agent factory names (file names without the ".py" extension) and corresponding class names +agent_factory_files = [ + ("swarmauri.agent_factories.concrete.agent_factory", "AgentFactory"), + ( + "swarmauri.agent_factories.concrete.conf_driven_agent_factory", + "ConfDrivenAgentFactory", + ), + ("swarmauri.agent_factories.concrete.JsonAgentFactory", "JsonAgentFactory"), + ( + "swarmauri.agent_factories.concrete.ReflectionAgentFactory", + "ReflectionAgentFactory", + ), +] + +# Lazy loading of agent factories storing them in variables +for module_name, class_name in agent_factory_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded agent factories to __all__ +__all__ = [class_name for _, class_name in agent_factory_files] diff --git a/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py b/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py index f474dae0f..0103905cc 100644 --- a/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py @@ -1,27 +1,10 @@ -import importlib - -# Define a lazy loader function with a warning message if the module or class is not found -def _lazy_import(module_name, class_name): - try: - # Import the module - module = importlib.import_module(module_name) - # Dynamically get the class from the module - return getattr(module, class_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - except AttributeError: - # If class is not found, print a warning message - print(f"Warning: The class '{class_name}' was not found in module '{module_name}'.") - return None +from swarmauri.utils._lazy_import import _lazy_import # List of agent names (file names without the ".py" extension) and corresponding class names agent_files = [ - ("swarmauri.agents.concrete.SimpleConversationAgent", "SimpleConversationAgent"), ("swarmauri.agents.concrete.QAAgent", "QAAgent"), ("swarmauri.agents.concrete.RagAgent", "RagAgent"), + ("swarmauri.agents.concrete.SimpleConversationAgent", "SimpleConversationAgent"), ("swarmauri.agents.concrete.ToolAgent", "ToolAgent"), ] diff --git a/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py b/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py index efdd73eff..d6e508040 100644 --- a/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py @@ -1,11 +1,15 @@ -from swarmauri.chains.concrete.CallableChain import CallableChain -from swarmauri.chains.concrete.ChainStep import ChainStep -from swarmauri.chains.concrete.PromptContextChain import PromptContextChain -from swarmauri.chains.concrete.ContextChain import ContextChain +from swarmauri.utils._lazy_import import _lazy_import -__all__ = [ - "CallableChain", - "ChainStep", - "PromptContextChain", - "ContextChain", +chains_files = [ + ("swarmauri.chains.concrete.CallableChain import", "CallableChain"), + ("swarmauri.chains.concrete.ChainStep", "ChainStep"), + ("swarmauri.chains.concrete.PromptContextChain", "PromptContextChain"), + ("swarmauri.chains.concrete.ContextChain", "ContextChain"), ] + +# Lazy loading of chain classes, storing them in variables +for module_name, class_name in chains_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded chain classes to __all__ +__all__ = [class_name for _, class_name in chains_files] diff --git a/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py b/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py index f894163ad..fafead5cf 100644 --- a/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py @@ -1,13 +1,17 @@ -from swarmauri.chunkers.concrete.DelimiterBasedChunker import DelimiterBasedChunker -from swarmauri.chunkers.concrete.FixedLengthChunker import FixedLengthChunker -from swarmauri.chunkers.concrete.MdSnippetChunker import MdSnippetChunker -from swarmauri.chunkers.concrete.SentenceChunker import SentenceChunker -from swarmauri.chunkers.concrete.SlidingWindowChunker import SlidingWindowChunker +from swarmauri.utils._lazy_import import _lazy_import -__all__ = [ - "DelimiterBasedChunker", - "FixedLengthChunker", - "MdSnippetChunker", - "SentenceChunker", - "SlidingWindowChunker", +# List of chunker names (file names without the ".py" extension) and corresponding class names +chunkers_files = [ + ("swarmauri.chunkers.concrete.DelimiterBasedChunker", "DelimiterBasedChunker"), + ("swarmauri.chunkers.concrete.FixedLengthChunker", "FixedLengthChunker"), + ("swarmauri.chunkers.concrete.MdSnippetChunker", "MdSnippetChunker"), + ("swarmauri.chunkers.concrete.SentenceChunker", "SentenceChunker"), + ("swarmauri.chunkers.concrete.SlidingWindowChunker", "SlidingWindowChunker"), ] + +# Lazy loading of chunker classes, storing them in variables +for module_name, class_name in chunkers_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded chunker classes to __all__ +__all__ = [class_name for _, class_name in chunkers_files] diff --git a/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py b/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py index e51d24fe0..46179d6fb 100644 --- a/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py @@ -1,15 +1,22 @@ -from swarmauri.conversations.concrete.Conversation import Conversation -from swarmauri.conversations.concrete.MaxSystemContextConversation import ( - MaxSystemContextConversation, -) -from swarmauri.conversations.concrete.MaxSizeConversation import MaxSizeConversation -from swarmauri.conversations.concrete.SessionCacheConversation import ( - SessionCacheConversation, -) +from swarmauri.utils._lazy_import import _lazy_import -__all__ = [ - "Conversation", - "MaxSystemContextConversation", - "MaxSizeConversation", - "SessionCacheConversation", +# List of conversations names (file names without the ".py" extension) and corresponding class names +conversations_files = [ + ("swarmauri.conversations.concrete.Conversation", "Conversation"), + ( + "swarmauri.conversations.concrete.MaxSystemContextConversation", + "MaxSystemContextConversation", + ), + ("swarmauri.conversations.concrete.MaxSizeConversation", "MaxSizeConversation"), + ( + "swarmauri.conversations.concrete.SessionCacheConversation", + "SessionCacheConversation", + ), ] + +# Lazy loading of conversations classes, storing them in variables +for module_name, class_name in conversations_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded conversations classes to __all__ +__all__ = [class_name for _, class_name in conversations_files] diff --git a/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py b/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py index 9e163ca4d..033a0dd13 100644 --- a/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py @@ -1,34 +1,27 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of distance names (file names without the ".py" extension) -distance_files = [ - "CanberraDistance", - "ChebyshevDistance", - "ChiSquaredDistance", - "CosineDistance", - "EuclideanDistance", - "HaversineDistance", - "JaccardIndexDistance", - "LevenshteinDistance", - "ManhattanDistance", - "MinkowskiDistance", - "SorensenDiceDistance", - "SquaredEuclideanDistance", +# List of distances names (file names without the ".py" extension) and corresponding class names +distances_files = [ + ("swarmauri.distances.concrete.CanberraDistance", "CanberraDistance"), + ("swarmauri.distances.concrete.ChebyshevDistance", "ChebyshevDistance"), + ("swarmauri.distances.concrete.ChiSquaredDistance", "ChiSquaredDistance"), + ("swarmauri.distances.concrete.CosineDistance", "CosineDistance"), + ("swarmauri.distances.concrete.EuclideanDistance", "EuclideanDistance"), + ("swarmauri.distances.concrete.HaversineDistance", "HaversineDistance"), + ("swarmauri.distances.concrete.JaccardIndexDistance", "JaccardIndexDistance"), + ("swarmauri.distances.concrete.LevenshteinDistance", "LevenshteinDistance"), + ("swarmauri.distances.concrete.ManhattanDistance", "ManhattanDistance"), + ("swarmauri.distances.concrete.MinkowskiDistance", "MinkowskiDistance"), + ("swarmauri.distances.concrete.SorensenDiceDistance", "SorensenDiceDistance"), + ( + "swarmauri.distances.concrete.SquaredEuclideanDistance", + "SquaredEuclideanDistance", + ), ] -# Lazy loading of distance modules, storing them in variables -for distance in distance_files: - globals()[distance] = _lazy_import(f"swarmauri.distances.concrete.{distance}", distance) +# Lazy loading of distances classes, storing them in variables +for module_name, class_name in distances_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded distance modules to __all__ -__all__ = distance_files +# Adding the lazy-loaded distances classes to __all__ +__all__ = [class_name for _, class_name in distances_files] diff --git a/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py b/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py index f0725fde0..c4b50e1a5 100644 --- a/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py @@ -1 +1,11 @@ -from swarmauri.documents.concrete import * +from swarmauri.utils._lazy_import import _lazy_import + +# List of documents names (file names without the ".py" extension) and corresponding class names +documents_files = [("swarmauri.documents.concrete.Document", "Document")] + +# Lazy loading of documents classes, storing them in variables +for module_name, class_name in documents_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded documents classes to __all__ +__all__ = [class_name for _, class_name in documents_files] diff --git a/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py b/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py index a1f0f231c..c6d12f871 100644 --- a/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py @@ -1,31 +1,19 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None +# List of embeddings names (file names without the ".py" extension) and corresponding class names +embeddings_files = [ + ("swarmauri.embeddings.concrete.CohereEmbedding", "CohereEmbedding"), + ("swarmauri.embeddings.concrete.GeminiEmbedding", "GeminiEmbedding"), + ("swarmauri.embeddings.concrete.MistralEmbedding", "MistralEmbedding"), + ("swarmauri.embeddings.concrete.NmfEmbedding", "NmfEmbedding"), + ("swarmauri.embeddings.concrete.OpenAIEmbedding", "OpenAIEmbedding"), + ("swarmauri.embeddings.concrete.TfidfEmbedding", "TfidfEmbedding"), + ("swarmauri.embeddings.concrete.VoyageEmbedding", "VoyageEmbedding"), +] -# Lazy loading of embeddings with descriptive names -Doc2VecEmbedding = _lazy_import("swarmauri.embeddings.concrete.Doc2VecEmbedding", "Doc2VecEmbedding") -GeminiEmbedding = _lazy_import("swarmauri.embeddings.concrete.GeminiEmbedding", "GeminiEmbedding") -MistralEmbedding = _lazy_import("swarmauri.embeddings.concrete.MistralEmbedding", "MistralEmbedding") -MlmEmbedding = _lazy_import("swarmauri.embeddings.concrete.MlmEmbedding", "MlmEmbedding") -NmfEmbedding = _lazy_import("swarmauri.embeddings.concrete.NmfEmbedding", "NmfEmbedding") -OpenAIEmbedding = _lazy_import("swarmauri.embeddings.concrete.OpenAIEmbedding", "OpenAIEmbedding") -TfidfEmbedding = _lazy_import("swarmauri.embeddings.concrete.TfidfEmbedding", "TfidfEmbedding") +# Lazy loading of embeddings classes, storing them in variables +for module_name, class_name in embeddings_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding lazy-loaded modules to __all__ -__all__ = [ - "Doc2VecEmbedding", - "GeminiEmbedding", - "MistralEmbedding", - "MlmEmbedding", - "NmfEmbedding", - "OpenAIEmbedding", - "TfidfEmbedding", -] +# Adding the lazy-loaded embeddings classes to __all__ +__all__ = [class_name for _, class_name in embeddings_files] diff --git a/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py b/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py index 43b631bc1..2baf7a56d 100644 --- a/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py @@ -1,3 +1,13 @@ -from swarmauri.exceptions.concrete.IndexErrorWithContext import IndexErrorWithContext +from swarmauri.utils._lazy_import import _lazy_import -__all__ = ["IndexErrorWithContext"] +# List of exceptions names (file names without the ".py" extension) and corresponding class names +exceptions_files = [ + ("swarmauri.exceptions.concrete.IndexErrorWithContext", "IndexErrorWithContext"), +] + +# Lazy loading of exceptions classes, storing them in variables +for module_name, class_name in exceptions_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded exceptions classes to __all__ +__all__ = [class_name for _, class_name in exceptions_files] diff --git a/pkgs/swarmauri/swarmauri/image_gens/__init_.py b/pkgs/swarmauri/swarmauri/image_gens/__init_.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/image_gens/base/ImageGenBase.py b/pkgs/swarmauri/swarmauri/image_gens/base/ImageGenBase.py new file mode 100644 index 000000000..ab3dff1f2 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/image_gens/base/ImageGenBase.py @@ -0,0 +1,51 @@ +from abc import abstractmethod +from typing import Optional, List, Literal +from pydantic import ConfigDict, model_validator, Field +from swarmauri_core.image_gens.IGenImage import IGenImage +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes + + +class ImageGenBase(IGenImage, ComponentBase): + allowed_models: List[str] = [] + resource: Optional[str] = Field(default=ResourceTypes.IMAGE_GEN.value, frozen=True) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: Literal["ImageGenBase"] = "ImageGenBase" + + @model_validator(mode="after") + @classmethod + def _validate_name_in_allowed_models(cls, values): + name = values.name + allowed_models = values.allowed_models + if name and name not in allowed_models: + raise ValueError( + f"Model name {name} is not allowed. Choose from {allowed_models}" + ) + return values + + @abstractmethod + def generate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("generate_image() not implemented in subclass yet.") + + @abstractmethod + async def agenerate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("agenerate_image() not implemented in subclass yet.") + + @abstractmethod + def batch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("batch_generate() not implemented in subclass yet.") + + @abstractmethod + async def abatch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("abatch_generate() not implemented in subclass yet.") diff --git a/pkgs/swarmauri/swarmauri/image_gens/base/__init__.py b/pkgs/swarmauri/swarmauri/image_gens/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/BlackForestImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/BlackForestImgGenModel.py similarity index 96% rename from pkgs/swarmauri/swarmauri/llms/concrete/BlackForestImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/BlackForestImgGenModel.py index 50d395394..d783c4204 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/BlackForestImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/BlackForestImgGenModel.py @@ -1,35 +1,35 @@ import httpx import time -from typing import List, Literal, Optional, Dict, ClassVar +from typing import List, Literal, Optional, Dict from pydantic import PrivateAttr from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase import asyncio import contextlib -class BlackForestImgGenModel(LLMBase): +class BlackForestImgGenModel(ImageGenBase): """ A model for generating images using FluxPro's image generation models through the Black Forest API. Link to API key: https://api.bfl.ml/auth/profile """ _BASE_URL: str = PrivateAttr("https://api.bfl.ml") - _client: httpx.Client = PrivateAttr() + _client: httpx.Client = PrivateAttr(default=None) _async_client: httpx.AsyncClient = PrivateAttr(default=None) + _headers: Dict[str, str] = PrivateAttr(default=None) api_key: str allowed_models: List[str] = ["flux-pro-1.1", "flux-pro", "flux-dev"] - asyncio: ClassVar = asyncio name: str = "flux-pro" # Default model type: Literal["BlackForestImgGenModel"] = "BlackForestImgGenModel" - def __init__(self, **data): + def __init__(self, **kwargs): """ Initializes the BlackForestImgGenModel instance with HTTP clients. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Content-Type": "application/json", "X-Key": self.api_key, diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/DeepInfraImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/DeepInfraImgGenModel.py similarity index 97% rename from pkgs/swarmauri/swarmauri/llms/concrete/DeepInfraImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/DeepInfraImgGenModel.py index 56afc2105..5dcc35d98 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/DeepInfraImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/DeepInfraImgGenModel.py @@ -2,12 +2,12 @@ from typing import List, Literal from pydantic import PrivateAttr from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase import asyncio import contextlib -class DeepInfraImgGenModel(LLMBase): +class DeepInfraImgGenModel(ImageGenBase): """ A model class for generating images from text prompts using DeepInfra's image generation API. @@ -37,7 +37,7 @@ class DeepInfraImgGenModel(LLMBase): name: str = "stabilityai/stable-diffusion-2-1" # Default model type: Literal["DeepInfraImgGenModel"] = "DeepInfraImgGenModel" - def __init__(self, **data): + def __init__(self, **kwargs): """ Initializes the DeepInfraImgGenModel instance. @@ -47,7 +47,7 @@ def __init__(self, **data): Args: **data: Keyword arguments for model initialization. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/FalAIImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/FalAIImgGenModel.py similarity index 98% rename from pkgs/swarmauri/swarmauri/llms/concrete/FalAIImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/FalAIImgGenModel.py index 6943d1e59..b6eadb3b7 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/FalAIImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/FalAIImgGenModel.py @@ -3,11 +3,11 @@ from typing import List, Literal, Optional, Dict from pydantic import Field, PrivateAttr from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase import time -class FalAIImgGenModel(LLMBase): +class FalAIImgGenModel(ImageGenBase): """ A model class for generating images from text using FluxPro's image generation model, provided by FalAI. This class uses a queue-based API to handle image generation requests. @@ -34,7 +34,7 @@ class FalAIImgGenModel(LLMBase): max_retries: int = Field(default=60) # Maximum number of status check retries retry_delay: float = Field(default=1.0) # Delay between status checks in seconds - def __init__(self, **data): + def __init__(self, **kwargs): """ Initializes the model with the specified API key and model name. @@ -44,7 +44,7 @@ def __init__(self, **data): Raises: ValueError: If an invalid model name is provided. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Content-Type": "application/json", "Authorization": f"Key {self.api_key}", diff --git a/pkgs/swarmauri/swarmauri/image_gens/concrete/HyperbolicImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/HyperbolicImgGenModel.py new file mode 100644 index 000000000..43f7dd60b --- /dev/null +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/HyperbolicImgGenModel.py @@ -0,0 +1,210 @@ +import httpx +from typing import List, Literal +from pydantic import PrivateAttr +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase +import asyncio +import contextlib + + +class HyperbolicImgGenModel(ImageGenBase): + """ + A model class for generating images from text prompts using Hyperbolic's image generation API. + + Attributes: + api_key (str): The API key for authenticating with the Hyperbolic API. + allowed_models (List[str]): A list of available models for image generation. + asyncio (ClassVar): The asyncio module for handling asynchronous operations. + name (str): The name of the model to be used for image generation. + type (Literal["HyperbolicImgGenModel"]): The type identifier for the model class. + height (int): Height of the generated image. + width (int): Width of the generated image. + steps (int): Number of inference steps. + cfg_scale (float): Classifier-free guidance scale. + enable_refiner (bool): Whether to enable the refiner model. + backend (str): Computational backend for the model. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + _BASE_URL: str = PrivateAttr("https://api.hyperbolic.xyz/v1/image/generation") + _client: httpx.Client = PrivateAttr(default=None) + _async_client: httpx.AsyncClient = PrivateAttr(default=None) + + api_key: str + allowed_models: List[str] = [ + "SDXL1.0-base", + "SD1.5", + "SSD", + "SD2", + "SDXL-turbo", + ] + + name: str = "SDXL1.0-base" # Default model + type: Literal["HyperbolicImgGenModel"] = "HyperbolicImgGenModel" + + # Additional configuration parameters + height: int = 1024 + width: int = 1024 + steps: int = 30 + cfg_scale: float = 5.0 + enable_refiner: bool = False + backend: str = "auto" + + def __init__(self, **kwargs): + """ + Initializes the HyperbolicImgGenModel instance. + + This constructor sets up HTTP clients for both synchronous and asynchronous + operations and configures request headers with the provided API key. + + Args: + **data: Keyword arguments for model initialization. + """ + super().__init__(**kwargs) + self._headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self._client = httpx.Client(headers=self._headers, timeout=30) + + async def _get_async_client(self) -> httpx.AsyncClient: + """ + Gets or creates an async client instance. + """ + if self._async_client is None or self._async_client.is_closed: + self._async_client = httpx.AsyncClient(headers=self._headers, timeout=30) + return self._async_client + + async def _close_async_client(self): + """ + Closes the async client if it exists and is open. + """ + if self._async_client is not None and not self._async_client.is_closed: + await self._async_client.aclose() + self._async_client = None + + def _create_request_payload(self, prompt: str) -> dict: + """ + Creates the payload for the image generation request. + """ + return { + "model_name": self.name, + "prompt": prompt, + "height": self.height, + "width": self.width, + "steps": self.steps, + "cfg_scale": self.cfg_scale, + "enable_refiner": self.enable_refiner, + "backend": self.backend, + } + + @retry_on_status_codes((429, 529), max_retries=1) + def _send_request(self, prompt: str) -> dict: + """ + Sends a synchronous request to the Hyperbolic API for image generation. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + dict: The response data from the API. + """ + payload = self._create_request_payload(prompt) + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + return response.json() + + @retry_on_status_codes((429, 529), max_retries=1) + async def _async_send_request(self, prompt: str) -> dict: + """ + Sends an asynchronous request to the Hyperbolic API for image generation. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + dict: The response data from the API. + """ + client = await self._get_async_client() + payload = self._create_request_payload(prompt) + response = await client.post(self._BASE_URL, json=payload) + response.raise_for_status() + return response.json() + + def generate_image_base64(self, prompt: str) -> str: + """ + Generates an image synchronously based on the provided prompt and returns it as a base64-encoded string. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + str: The base64-encoded representation of the generated image. + """ + response_data = self._send_request(prompt) + return response_data["images"][0]["image"] + + async def agenerate_image_base64(self, prompt: str) -> str: + """ + Generates an image asynchronously based on the provided prompt and returns it as a base64-encoded string. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + str: The base64-encoded representation of the generated image. + """ + try: + response_data = await self._async_send_request(prompt) + return response_data["images"][0]["image"] + finally: + await self._close_async_client() + + def batch_base64(self, prompts: List[str]) -> List[str]: + """ + Generates images for a batch of prompts synchronously and returns them as a list of base64-encoded strings. + + Args: + prompts (List[str]): A list of text prompts for image generation. + + Returns: + List[str]: A list of base64-encoded representations of the generated images. + """ + return [self.generate_image_base64(prompt) for prompt in prompts] + + async def abatch_base64( + self, prompts: List[str], max_concurrent: int = 5 + ) -> List[str]: + """ + Generates images for a batch of prompts asynchronously and returns them as a list of base64-encoded strings. + + Args: + prompts (List[str]): A list of text prompts for image generation. + max_concurrent (int): The maximum number of concurrent tasks. + + Returns: + List[str]: A list of base64-encoded representations of the generated images. + """ + try: + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_prompt(prompt): + async with semaphore: + response_data = await self._async_send_request(prompt) + return response_data["images"][0]["image"] + + tasks = [process_prompt(prompt) for prompt in prompts] + return await asyncio.gather(*tasks) + finally: + await self._close_async_client() + + def __del__(self): + """ + Cleanup method to ensure clients are closed. + """ + self._client.close() + if self._async_client is not None and not self._async_client.is_closed: + with contextlib.suppress(Exception): + asyncio.run(self._close_async_client()) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/OpenAIImgGenModel.py similarity index 97% rename from pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/OpenAIImgGenModel.py index ad78fd7d8..8862799e5 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/OpenAIImgGenModel.py @@ -2,11 +2,11 @@ import asyncio import httpx from typing import Dict, List, Literal, Optional +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase -class OpenAIImgGenModel(LLMBase): +class OpenAIImgGenModel(ImageGenBase): """ OpenAIImgGenModel is a class for generating images using OpenAI's DALL-E models. @@ -26,14 +26,14 @@ class OpenAIImgGenModel(LLMBase): _BASE_URL: str = PrivateAttr(default="https://api.openai.com/v1/images/generations") _headers: Dict[str, str] = PrivateAttr(default=None) - def __init__(self, **data) -> None: + def __init__(self, **kwargs) -> None: """ Initialize the GroqAIAudio class with the provided data. Args: **data: Arbitrary keyword arguments containing initialization data. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", diff --git a/pkgs/swarmauri/swarmauri/image_gens/concrete/__init__.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicAudioTTS.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicAudioTTS.py new file mode 100644 index 000000000..b6ff4de67 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicAudioTTS.py @@ -0,0 +1,146 @@ +import base64 +import io +import os +from typing import AsyncIterator, Iterator, List, Literal, Dict, Optional +import httpx +from pydantic import PrivateAttr, model_validator, Field +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.llms.base.LLMBase import LLMBase +import asyncio + + +class HyperbolicAudioTTS(LLMBase): + """ + A class to interact with Hyperbolic's Text-to-Speech API, allowing for synchronous + and asynchronous text-to-speech synthesis. + + Attributes: + api_key (str): The API key for accessing Hyperbolic's TTS service. + language (Optional[str]): Language of the text. + speaker (Optional[str]): Specific speaker variant. + speed (Optional[float]): Speech speed control. + + Provider Resource: https://api.hyperbolic.xyz/v1/audio/generation + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + api_key: str + + # Supported languages + allowed_languages: List[str] = ["EN", "ES", "FR", "ZH", "JP", "KR"] + + # Supported speakers per language + allowed_speakers: Dict[str, List[str]] = { + "EN": ["EN-US", "EN-BR", "EN-INDIA", "EN-AU"], + "ES": ["ES"], + "FR": ["FR"], + "ZH": ["ZH"], + "JP": ["JP"], + "KR": ["KR"], + } + + # Optional parameters with type hints and validation + language: Optional[str] = None + speaker: Optional[str] = None + speed: Optional[float] = Field(default=1.0, ge=0.1, le=5) + + type: Literal["HyperbolicAudioTTS"] = "HyperbolicAudioTTS" + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/audio/generation" + ) + _headers: Dict[str, str] = PrivateAttr(default=None) + + def __init__(self, **data): + """ + Initialize the HyperbolicAudioTTS class with the provided data. + """ + super().__init__(**data) + self._headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + def _prepare_payload(self, text: str) -> Dict: + """ + Prepare the payload for the TTS request. + """ + payload = {"text": text} + + # Add optional parameters if they are set + if self.language: + payload["language"] = self.language + if self.speaker: + payload["speaker"] = self.speaker + if self.speed is not None: + payload["speed"] = self.speed + + return payload + + def predict(self, text: str, audio_path: str = "output.mp3") -> str: + """ + Synchronously converts text to speech. + """ + payload = self._prepare_payload(text) + + with httpx.Client(timeout=30) as client: + response = client.post(self._BASE_URL, headers=self._headers, json=payload) + response.raise_for_status() + + # Decode base64 audio + audio_data = base64.b64decode(response.json()["audio"]) + + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_data) + + return os.path.abspath(audio_path) + + async def apredict(self, text: str, audio_path: str = "output.mp3") -> str: + """ + Asynchronously converts text to speech. + """ + payload = self._prepare_payload(text) + + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + # Decode base64 audio + audio_data = base64.b64decode(response.json()["audio"]) + + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_data) + + return os.path.abspath(audio_path) + + def batch( + self, + text_path_dict: Dict[str, str], + ) -> List[str]: + """ + Synchronously process multiple text-to-speech requests in batch mode. + """ + return [ + self.predict(text=text, audio_path=path) + for text, path in text_path_dict.items() + ] + + async def abatch( + self, + text_path_dict: Dict[str, str], + max_concurrent=5, + ) -> List[str]: + """ + Asynchronously process multiple text-to-speech requests in batch mode. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(text, path) -> str: + async with semaphore: + return await self.apredict(text=text, audio_path=path) + + tasks = [ + process_conversation(text, path) for text, path in text_path_dict.items() + ] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicModel.py new file mode 100644 index 000000000..2b3677677 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicModel.py @@ -0,0 +1,427 @@ +import asyncio +import json +from pydantic import PrivateAttr +import httpx +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.utils.duration_manager import DurationManager +from swarmauri.conversations.concrete.Conversation import Conversation +from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator + +from swarmauri_core.typing import SubclassUnion +from swarmauri.messages.base.MessageBase import MessageBase +from swarmauri.messages.concrete.AgentMessage import AgentMessage +from swarmauri.llms.base.LLMBase import LLMBase + +from swarmauri.messages.concrete.AgentMessage import UsageData + + +class HyperbolicModel(LLMBase): + """ + HyperbolicModel class for interacting with the Hyperbolic AI language models API. + + Attributes: + api_key (str): API key for authenticating requests to the Hyperbolic API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["HyperbolicModel"]): The type identifier for this class. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + api_key: str + allowed_models: List[str] = [ + "Qwen/Qwen2.5-Coder-32B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct", + "Qwen/Qwen2.5-72B-Instruct", + "deepseek-ai/DeepSeek-V2.5", + "meta-llama/Meta-Llama-3-70B-Instruct", + "NousResearch/Hermes-3-Llama-3.1-70B", + "meta-llama/Meta-Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + ] + name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" + type: Literal["HyperbolicModel"] = "HyperbolicModel" + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/chat/completions" + ) + _headers: Dict[str, str] = PrivateAttr(default=None) + + def __init__(self, **data) -> None: + """ + Initialize the HyperbolicModel class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + def _format_messages( + self, + messages: List[SubclassUnion[MessageBase]], + ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] + for message in messages: + formatted_message = message.model_dump( + include=["content", "role", "name"], exclude_none=True + ) + + if isinstance(formatted_message["content"], list): + formatted_message["content"] = [ + {"type": item["type"], **item} + for item in formatted_message["content"] + ] + + formatted_messages.append(formatted_message) + return formatted_messages + + def _prepare_usage_data( + self, + usage_data, + prompt_time: float = 0.0, + completion_time: float = 0.0, + ) -> UsageData: + """ + Prepare usage data by combining token counts and timing information. + + Args: + usage_data: Raw usage data containing token counts. + prompt_time (float): Time taken for prompt processing. + completion_time (float): Time taken for response completion. + + Returns: + UsageData: Processed usage data. + """ + total_time = prompt_time + completion_time + + # Filter usage data for relevant keys + filtered_usage_data = { + key: value + for key, value in usage_data.items() + if key + not in { + "prompt_tokens", + "completion_tokens", + "total_tokens", + "prompt_time", + "completion_time", + "total_time", + } + } + + usage = UsageData( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + prompt_time=prompt_time, + completion_time=completion_time, + total_time=total_time, + **filtered_usage_data, + ) + + return usage + + @retry_on_status_codes((429, 529), max_retries=1) + def predict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (Optional[int]): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + top_k (int): Maximum number of tokens to consider at each step. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": False, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + with httpx.Client(timeout=30) as client: + response = client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + response_data = response.json() + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data, promt_timer.duration) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + async def apredict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args are same as predict method. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": False, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + response_data = response.json() + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data, promt_timer.duration) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + def stream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Generator[str, None, None]: + """ + Streams response text from the model in real-time. + + Args are same as predict method. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": True, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + with httpx.Client(timeout=30) as client: + response = client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + message_content = "" + usage_data = {} + with DurationManager() as completion_timer: + for line in response.iter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"] and chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + if "usage" in chunk and chunk["usage"] is not None: + usage_data = chunk["usage"] + except json.JSONDecodeError: + pass + + usage = self._prepare_usage_data( + usage_data, promt_timer.duration, completion_timer.duration + ) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + + @retry_on_status_codes((429, 529), max_retries=1) + async def astream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args are same as predict method. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": True, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + message_content = "" + usage_data = {} + with DurationManager() as completion_timer: + async for line in response.aiter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"] and chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + if "usage" in chunk and chunk["usage"] is not None: + usage_data = chunk["usage"] + except json.JSONDecodeError: + pass + + usage = self._prepare_usage_data( + usage_data, promt_timer.duration, completion_timer.duration + ) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + + def batch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args are same as predict method. + """ + results = [] + for conversation in conversations: + result_conversation = self.predict( + conversation, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + enable_json=enable_json, + stop=stop, + ) + results.append(result_conversation) + return results + + async def abatch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + max_concurrent=5, + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args are same as predict method, with additional arg: + max_concurrent (int): Maximum number of concurrent requests. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv: Conversation) -> Conversation: + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + enable_json=enable_json, + stop=stop, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py new file mode 100644 index 000000000..14e2d196a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py @@ -0,0 +1,381 @@ +import json +from pydantic import PrivateAttr +import httpx +from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator +import asyncio + +from swarmauri_core.typing import SubclassUnion +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.base.MessageBase import MessageBase +from swarmauri.messages.concrete.AgentMessage import AgentMessage +from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.messages.concrete.AgentMessage import UsageData +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.utils.file_path_to_base64 import file_path_to_base64 + + +class HyperbolicVisionModel(LLMBase): + """ + HyperbolicVisionModel class for interacting with the Hyperbolic vision language models API. This class + provides synchronous and asynchronous methods to send conversation data to the + model, receive predictions, and stream responses. + + Attributes: + api_key (str): API key for authenticating requests to the Hyperbolic API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["HyperbolicVisionModel"]): The type identifier for this class. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + api_key: str + allowed_models: List[str] = [ + "Qwen/Qwen2-VL-72B-Instruct", + "mistralai/Pixtral-12B-2409", + "Qwen/Qwen2-VL-7B-Instruct", + ] + name: str = "Qwen/Qwen2-VL-72B-Instruct" + type: Literal["HyperbolicVisionModel"] = "HyperbolicVisionModel" + _headers: Dict[str, str] = PrivateAttr(default=None) + _client: httpx.Client = PrivateAttr(default=None) + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/chat/completions" + ) + + def __init__(self, **data): + """ + Initialize the HyperbolicVisionModel class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self._client = httpx.Client( + headers=self._headers, + base_url=self._BASE_URL, + ) + + def _format_messages( + self, + messages: List[SubclassUnion[MessageBase]], + ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] + for message in messages: + formatted_message = message.model_dump( + include=["content", "role", "name"], exclude_none=True + ) + + if isinstance(formatted_message["content"], list): + formatted_content = [] + for item in formatted_message["content"]: + if item["type"] == "image_url" and "file_path" in item: + # Convert file path to base64 + base64_img = file_path_to_base64(item["file_path"]) + formatted_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_img}" + }, + } + ) + else: + formatted_content.append(item) + formatted_message["content"] = formatted_content + + formatted_messages.append(formatted_message) + return formatted_messages + + def _prepare_usage_data(self, usage_data) -> UsageData: + """ + Prepares and validates usage data received from the API response. + + Args: + usage_data (dict): Raw usage data from the API response. + + Returns: + UsageData: Validated usage data instance. + """ + return UsageData.model_validate(usage_data) + + @retry_on_status_codes((429, 529), max_retries=1) + def predict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop or [], + } + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + async def apredict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop or [], + } + + async with httpx.AsyncClient() as async_client: + response = await async_client.post( + self._BASE_URL, json=payload, headers=self._headers + ) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + def stream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Generator[str, None, None]: + """ + Streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + "stop": stop or [], + } + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + message_content = "" + for line in response.iter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) + + @retry_on_status_codes((429, 529), max_retries=1) + async def astream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + "stop": stop or [], + } + + async with httpx.AsyncClient as async_client: + response = await async_client.post( + self._BASE_URL, json=payload, headers=self._headers + ) + response.raise_for_status() + + message_content = "" + async for line in response.aiter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) + + def batch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + results = [] + for conversation in conversations: + result_conversation = self.predict( + conversation, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + ) + results.append(result_conversation) + return results + + async def abatch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + max_concurrent=5, + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + max_concurrent (int): Maximum number of concurrent requests. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv: Conversation) -> Conversation: + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py b/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py index a24e7b59f..975ac7e93 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py @@ -1,47 +1,43 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of model names (file names without the ".py" extension) -model_files = [ - "AI21StudioModel", - "AnthropicModel", - "AnthropicToolModel", - "BlackForestimgGenModel", - "CohereModel", - "CohereToolModel", - "DeepInfraImgGenModel", - "DeepInfraModel", - "DeepSeekModel", - "FalAllImgGenModel", - "FalAVisionModel", - "GeminiProModel", - "GeminiToolModel", - "GroqAudio", - "GroqModel", - "GroqToolModel", - "GroqVisionModel", - "MistralModel", - "MistralToolModel", - "OpenAIGenModel", - "OpenAIModel", - "OpenAIToolModel", - "PerplexityModel", - "PlayHTModel", - "WhisperLargeModel", +# List of llms names (file names without the ".py" extension) and corresponding class names +llms_files = [ + ("swarmauri.llms.concrete.AI21StudioModel", "AI21StudioModel"), + ("swarmauri.llms.concrete.AnthropicModel", "AnthropicModel"), + ("swarmauri.llms.concrete.AnthropicToolModel", "AnthropicToolModel"), + ("swarmauri.llms.concrete.BlackForestImgGenModel", "BlackForestImgGenModel"), + ("swarmauri.llms.concrete.CohereModel", "CohereModel"), + ("swarmauri.llms.concrete.CohereToolModel", "CohereToolModel"), + ("swarmauri.llms.concrete.DeepInfraImgGenModel", "DeepInfraImgGenModel"), + ("swarmauri.llms.concrete.DeepInfraModel", "DeepInfraModel"), + ("swarmauri.llms.concrete.DeepSeekModel", "DeepSeekModel"), + ("swarmauri.llms.concrete.FalAIImgGenModel", "FalaiImgGenModel"), + ("swarmauri.llms.concrete.FalAIVisionModel", "FalAIVisionModel"), + ("swarmauri.llms.concrete.GeminiProModel", "GeminiProModel"), + ("swarmauri.llms.concrete.GeminiToolModel", "GeminiToolModel"), + ("swarmauri.llms.concrete.GroqAIAudio", "GroqAIAudio"), + ("swarmauri.llms.concrete.GroqModel", "GroqModel"), + ("swarmauri.llms.concrete.GroqToolModel", "GroqToolModel"), + ("swarmauri.llms.concrete.GroqVisionModel", "GroqVisionModel"), + ("swarmauri.llms.concrete.HyperbolicAudioTTS", "HyperbolicAudioTTS"), + ("swarmauri.llms.concrete.HyperbolicImgGenModel", "HyperbolicImgGenModel"), + ("swarmauri.llms.concrete.HyperbolicModel", "HyperbolicModel"), + ("swarmauri.llms.concrete.HyperbolicVisionModel", "HyperbolicVisionModel"), + ("swarmauri.llms.concrete.MistralModel", "MistralModel"), + ("swarmauri.llms.concrete.MistralToolModel", "MistralToolModel"), + ("swarmauri.llms.concrete.OpenAIAudio", "OpenAIAudio"), + ("swarmauri.llms.concrete.OpenAIAudioTTS", "OpenAIAudioTTS"), + ("swarmauri.llms.concrete.OpenAIImgGenModel", "OpenAIImgGenModel"), + ("swarmauri.llms.concrete.OpenAIModel", "OpenAIModel"), + ("swarmauri.llms.concrete.OpenAIToolModel", "OpenAIToolModel"), + ("swarmauri.llms.concrete.PerplexityModel", "PerplexityModel"), + ("swarmauri.llms.concrete.PlayHTModel", "PlayHTModel"), + ("swarmauri.llms.concrete.WhisperLargeModel", "WhisperLargeModel"), ] -# Lazy loading of models, storing them in variables -for model in model_files: - globals()[model] = _lazy_import(f"swarmauri.llms.concrete.{model}", model) +# Lazy loading of llms classes, storing them in variables +for module_name, class_name in llms_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded models to __all__ -__all__ = model_files +# Adding the lazy-loaded llms classes to __all__ +__all__ = [class_name for _, class_name in llms_files] diff --git a/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py b/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py index e340b2b85..ea47cb17f 100644 --- a/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py @@ -1,6 +1,41 @@ -from swarmauri.measurements.concrete.FirstImpressionMeasurement import FirstImpressionMeasurement -from swarmauri.measurements.concrete.MeanMeasurement import MeanMeasurement -from swarmauri.measurements.concrete.PatternMatchingMeasurement import PatternMatchingMeasurement -from swarmauri.measurements.concrete.RatioOfSumsMeasurement import RatioOfSumsMeasurement -from swarmauri.measurements.concrete.StaticMeasurement import StaticMeasurement -from swarmauri.measurements.concrete.ZeroMeasurement import ZeroMeasurement +from swarmauri.utils._lazy_import import _lazy_import + +# List of measurements names (file names without the ".py" extension) and corresponding class names +measurements_files = [ + ( + "swarmauri.measurements.concrete.CompletenessMeasurement", + "CompletenessMeasurement", + ), + ( + "swarmauri.measurements.concrete.DistinctivenessMeasurement", + "DistinctivenessMeasurement", + ), + ( + "swarmauri.measurements.concrete.FirstImpressionMeasurement", + "FirstImpressionMeasurement", + ), + ("swarmauri.measurements.concrete.MeanMeasurement", "MeanMeasurement"), + ("swarmauri.measurements.concrete.MiscMeasurement", "MiscMeasurement"), + ( + "swarmauri.measurements.concrete.MissingnessMeasurement", + "MissingnessMeasurement", + ), + ( + "swarmauri.measurements.concrete.PatternMatchingMeasurement", + "PatternMatchingMeasurement", + ), + ( + "swarmauri.measurements.concrete.RatioOfSumsMeasurement", + "RatioOfSumsMeasurement", + ), + ("swarmauri.measurements.concrete.StaticMeasurement", "StaticMeasurement"), + ("swarmauri.measurements.concrete.UniquenessMeasurement", "UniquenessMeasurement"), + ("swarmauri.measurements.concrete.ZeroMeasurement", "ZeroMeasurement"), +] + +# Lazy loading of measurements classes, storing them in variables +for module_name, class_name in measurements_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded measurements classes to __all__ +__all__ = [class_name for _, class_name in measurements_files] diff --git a/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py b/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py index 5c619ecc8..716bd57c5 100644 --- a/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py @@ -1,4 +1,16 @@ -from swarmauri.messages.concrete.HumanMessage import HumanMessage -from swarmauri.messages.concrete.AgentMessage import AgentMessage -from swarmauri.messages.concrete.FunctionMessage import FunctionMessage -from swarmauri.messages.concrete.SystemMessage import SystemMessage +from swarmauri.utils._lazy_import import _lazy_import + +# List of messages names (file names without the ".py" extension) and corresponding class names +messages_files = [ + ("swarmauri.messages.concrete.HumanMessage", "HumanMessage"), + ("swarmauri.messages.concrete.AgentMessage", "AgentMessage"), + ("from swarmauri.messages.concrete.FunctionMessage", "FunctionMessage"), + ("swarmauri.messages.concrete.SystemMessage", "SystemMessage"), +] + +# Lazy loading of messages classes, storing them in variables +for module_name, class_name in messages_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded messages classes to __all__ +__all__ = [class_name for _, class_name in messages_files] diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py b/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py index 45b1c7640..fb730763f 100644 --- a/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py @@ -1,37 +1,29 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of parser names (file names without the ".py" extension) -parser_files = [ - "BeautifulSoupElementParser", - "BERTEmbeddingParser", - "CSVParser", - "EntityRecognitionParser", - "HTMLTagStripParser", - "KeywordExtractorParser", - "Md2HtmlParser", - "OpenAPISpecParser", - "PhoneNumberExtractorParser", - "PythonParser", - "RegExParser", - "TextBlobNounParser", - "TextBlobSentenceParser", - "URLExtractorParser", - "XMLParser", +# List of parsers names (file names without the ".py" extension) and corresponding class names +parsers_files = [ + ( + "swarmauri.parsers.concrete.BeautifulSoupElementParser", + "BeautifulSoupElementParser", + ), + ("swarmauri.parsers.concrete.CSVParser", "CSVParser"), + ("swarmauri.parsers.concrete.HTMLTagStripParser", "HTMLTagStripParser"), + ("swarmauri.parsers.concrete.KeywordExtractorParser", "KeywordExtractorParser"), + ("swarmauri.parsers.concrete.Md2HtmlParser", "Md2HtmlParser"), + ("swarmauri.parsers.concrete.OpenAPISpecParser", "OpenAPISpecParser"), + ( + "swarmauri.parsers.concrete.PhoneNumberExtractorParser", + "PhoneNumberExtractorParser", + ), + ("swarmauri.parsers.concrete.PythonParser", "PythonParser"), + ("swarmauri.parsers.concrete.RegExParser", "RegExParser"), + ("swarmauri.parsers.concrete.URLExtractorParser", "URLExtractorParser"), + ("swarmauri.parsers.concrete.XMLParser", "XMLParser"), ] -# Lazy loading of parser modules, storing them in variables -for parser in parser_files: - globals()[parser] = _lazy_import(f"swarmauri.parsers.concrete.{parser}", parser) +# Lazy loading of parsers classes, storing them in variables +for module_name, class_name in parsers_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded parser modules to __all__ -__all__ = parser_files +# Adding the lazy-loaded parsers classes to __all__ +__all__ = [class_name for _, class_name in parsers_files] diff --git a/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py b/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py index 00d6b3cb9..3755b609f 100644 --- a/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py @@ -1,4 +1,16 @@ -from swarmauri.prompts.concrete.Prompt import Prompt -from swarmauri.prompts.concrete.PromptGenerator import PromptGenerator -from swarmauri.prompts.concrete.PromptMatrix import PromptMatrix -from swarmauri.prompts.concrete.PromptTemplate import PromptTemplate +from swarmauri.utils._lazy_import import _lazy_import + +# List of prompts names (file names without the ".py" extension) and corresponding class names +prompts_files = [ + ("swarmauri.prompts.concrete.Prompt", "Prompt"), + ("swarmauri.prompts.concrete.PromptGenerator", "PromptGenerator"), + ("swarmauri.prompts.concrete.PromptMatrix", "PromptMatrix"), + ("from swarmauri.prompts.concrete.PromptTemplate", "PromptTemplate"), +] + +# Lazy loading of prompts classes, storing them in variables +for module_name, class_name in prompts_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded prompts classes to __all__ +__all__ = [class_name for _, class_name in prompts_files] diff --git a/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py b/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py index c608d8c11..65044d64d 100644 --- a/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py @@ -1,29 +1,37 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of schema converter names (file names without the ".py" extension) -schema_converter_files = [ - "AnthropicSchemaConverter", - "CohereSchemaConverter", - "GeminiSchemaConverter", - "GroqSchemaConverter", - "MistralSchemaConverter", - "OpenAISchemaConverter", - "ShuttleAISchemaConverter", +# List of schema_converters names (file names without the ".py" extension) and corresponding class names +schema_converters_files = [ + ( + "swarmauri.schema_converters.concrete.AnthropicSchemaConverter", + "AnthropicSchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.CohereSchemaConverter", + "CohereSchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.GeminiSchemaConverter", + "GeminiSchemaConverter", + ), + ("swarmauri.schema_converters.concrete.GroqSchemaConverter", "GroqSchemaConverter"), + ( + "swarmauri.schema_converters.concrete.MistralSchemaConverter", + "MistralSchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.OpenAISchemaConverter", + "OpenAISchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.ShuttleAISchemaConverter", + "ShuttleAISchemaConverter", + ), ] -# Lazy loading of schema converters, storing them in variables -for schema_converter in schema_converter_files: - globals()[schema_converter] = _lazy_import(f"swarmauri.schema_converters.concrete.{schema_converter}", schema_converter) +# Lazy loading of schema_converters classes, storing them in variables +for module_name, class_name in schema_converters_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded schema converters to __all__ -__all__ = schema_converter_files +# Adding the lazy-loaded schema_converters classes to __all__ +__all__ = [class_name for _, class_name in schema_converters_files] diff --git a/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py b/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py index bd32d1999..61f84eae6 100644 --- a/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py @@ -1 +1,11 @@ -from swarmauri.swarms.concrete.SimpleSwarmFactory import SimpleSwarmFactory +from swarmauri.utils._lazy_import import _lazy_import + +# List of swarms names (file names without the ".py" extension) and corresponding class names +swarms_files = [("swarmauri.swarms.concrete.SimpleSwarmFactory", "SimpleSwarmFactory")] + +# Lazy loading of swarms classes, storing them in variables +for module_name, class_name in swarms_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded swarms classes to __all__ +__all__ = [class_name for _, class_name in swarms_files] diff --git a/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py b/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py index 87127d6bf..a7311c7c9 100644 --- a/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py @@ -1,21 +1,4 @@ -import importlib - -# Define a lazy loader function with a warning message if the module or class is not found -def _lazy_import(module_name, class_name): - try: - # Import the module - module = importlib.import_module(module_name) - # Dynamically get the class from the module - return getattr(module, class_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - except AttributeError: - # If class is not found, print a warning message - print(f"Warning: The class '{class_name}' was not found in module '{module_name}'.") - return None +from swarmauri.utils._lazy_import import _lazy_import # List of toolkit names (file names without the ".py" extension) and corresponding class names toolkit_files = [ diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/SMOGIndexTool.py b/pkgs/swarmauri/swarmauri/tools/concrete/SMOGIndexTool.py deleted file mode 100644 index 23ce384df..000000000 --- a/pkgs/swarmauri/swarmauri/tools/concrete/SMOGIndexTool.py +++ /dev/null @@ -1,113 +0,0 @@ -from swarmauri_core.typing import SubclassUnion -from typing import List, Literal, Dict -from pydantic import Field -from swarmauri.tools.base.ToolBase import ToolBase -from swarmauri.tools.concrete.Parameter import Parameter -import re -import math -import nltk -from nltk.tokenize import sent_tokenize - -# Download required NLTK data once during module load -nltk.download("punkt", quiet=True) - - -class SMOGIndexTool(ToolBase): - version: str = "0.1.dev2" - parameters: List[Parameter] = Field( - default_factory=lambda: [ - Parameter( - name="text", - type="string", - description="The text to analyze for SMOG Index", - required=True, - ) - ] - ) - name: str = "SMOGIndexTool" - description: str = "Calculates the SMOG Index for the provided text." - type: Literal["SMOGIndexTool"] = "SMOGIndexTool" - - def __call__(self, text: str) -> Dict[str, float]: - """ - Calculates the SMOG Index for the provided text. - - Parameters: - text (str): The text to analyze. - - Returns: - float: The calculated SMOG Index. - """ - return {"smog_index": self.calculate_smog_index(text)} - - def calculate_smog_index(self, text: str) -> float: - """ - Calculate the SMOG Index for a given text. - - Parameters: - text (str): The text to analyze. - - Returns: - float: The calculated SMOG Index. - """ - sentences = self.count_sentences(text) - polysyllables = self.count_polysyllables(text) - - if sentences == 0: - return 0.0 # Avoid division by zero - - smog_index = 1.0430 * math.sqrt(polysyllables * (30 / sentences)) + 3.1291 - return round(smog_index, 1) - - def count_sentences(self, text: str) -> int: - """ - Count the number of sentences in the text. - - Parameters: - text (str): The text to analyze. - - Returns: - int: The number of sentences in the text. - """ - sentences = sent_tokenize(text) - return len(sentences) - - def count_polysyllables(self, text: str) -> int: - """ - Count the number of polysyllabic words (words with three or more syllables) in the text. - - Parameters: - text (str): The text to analyze. - - Returns: - int: The number of polysyllabic words in the text. - """ - words = re.findall(r"\w+", text) - return len([word for word in words if self.count_syllables(word) >= 3]) - - def count_syllables(self, word: str) -> int: - """ - Count the number of syllables in a given word. - - Parameters: - word (str): The word to analyze. - - Returns: - int: The number of syllables in the word. - """ - word = word.lower() - vowels = "aeiouy" - count = 0 - if word and word[0] in vowels: - count += 1 - for index in range(1, len(word)): - if word[index] in vowels and word[index - 1] not in vowels: - count += 1 - if word.endswith("e") and not word.endswith("le"): - count -= 1 - if count == 0: - count = 1 - return count - - -SubclassUnion.update(baseclass=ToolBase, type_name="SMOGIndexTool", obj=SMOGIndexTool) diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py b/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py index f9d2a297e..5b7d61054 100644 --- a/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py @@ -1,25 +1,12 @@ -import importlib - -# Define a lazy loader function with a warning message if the module or class is not found -def _lazy_import(module_name, class_name): - try: - # Import the module - module = importlib.import_module(module_name) - # Dynamically get the class from the module - return getattr(module, class_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - except AttributeError: - print(f"Warning: The class '{class_name}' was not found in module '{module_name}'.") - return None +from swarmauri.utils._lazy_import import _lazy_import # List of tool names (file names without the ".py" extension) and corresponding class names tool_files = [ ("swarmauri.tools.concrete.AdditionTool", "AdditionTool"), - ("swarmauri.tools.concrete.AutomatedReadabilityIndexTool", "AutomatedReadabilityIndexTool"), + ( + "swarmauri.tools.concrete.AutomatedReadabilityIndexTool", + "AutomatedReadabilityIndexTool", + ), ("swarmauri.tools.concrete.CalculatorTool", "CalculatorTool"), ("swarmauri.tools.concrete.CodeExtractorTool", "CodeExtractorTool"), ("swarmauri.tools.concrete.CodeInterpreterTool", "CodeInterpreterTool"), diff --git a/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py b/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py index 95900d024..1b6619352 100644 --- a/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py @@ -1,5 +1,17 @@ -from swarmauri.tracing.concrete.CallableTracer import CallableTracer -from swarmauri.tracing.concrete.ChainTracer import ChainTracer -from swarmauri.tracing.concrete.SimpleTraceContext import SimpleTraceContext -from swarmauri.tracing.concrete.TracedVariable import TracedVariable -from swarmauri.tracing.concrete.VariableTracer import VariableTracer +from swarmauri.utils._lazy_import import _lazy_import + +# List of tracing names (file names without the ".py" extension) and corresponding class names +tracing_files = [ + ("swarmauri.tracing.concrete.CallableTracer", "CallableTracer"), + ("from swarmauri.tracing.concrete.ChainTracer", "ChainTracer"), + ("swarmauri.tracing.concrete.SimpleTraceContext", "SimpleTraceContext"), + ("swarmauri.tracing.concrete.TracedVariable", "TracedVariable"), + ("swarmauri.tracing.concrete.VariableTracer", "VariableTracer"), +] + +# Lazy loading of tracings, storing them in variables +for module_name, class_name in tracing_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded tracings to __all__ +__all__ = [class_name for _, class_name in tracing_files] diff --git a/pkgs/swarmauri/swarmauri/utils/_lazy_import.py b/pkgs/swarmauri/swarmauri/utils/_lazy_import.py new file mode 100644 index 000000000..a3d3bd34a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/utils/_lazy_import.py @@ -0,0 +1,22 @@ +import importlib + + +# Define a lazy loader function with a warning message if the module or class is not found +def _lazy_import(module_name, class_name): + try: + # Import the module + module = importlib.import_module(module_name) + # Dynamically get the class from the module + return getattr(module, class_name) + except ImportError: + # If module is not available, print a warning message + print( + f"Warning: The module '{module_name}' is not available. " + f"Please install the necessary dependencies to enable this functionality." + ) + return None + except AttributeError: + print( + f"Warning: The class '{class_name}' was not found in module '{module_name}'." + ) + return None diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py b/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py index 08a36e26c..ceb2b245c 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py @@ -1,26 +1,14 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of vector store names (file names without the ".py" extension) -vector_store_files = [ - "Doc2VecVectorStore", - "MlmVectorStore", - "SqliteVectorStore", - "TfidfVectorStore", +# List of vectore_stores names (file names without the ".py" extension) and corresponding class names +vectore_stores_files = [ + ("swarmauri.vector_stores.concrete.SqliteVectorStore", "SqliteVectorStore"), + ("swarmauri.vector_stores.concrete.TfidfVectorStore", "TfidfVectorStore"), ] -# Lazy loading of vector stores, storing them in variables -for vector_store in vector_store_files: - globals()[vector_store] = _lazy_import(f"swarmauri.vector_stores.concrete.{vector_store}", vector_store) +# Lazy loading of vectore_storess, storing them in variables +for module_name, class_name in vectore_stores_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded vector stores to __all__ -__all__ = vector_store_files +# Adding the lazy-loaded vectore_storess to __all__ +__all__ = [class_name for _, class_name in vectore_stores_files] diff --git a/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py b/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py index 16f348f20..7283bc0a9 100644 --- a/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py @@ -1,4 +1,14 @@ -# -*- coding: utf-8 -*- +from swarmauri.utils._lazy_import import _lazy_import -from swarmauri.vectors.concrete.Vector import Vector -from swarmauri.vectors.concrete.VectorProductMixin import VectorProductMixin +# List of vectors names (file names without the ".py" extension) and corresponding class names +vectors_files = [ + ("swarmauri.vectors.concrete.Vector", "Vector"), + ("swarmauri.vectors.concrete.VectorProductMixin", "VectorProductMixin"), +] + +# Lazy loading of vectorss, storing them in variables +for module_name, class_name in vectors_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded vectorss to __all__ +__all__ = [class_name for _, class_name in vectors_files] diff --git a/pkgs/swarmauri/tests/static/hyperbolic_test2.mp3 b/pkgs/swarmauri/tests/static/hyperbolic_test2.mp3 new file mode 100644 index 000000000..1fabb16fa Binary files /dev/null and b/pkgs/swarmauri/tests/static/hyperbolic_test2.mp3 differ diff --git a/pkgs/swarmauri/tests/static/hyperbolic_test3.mp3 b/pkgs/swarmauri/tests/static/hyperbolic_test3.mp3 new file mode 100644 index 000000000..f8a7a0a8b Binary files /dev/null and b/pkgs/swarmauri/tests/static/hyperbolic_test3.mp3 differ diff --git a/pkgs/swarmauri/tests/static/hyperbolic_test_tts.mp3 b/pkgs/swarmauri/tests/static/hyperbolic_test_tts.mp3 new file mode 100644 index 000000000..194714d16 Binary files /dev/null and b/pkgs/swarmauri/tests/static/hyperbolic_test_tts.mp3 differ diff --git a/pkgs/swarmauri/tests/unit/llms/BlackForestImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/BlackForestImgGenModel_unit_test.py similarity index 95% rename from pkgs/swarmauri/tests/unit/llms/BlackForestImgGenModel_unit_test.py rename to pkgs/swarmauri/tests/unit/image_gens/BlackForestImgGenModel_unit_test.py index 706d03b61..5fbd06c8a 100644 --- a/pkgs/swarmauri/tests/unit/llms/BlackForestImgGenModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/image_gens/BlackForestImgGenModel_unit_test.py @@ -1,7 +1,7 @@ import pytest import os from dotenv import load_dotenv -from swarmauri.llms.concrete.BlackForestImgGenModel import ( +from swarmauri.image_gens.concrete.BlackForestImgGenModel import ( BlackForestImgGenModel, ) @@ -30,7 +30,7 @@ def get_allowed_models(): @timeout(5) @pytest.mark.unit def test_model_resource(blackforest_imggen_model): - assert blackforest_imggen_model.resource == "LLM" + assert blackforest_imggen_model.resource == "ImageGen" @timeout(5) diff --git a/pkgs/swarmauri/tests/unit/llms/DeepInfraImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/DeepInfraImgGenModel_unit_test.py similarity index 97% rename from pkgs/swarmauri/tests/unit/llms/DeepInfraImgGenModel_unit_test.py rename to pkgs/swarmauri/tests/unit/image_gens/DeepInfraImgGenModel_unit_test.py index 98b3b7047..5492ff573 100644 --- a/pkgs/swarmauri/tests/unit/llms/DeepInfraImgGenModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/image_gens/DeepInfraImgGenModel_unit_test.py @@ -1,6 +1,6 @@ import pytest import os -from swarmauri.llms.concrete.DeepInfraImgGenModel import DeepInfraImgGenModel +from swarmauri.image_gens.concrete.DeepInfraImgGenModel import DeepInfraImgGenModel from dotenv import load_dotenv from swarmauri.utils.timeout_wrapper import timeout diff --git a/pkgs/swarmauri/tests/unit/llms/FalAIImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/FalAIImgGenModel_unit_test.py similarity index 97% rename from pkgs/swarmauri/tests/unit/llms/FalAIImgGenModel_unit_test.py rename to pkgs/swarmauri/tests/unit/image_gens/FalAIImgGenModel_unit_test.py index bf5b6d83f..858414f7f 100644 --- a/pkgs/swarmauri/tests/unit/llms/FalAIImgGenModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/image_gens/FalAIImgGenModel_unit_test.py @@ -1,6 +1,6 @@ import pytest import os -from swarmauri.llms.concrete.FalAIImgGenModel import FalAIImgGenModel +from swarmauri.image_gens.concrete.FalAIImgGenModel import FalAIImgGenModel from dotenv import load_dotenv from swarmauri.utils.timeout_wrapper import timeout diff --git a/pkgs/swarmauri/tests/unit/image_gens/HyperbolicImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/HyperbolicImgGenModel_unit_test.py new file mode 100644 index 000000000..3772b4dce --- /dev/null +++ b/pkgs/swarmauri/tests/unit/image_gens/HyperbolicImgGenModel_unit_test.py @@ -0,0 +1,118 @@ +import pytest +import os +from swarmauri.image_gens.concrete.HyperbolicImgGenModel import HyperbolicImgGenModel +from dotenv import load_dotenv + +from swarmauri.utils.timeout_wrapper import timeout + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_imggen_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + model = HyperbolicImgGenModel(api_key=API_KEY) + return model + + +def get_allowed_models(): + if not API_KEY: + return [] + model = HyperbolicImgGenModel(api_key=API_KEY) + return model.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_imggen_model): + assert hyperbolic_imggen_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_imggen_model): + assert hyperbolic_imggen_model.type == "HyperbolicImgGenModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_imggen_model): + assert ( + hyperbolic_imggen_model.id + == HyperbolicImgGenModel.model_validate_json( + hyperbolic_imggen_model.model_dump_json() + ).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_name(hyperbolic_imggen_model): + assert hyperbolic_imggen_model.name == "SDXL1.0-base" + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_generate_image_base64(hyperbolic_imggen_model, model_name): + model = hyperbolic_imggen_model + model.name = model_name + + prompt = "A cute cat playing with a ball of yarn" + + image_base64 = model.generate_image_base64(prompt=prompt) + + assert isinstance(image_base64, str) + assert len(image_base64) > 0 + + +@timeout(5) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_agenerate_image_base64(hyperbolic_imggen_model, model_name): + model = hyperbolic_imggen_model + model.name = model_name + + prompt = "A serene landscape with mountains and a lake" + + image_base64 = await model.agenerate_image_base64(prompt=prompt) + + assert isinstance(image_base64, str) + assert len(image_base64) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch_base64(hyperbolic_imggen_model): + prompts = [ + "A futuristic city skyline", + "A tropical beach at sunset", + ] + + result_base64_images = hyperbolic_imggen_model.batch_base64(prompts=prompts) + + assert len(result_base64_images) == len(prompts) + for image_base64 in result_base64_images: + assert isinstance(image_base64, str) + assert len(image_base64) > 0 + + +@timeout(5) +@pytest.mark.asyncio +@pytest.mark.unit +async def test_abatch_base64(hyperbolic_imggen_model): + prompts = [ + "An abstract painting with vibrant colors", + "A snowy mountain peak", + ] + + result_base64_images = await hyperbolic_imggen_model.abatch_base64(prompts=prompts) + + assert len(result_base64_images) == len(prompts) + for image_base64 in result_base64_images: + assert isinstance(image_base64, str) + assert len(image_base64) > 0 diff --git a/pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py b/pkgs/swarmauri/tests/unit/image_gens/OpenAIImgGenModel_unit_tesst.py similarity index 97% rename from pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py rename to pkgs/swarmauri/tests/unit/image_gens/OpenAIImgGenModel_unit_tesst.py index 7780ba042..b22b9e6ea 100644 --- a/pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py +++ b/pkgs/swarmauri/tests/unit/image_gens/OpenAIImgGenModel_unit_tesst.py @@ -1,7 +1,7 @@ import pytest import os from dotenv import load_dotenv -from swarmauri.llms.concrete.OpenAIImgGenModel import OpenAIImgGenModel +from swarmauri.image_gens.concrete.OpenAIImgGenModel import OpenAIImgGenModel from swarmauri.utils.timeout_wrapper import timeout load_dotenv() diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicAudioTTS_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicAudioTTS_unit_test.py new file mode 100644 index 000000000..9b81c84ef --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicAudioTTS_unit_test.py @@ -0,0 +1,141 @@ +import logging +import pytest +import os + +from swarmauri.llms.concrete.HyperbolicAudioTTS import HyperbolicAudioTTS as LLM +from dotenv import load_dotenv +from swarmauri.utils.timeout_wrapper import timeout +from pathlib import Path + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +# Get the current working directory +root_dir = Path(__file__).resolve().parents[2] + +# Construct file paths dynamically +file_path = os.path.join(root_dir, "static", "hyperbolic_test_tts.mp3") +file_path2 = os.path.join(root_dir, "static", "hyperbolic_test2.mp3") +file_path3 = os.path.join(root_dir, "static", "hyperbolic_test3.mp3") + + +@pytest.fixture(scope="module") +def hyperbolic_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +@timeout(5) +def get_allowed_languages(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_languages + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_model): + assert hyperbolic_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_model): + assert hyperbolic_model.type == "HyperbolicAudioTTS" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_model): + assert ( + hyperbolic_model.id + == LLM.model_validate_json(hyperbolic_model.model_dump_json()).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_speed(hyperbolic_model): + assert hyperbolic_model.speed == 1.0 + + +@timeout(5) +@pytest.mark.parametrize("language", get_allowed_languages()) +@pytest.mark.unit +def test_predict(hyperbolic_model, language): + """ + Test prediction with different languages + Note: Adjust the text according to the language if needed + """ + # Set the language for the test + hyperbolic_model.language = language + + # Select an appropriate text based on the language + texts = { + "EN": "Hello, this is a test of text-to-speech output in English.", + "ES": "Hola, esta es una prueba de salida de texto a voz en español.", + "FR": "Bonjour, ceci est un test de sortie de texte en français.", + "ZH": "这是一个中文语音转换测试。", + "JP": "これは日本語の音声合成テストです。", + "KR": "이것은 한국어 음성 합성 테스트입니다.", + } + + text = texts.get( + language, "Hello, this is a generic test of text-to-speech output." + ) + + audio_file_path = hyperbolic_model.predict(text=text, audio_path=file_path) + + logging.info(audio_file_path) + + assert isinstance(audio_file_path, str) + assert os.path.exists(audio_file_path) + assert os.path.getsize(audio_file_path) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch(hyperbolic_model): + """ + Test batch processing of multiple texts + """ + text_path_dict = { + "Hello": file_path, + "Hi there": file_path2, + "Good morning": file_path3, + } + + results = hyperbolic_model.batch(text_path_dict=text_path_dict) + assert len(results) == len(text_path_dict) + + for result in results: + assert isinstance(result, str) + assert os.path.exists(result) + assert os.path.getsize(result) > 0 + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.unit +async def test_abatch(hyperbolic_model): + """ + Test asynchronous batch processing of multiple texts + """ + text_path_dict = { + "Hello": file_path, + "Hi there": file_path2, + "Good morning": file_path3, + } + + results = await hyperbolic_model.abatch(text_path_dict=text_path_dict) + assert len(results) == len(text_path_dict) + + for result in results: + assert isinstance(result, str) + assert os.path.exists(result) + assert os.path.getsize(result) > 0 diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicModel_unit_test.py new file mode 100644 index 000000000..6a4bd9652 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicModel_unit_test.py @@ -0,0 +1,218 @@ +import logging +import pytest +import os + +from swarmauri.llms.concrete.HyperbolicModel import HyperbolicModel as LLM +from swarmauri.conversations.concrete.Conversation import Conversation + +from swarmauri.messages.concrete.HumanMessage import HumanMessage +from swarmauri.messages.concrete.SystemMessage import SystemMessage + +from swarmauri.messages.concrete.AgentMessage import UsageData + +from swarmauri.utils.timeout_wrapper import timeout + +from dotenv import load_dotenv + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +def get_allowed_models(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_model): + assert hyperbolic_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_model): + assert hyperbolic_model.type == "HyperbolicModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_model): + assert ( + hyperbolic_model.id + == LLM.model_validate_json(hyperbolic_model.model_dump_json()).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_name(hyperbolic_model): + assert hyperbolic_model.name == "meta-llama/Meta-Llama-3.1-8B-Instruct" + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_no_system_context(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Hello" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + model.predict(conversation=conversation) + prediction = conversation.get_last().content + usage_data = conversation.get_last().usage + + logging.info(usage_data) + + assert type(prediction) is str + assert isinstance(usage_data, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_preamble_system_context(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + system_context = 'You only respond with the following phrase, "Jeff"' + human_message = SystemMessage(content=system_context) + conversation.add_message(human_message) + + input_data = "Hi" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + model.predict(conversation=conversation) + prediction = conversation.get_last().content + usage_data = conversation.get_last().usage + + logging.info(usage_data) + + assert type(prediction) is str + assert "Jeff" in prediction + assert isinstance(usage_data, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_stream(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Write a short story about a cat." + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + collected_tokens = [] + for token in model.stream(conversation=conversation): + logging.info(token) + assert isinstance(token, str) + collected_tokens.append(token) + + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_apredict(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Hello" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + result = await model.apredict(conversation=conversation) + prediction = result.get_last().content + assert isinstance(prediction, str) + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_astream(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Write a short story about a dog." + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + collected_tokens = [] + async for token in model.astream(conversation=conversation): + assert isinstance(token, str) + collected_tokens.append(token) + + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_batch(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + + conversations = [] + for prompt in ["Hello", "Hi there", "Good morning"]: + conv = Conversation() + conv.add_message(HumanMessage(content=prompt)) + conversations.append(conv) + + results = model.batch(conversations=conversations) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + assert isinstance(result.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_abatch(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + + conversations = [] + for prompt in ["Hello", "Hi there", "Good morning"]: + conv = Conversation() + conv.add_message(HumanMessage(content=prompt)) + conversations.append(conv) + + results = await model.abatch(conversations=conversations) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + assert isinstance(result.get_last().usage, UsageData) diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py new file mode 100644 index 000000000..495341aae --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py @@ -0,0 +1,158 @@ +import pytest +import os +from swarmauri.llms.concrete.HyperbolicVisionModel import HyperbolicVisionModel +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.concrete.HumanMessage import HumanMessage +from dotenv import load_dotenv +from swarmauri.utils.timeout_wrapper import timeout + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_vision_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + model = HyperbolicVisionModel(api_key=API_KEY) + return model + + +def get_allowed_models(): + if not API_KEY: + return [] + model = HyperbolicVisionModel(api_key=API_KEY) + return model.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_vision_model): + assert hyperbolic_vision_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_vision_model): + assert hyperbolic_vision_model.type == "HyperbolicVisionModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_vision_model): + assert ( + hyperbolic_vision_model.id + == HyperbolicVisionModel.model_validate_json( + hyperbolic_vision_model.model_dump_json() + ).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_model_name(hyperbolic_vision_model): + assert hyperbolic_vision_model.name == "Qwen/Qwen2-VL-72B-Instruct" + + +def create_test_conversation(image_url, prompt): + conversation = Conversation() + conversation.add_message( + HumanMessage( + content=[ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ] + ) + ) + return conversation + + +@pytest.mark.parametrize("model_name", get_allowed_models()) +@timeout(5) +@pytest.mark.unit +def test_predict(hyperbolic_vision_model, model_name): + model = hyperbolic_vision_model + model.name = model_name + + image_url = "https://llava-vl.github.io/static/images/monalisa.jpg" + prompt = "Who painted this artwork?" + conversation = create_test_conversation(image_url, prompt) + + result = model.predict(conversation) + + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", get_allowed_models()) +@timeout(5) +@pytest.mark.unit +async def test_apredict(hyperbolic_vision_model, model_name): + model = hyperbolic_vision_model + model.name = model_name + + image_url = "https://llava-vl.github.io/static/images/monalisa.jpg" + prompt = "Describe the woman in the painting." + conversation = create_test_conversation(image_url, prompt) + + result = await model.apredict(conversation) + + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch(hyperbolic_vision_model): + image_urls = [ + "https://llava-vl.github.io/static/images/monalisa.jpg", + "https://llava-vl.github.io/static/images/monalisa.jpg", + ] + prompts = [ + "Who painted this artwork?", + "Describe the woman in the painting.", + ] + + conversations = [ + create_test_conversation(image_url, prompt) + for image_url, prompt in zip(image_urls, prompts) + ] + + results = hyperbolic_vision_model.batch(conversations) + + assert len(results) == len(image_urls) + for result in results: + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@pytest.mark.asyncio +@timeout(5) +@pytest.mark.unit +async def test_abatch(hyperbolic_vision_model): + image_urls = [ + "https://llava-vl.github.io/static/images/monalisa.jpg", + "https://llava-vl.github.io/static/images/monalisa.jpg", + ] + prompts = [ + "Who painted this artwork?", + "Describe the woman in the painting.", + ] + + conversations = [ + create_test_conversation(image_url, prompt) + for image_url, prompt in zip(image_urls, prompts) + ] + + results = await hyperbolic_vision_model.abatch(conversations) + + assert len(results) == len(image_urls) + for result in results: + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0